From 47adfd632e2d549e539c45af4a73c1ab75e37d3e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 Jan 2024 17:06:33 +0000 Subject: [PATCH] [Unity][Transform] Raise error in FuseOpsByPattern for SSA violation Internally, `FuseOpsByPattern` makes a mapping from relax variables to the fused group containing that variable. If the input module violates SSA, this map may be ill-formed. While not strictly necessary for FuseOps to handle ill-formed inputs, checking it at this level provides better error handling than propagating it to downstream passes. This commit checks for ill-formed inputs that would produce invalid fused outputs and raises an error. --- src/relax/transform/fuse_ops.cc | 9 +++++++- .../test_transform_fuse_ops_by_pattern.py | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 0dbee3667061..db586443b278 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1282,7 +1282,14 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, pattern->annotation_patterns, pattern->check.value_or(nullptr), entry.second, &arena, pattern->attrs_getter.value_or(nullptr)); - group_map.insert(map.begin(), map.end()); + for (const auto& [key, value] : map) { + CHECK(!group_map.count(key)) + << "ValueError: " + << "IRModule is invalid. " + << "The object " << GetRef(key) << " appears in multiple partitions, " + << "which can occur when the IRModule was not single-site assignment"; + group_map.insert({key, value}); + } } mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants); } diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index de356fd5480e..90b713b4f348 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -1055,5 +1055,26 @@ def test_multple_runs(): ) +@pytest.mark.skip_well_formed_check_before_transform +def test_error_on_repeated_variable_definitions(): + """Raise error for SSA violations + + Internally, `FuseOpsByPattern` makes a mapping from relax + variables to the fused group containing that variable. If the + input module violates SSA, this map may be ill-formed. + + While not strictly necessary for FuseOps to handle ill-formed + inputs, checking it at this level provides better error handling + than propagating it to downstream passes. + """ + mod = Conv2dReLU.clone() + mod["copy"] = mod["main"].with_attr("global_symbol", "copy") + + patterns = [("dnnl.conv2d_relu", conv2d_relu_pat)] + + with pytest.raises(ValueError): + relax.transform.FuseOpsByPattern(patterns)(mod) + + if __name__ == "__main__": pytest.main([__file__])