From 685cc7605f2d8b380d88b650c2f46b4e113d8b42 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 1 Apr 2024 07:10:39 -0500 Subject: [PATCH 1/2] [Relax][Bugfix] Provide the full Expr to pattern-match rewriter This resolves a bug that was introduced in https://github.com/apache/tvm/pull/16732. If a rewriter function returned a no-op, and the pattern-match continued, then the `matches` provided to the rewriter function in subsequent calls would contain a variable to which the matched expression was bound, not the matched expression itself. (e.g. For a match of `C = R.add(A,B)`, passing `C` to the rewriter instead of `R.add(A,B)`.) This bug was caused by incorrect re-wrapping of `OrPattern` in `ExprPatternRewriter`. Prior to https://github.com/apache/tvm/pull/16732, all pattern-match results were populated by `ExtractMatchExpr`, and contained the result after applying `TryGetValOfVar`. When re-wrapping the result of an `OrPattern`, https://github.com/apache/tvm/pull/16732 populated the additional matches with the result before applying `TryGetValOfVar`. This commit fixes the bug by applying `TryGetValOfVar`. --- src/relax/ir/dataflow_matcher.cc | 13 ++++++-- tests/python/relax/test_dataflow_pattern.py | 33 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index db70ef6a9cec..cf8934c372e2 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -1190,8 +1190,17 @@ class ExprPatternRewriter : ExprMutator { if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { auto matches = opt_matches.value(); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, expr); + + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + if (matches_top_level->size()) { + auto matched_expr = TryGetValOfVar(expr, bindings_); + for (const auto& pat : *matches_top_level) { + matches.Set(pat, matched_expr); + } } Expr rewritten_expr = rewriter_func_(expr, matches); diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 81cd8da7fe71..590160fc5e87 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1952,5 +1952,38 @@ def expected(): tvm.ir.assert_structural_equal(expected, after) +def test_backtrack_for_no_op_rewriter_does_not_match_on_var(): + """The matches should always contain the bound value + + This is a regression test. In versions from + https://github.com/apache/tvm/pull/16732 to + https://github.com/apache/tvm/pull/#####, the `rewrite_call` + function could erroneously call the rewriter with `expr` and + `matches[pat]` set to a variable (`C`) instead of the value to + which it is bound (`R.add(A,B)`). + """ + pat_a = is_op("relax.add")(wildcard(), wildcard()) + pat_b = is_op("relax.add")(wildcard(), wildcard()) + pat = pat_a | pat_b + + def rewriter(expr, matches): + assert isinstance(matches[pat], rx.Call) + return expr + + @R.function(private=True) + def before(): + with R.dataflow(): + A = R.ones([64, 128], "int32") + B = R.zeros([64, 128], "int32") + C = R.add(A, B) + + R.output(C) + return C + + expected = before + after = rewrite_call(pat, rewriter, before) + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main() From 65ba7b88d6bcee9e52c7b608c9b17674c661ab50 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 1 Apr 2024 08:25:57 -0500 Subject: [PATCH 2/2] Update with PR link of bugfix --- tests/python/relax/test_dataflow_pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 590160fc5e87..24c36d20dc18 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1957,7 +1957,7 @@ def test_backtrack_for_no_op_rewriter_does_not_match_on_var(): This is a regression test. In versions from https://github.com/apache/tvm/pull/16732 to - https://github.com/apache/tvm/pull/#####, the `rewrite_call` + https://github.com/apache/tvm/pull/16828, the `rewrite_call` function could erroneously call the rewriter with `expr` and `matches[pat]` set to a variable (`C`) instead of the value to which it is bound (`R.add(A,B)`).