From 08f91b9431573199126375fe1230b4c57f9eb8c2 Mon Sep 17 00:00:00 2001 From: Huibin Wang Date: Thu, 9 May 2024 15:56:41 +0800 Subject: [PATCH 1/2] [BugFix][Relay] skip leaf args when matching 'path' part for dominator pattern --- src/relay/ir/dataflow_matcher.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 8e756a8aa2d3..0c0ff7290115 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -300,11 +300,17 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { + // utilities + auto is_leaf_node = [](const Expr& expr) { + return expr.as() || expr.as(); + }; + + // logic auto call_node = expr.as(); auto index_node = expr_to_node(expr); size_t arg_counter{0}; for (auto node : index_node->inputs_) { - if (!(call_node && node->ref() == call_node->op)) { + if (!(call_node && (node->ref() == call_node->op || is_leaf_node(node->ref())))) { arg_counter += 1; memoize_ = true; if (!VisitDFPattern(op->parent, node->ref())) { From 26856ad45b4d27f10f8c3ec5d1c02a240b28ffd6 Mon Sep 17 00:00:00 2001 From: Huibin Wang Date: Sat, 11 May 2024 19:11:56 +0800 Subject: [PATCH 2/2] add testcase --- tests/python/relay/test_dataflow_pattern.py | 24 ++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 3950c02c08a4..d3e6874721ec 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -28,7 +28,7 @@ # convention. K_ELEMWISE = 0 K_BROADCAST = 1 - +K_INJECTIVE = 2 ## NODE TESTS def test_expr_pattern(): @@ -696,6 +696,28 @@ def test_match_dominator(): assert diamond.match(out) +def test_match_dominator2(): + # Pattern + conv2d_pat = is_op("nn.conv2d")(wildcard(), wildcard()) + eltwise_pat = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(None) + broadcast_pat = (wildcard().has_attr({"TOpPattern": K_BROADCAST}))(None) + path_pat = eltwise_pat | broadcast_pat + injective_pat = (wildcard().has_attr({"TOpPattern": K_INJECTIVE}))(wildcard()) + pattern = injective_pat.dominates(conv2d_pat, path_pat) + + # Graph + inp = relay.var("input") + weight = relay.var("weight") + bias = relay.var("bias") + conv2d = relay.op.nn.conv2d(inp, weight) + bias_add = relay.op.nn.bias_add(conv2d, bias) + relu = relay.op.nn.relu(bias_add) + reshape = relay.op.reshape(relu, newshape=[-1, 2, 8]) + + # Check + assert pattern.match(reshape) + + def test_not_match_dominator(): is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()) is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())