[Bugfix] [Relay] fix a bug of printing dataflow pattern#15350
[Bugfix] [Relay] fix a bug of printing dataflow pattern#15350masahi merged 3 commits intoapache:mainfrom
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment. Generated by tvm-bot |
|
btw, does the dataflow pattern support recursion (loop) at all? |
|
|
||
| std::unordered_map<DFPattern, std::pair<size_t, std::string>, ObjectPtrHash, ObjectPtrEqual> | ||
| memo_{}; | ||
| std::vector<DFPattern> recursed_patterns{}; |
There was a problem hiding this comment.
Please document what you mean by "recursed_patterns". Probably "recursive_patterns" would be more correct.
There was a problem hiding this comment.
Documentation added. To be consistent with the printed text, I choose to name this variable to be auxiliary_patterns
| string_stream << "Main pattern is:" << std::endl; | ||
| string_stream << printer.string_stream.str(); | ||
| string_stream << std::endl; | ||
| string_stream << "Auxiliary patterns are:"; |
Surprisingly, the answer is yes. Indeed, I was thinking to add the support for recursion, only to find that the existing dataflow pattern language has alreadly (silently) support it. TVM_REGISTER_GLOBAL("relay.dataflow_pattern.my_pattern")
.set_body_typed([]() {
DFPattern dense_pattern = IsOp("nn.dense")({IsWildcard(), IsWildcard()});
ObjectPtr<CallPatternNode> the_pattern_ptr = make_object<CallPatternNode>();
the_pattern_ptr->op = IsOp("cast");
the_pattern_ptr->args.clear();
CallPattern the_pattern = CallPattern(the_pattern_ptr);
AltPattern or_pattern{the_pattern, dense_pattern};
the_pattern_ptr->args.push_back(or_pattern);
//LOG(INFO) << PrettyPrint(the_pattern); // TODO: BUG!
return the_pattern;
});This simple pattern matches a nn.dense followed by an arbitrary number of cast. You can test this pattern via the following python code: class TheRewrite(DFPatternCallback):
def __init__(self):
super(TheRewrite, self).__init__(rewrite_once = True)
pattern = tvm.get_global_func("relay.dataflow_pattern.my_pattern")()
self.pattern = pattern
def callback(self, pre, post, node_map):
return relay.nn.relu(post)
mod = create_model() # define a model
the_rewrite = TheRewrite()
out = rewrite(the_rewrite, mod["main"])Another application of recursion is PR #15362, which I do not know how to achieve without recursion. That PR is useful, and can really improve the computational graph for some quantized models. I would like to examine the pattern matching code further in the following days. |
|
@tvm-bot rerun |
When recursion of dataflow patterns is used, the pattern graph may not be a DAG. If recursion is encountered, ReprPrint of dataflow pattern may fall in a dead loop.
This PR solves the bug of dataflow pattern printing, and is the first PR for the pre-RFC.