Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,14 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,
pattern->annotation_patterns,
pattern->check.value_or(nullptr), entry.second,
&arena, pattern->attrs_getter.value_or(nullptr));
group_map.insert(map.begin(), map.end());
for (const auto& [key, value] : map) {
CHECK(!group_map.count(key))
<< "ValueError: "
<< "IRModule is invalid. "
<< "The object " << GetRef<ObjectRef>(key) << " appears in multiple partitions, "
<< "which can occur when the IRModule was not single-site assignment";
group_map.insert({key, value});
}
}
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
}
Expand Down
21 changes: 21 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,5 +1055,26 @@ def test_multple_runs():
)


@pytest.mark.skip_well_formed_check_before_transform
def test_error_on_repeated_variable_definitions():
"""Raise error for SSA violations

Internally, `FuseOpsByPattern` makes a mapping from relax
variables to the fused group containing that variable. If the
input module violates SSA, this map may be ill-formed.

While not strictly necessary for FuseOps to handle ill-formed
inputs, checking it at this level provides better error handling
than propagating it to downstream passes.
"""
mod = Conv2dReLU.clone()
mod["copy"] = mod["main"].with_attr("global_symbol", "copy")

patterns = [("dnnl.conv2d_relu", conv2d_relu_pat)]

with pytest.raises(ValueError):
relax.transform.FuseOpsByPattern(patterns)(mod)


if __name__ == "__main__":
pytest.main([__file__])