diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 9aeb289e2ae9..6b88446893cf 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -91,18 +91,21 @@ class CanonicalizePlanner : public ExprVisitor { bound_to = opt.value(); } - if (bound_var.as() || !bound_to.as()) { + if (bound_var.as() || !bound_to.as() || + !visitor.used_outside_home_dataflow_.count(bound_var)) { // Case 1: Var = Var // Case 2: DataflowVar = Var // Case 3: DataflowVar = DataflowVar + // Case 4a: Var = DataflowVar, where the Var is not used + // outside the DataflowBlock containing the binding // - // For these three cases, the trivial binding can be - // unwrapped, using the bound variable directly at the point - // of use. + // For these four cases, the trivial binding can be unwrapped, + // using the bound variable directly at the point of use. plan.replace_usage.Set(bound_var->vid, bound_to); plan.bindings_to_remove.insert(bound_var->vid); } else { - // Case 4: Var = DataflowVar + // Case 4b: Var = DataflowVar, where the Var is used somewhere + // outside the DataflowBlock containing the binding // // Replacing a Var with a DataflowVar could result in illegal // use of a DataflowVar outside of a DataflowBlock. Instead, diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 7d7b74bf5961..d513c0cf6c6d 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -977,5 +977,39 @@ def main(): verify(TestChainAssignments, Expected) +def test_trivial_binding_of_replaced_non_dataflow_var(): + @I.ir_module + class Before: + @R.function + def main(param_tuple: R.Tuple([R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + B = A + C = R.add(A, B) + R.output(A, B, C) + return C + + @I.ir_module + class Expected: + @R.function + def main(param_tuple: R.Tuple([R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + C = R.add(A, A) + R.output(C) + return C + + After = CanonicalizeBindings()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + def _get_binding_names(mod): + return [binding.var.name_hint for binding in mod["main"].body.blocks[0].bindings] + + expected_names = _get_binding_names(Expected) + after_names = _get_binding_names(After) + + assert after_names == expected_names + + if __name__ == "__main__": tvm.testing.main()