diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index db70ef6a9cec..cf8934c372e2 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -1190,8 +1190,17 @@ class ExprPatternRewriter : ExprMutator { if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { auto matches = opt_matches.value(); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, expr); + + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + if (matches_top_level->size()) { + auto matched_expr = TryGetValOfVar(expr, bindings_); + for (const auto& pat : *matches_top_level) { + matches.Set(pat, matched_expr); + } } Expr rewritten_expr = rewriter_func_(expr, matches); diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 81cd8da7fe71..24c36d20dc18 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1952,5 +1952,38 @@ def expected(): tvm.ir.assert_structural_equal(expected, after) +def test_backtrack_for_no_op_rewriter_does_not_match_on_var(): + """The matches should always contain the bound value + + This is a regression test. In versions from + https://github.com/apache/tvm/pull/16732 to + https://github.com/apache/tvm/pull/16828, the `rewrite_call` + function could erroneously call the rewriter with `expr` and + `matches[pat]` set to a variable (`C`) instead of the value to + which it is bound (`R.add(A,B)`). + """ + pat_a = is_op("relax.add")(wildcard(), wildcard()) + pat_b = is_op("relax.add")(wildcard(), wildcard()) + pat = pat_a | pat_b + + def rewriter(expr, matches): + assert isinstance(matches[pat], rx.Call) + return expr + + @R.function(private=True) + def before(): + with R.dataflow(): + A = R.ones([64, 128], "int32") + B = R.zeros([64, 128], "int32") + C = R.add(A, B) + + R.output(C) + return C + + expected = before + after = rewrite_call(pat, rewriter, before) + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main()