@@ -33,8 +33,12 @@ namespace tvm {
3333namespace relay {
3434namespace annotate_target {
3535
36- const PackedFunc* begin_op = runtime::Registry::Get(" relay.op.annotation._make.compiler_begin" );
37- const PackedFunc* end_op = runtime::Registry::Get(" relay.op.annotation._make.compiler_end" );
36+ static const Op& compiler_begin_op = Op::Get(" annotation.compiler_begin" );
37+ static const Op& compiler_end_op = Op::Get(" annotation.compiler_end" );
38+
39+ const PackedFunc* make_begin_op =
40+ runtime::Registry::Get (" relay.op.annotation._make.compiler_begin" );
41+ const PackedFunc* make_end_op = runtime::Registry::Get(" relay.op.annotation._make.compiler_end" );
3842
3943// A helper class to insert annotation boundaries for a program region that will
4044// be handled by a specific compiler.
@@ -59,26 +63,40 @@ class AnnotateTargetWrapper : public ExprMutator {
5963 std::string ref_target = " " ;
6064 Array<Expr> compiler_ends;
6165 for (auto arg : args) {
62- if (op_expr_to_target_.find (arg) != op_expr_to_target_.end ()) {
63- std::string arg_target = op_expr_to_target_[arg];
64- compiler_ends.push_back (InsertAnnotation (arg, arg_target, end_op));
65- if (ref_target == " " ) {
66- ref_target = arg_target;
67- } else if (ref_target != arg_target) {
68- ref_target = " default" ;
66+ std::string arg_target = " defualt" ;
67+ const CallNode* call = arg.as <CallNode>();
68+
69+ if (call && call->op == compiler_begin_op) {
70+ // Argument is already compiler begin node meaning that this is not the first time
71+ // running this pass, so we simply remove it and will add a new one later.
72+ CHECK_EQ (call->args .size (), 1U );
73+ const CallNode* end = call->args [0 ].as <CallNode>();
74+ if (end->op == compiler_end_op) {
75+ arg_target = end->attrs .as <CompilerAttrs>()->compiler ;
6976 }
77+ compiler_ends.push_back (call->args [0 ]);
78+ } else if (op_expr_to_target_.find (arg) != op_expr_to_target_.end ()) {
79+ arg_target = op_expr_to_target_[arg];
80+ compiler_ends.push_back (InsertAnnotation (arg, arg_target, make_end_op));
7081 } else {
7182 // Input vars.
7283 compiler_ends.push_back (arg);
7384 }
85+
86+ // Maintain reference target in case the target of the current node is unassigned.
87+ if (ref_target == " " ) {
88+ ref_target = arg_target;
89+ } else if (ref_target != arg_target) {
90+ ref_target = " default" ;
91+ }
7492 }
7593
7694 // Determine compiler begin target.
7795 std::string op_target = (target == " " ) ? ref_target : target;
7896
7997 Array<Expr> compiler_begins;
8098 for (const auto & end : compiler_ends) {
81- compiler_begins.push_back (InsertAnnotation (end, op_target, begin_op ));
99+ compiler_begins.push_back (InsertAnnotation (end, op_target, make_begin_op ));
82100 }
83101
84102 return {op_target, compiler_begins};
@@ -87,15 +105,41 @@ class AnnotateTargetWrapper : public ExprMutator {
87105 Expr InsertAnnotation (const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
88106 Expr new_op = (*ann_op)(expr, target);
89107 new_op->checked_type_ = expr->checked_type_ ;
90- return new_op;
108+ return std::move ( new_op) ;
91109 }
92110
93111 Expr VisitExpr_ (const CallNode* cn) final {
94112 // Supported targets for this node. The order implies the priority.
95113 std::vector<std::string> supported_targets;
96114
115+ auto op_node = cn->op .as <OpNode>();
116+
117+ // This graph has annotations, meaning that this is not the first time running this pass.
118+ if (op_node && cn->op == compiler_begin_op) {
119+ // Bypass compiler begin due to lack of target information. It will be processed
120+ // when the following op handling arguments.
121+ CHECK_EQ (cn->args .size (), 1U );
122+ return VisitExpr (cn->args [0 ]);
123+ } else if (op_node && cn->op == compiler_end_op) {
124+ // Override compiler end with the new target.
125+ CHECK_EQ (cn->args .size (), 1U );
126+ auto input_expr = VisitExpr (cn->args [0 ]);
127+ CHECK (op_expr_to_target_.find (input_expr) != op_expr_to_target_.end ());
128+ return InsertAnnotation (input_expr, op_expr_to_target_[input_expr], make_end_op);
129+ }
130+
131+ // Peek the first argument. If it is compiler begin then this node had annotated by
132+ // another target before, so we also consider that target as a supported target.
133+ const CallNode* first_arg_call = cn->args [0 ].as <CallNode>();
134+ if (first_arg_call && first_arg_call->op == compiler_begin_op) {
135+ std::string arg_target = first_arg_call->attrs .as <CompilerAttrs>()->compiler ;
136+ if (arg_target != " default" ) {
137+ supported_targets.push_back (arg_target);
138+ }
139+ }
140+
97141 // Check which targets this op can be offloaded.
98- if (cn-> op -> IsInstance <OpNode>() ) {
142+ if (op_node ) {
99143 // TVM operators: Check target specific op checking function and add to supported_targets
100144 // if it is supported.
101145 Op op = Downcast<Op>(cn->op );
@@ -179,7 +223,7 @@ class AnnotateTargetWrapper : public ExprMutator {
179223 func = Downcast<Function>(new_e);
180224 new_body = func->body ;
181225 if (op_expr_to_target_.find (func->body ) != op_expr_to_target_.end ()) {
182- new_body = InsertAnnotation (func->body , op_expr_to_target_[func->body ], end_op );
226+ new_body = InsertAnnotation (func->body , op_expr_to_target_[func->body ], make_end_op );
183227 op_expr_to_target_[new_body] = op_expr_to_target_[func->body ];
184228 }
185229 }
0 commit comments