Skip to content

Commit de9e85f

Browse files
committed
support AnnotateTarget multiple runs
1 parent dbd6301 commit de9e85f

2 files changed

Lines changed: 84 additions & 13 deletions

File tree

src/relay/transforms/annotate_target.cc

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@ namespace tvm {
3333
namespace relay {
3434
namespace annotate_target {
3535

36-
const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
37-
const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
36+
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
37+
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
38+
39+
const PackedFunc* make_begin_op =
40+
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
41+
const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
3842

3943
// A helper class to insert annotation boundaries for a program region that will
4044
// be handled by a specific compiler.
@@ -59,26 +63,40 @@ class AnnotateTargetWrapper : public ExprMutator {
5963
std::string ref_target = "";
6064
Array<Expr> compiler_ends;
6165
for (auto arg : args) {
62-
if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
63-
std::string arg_target = op_expr_to_target_[arg];
64-
compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op));
65-
if (ref_target == "") {
66-
ref_target = arg_target;
67-
} else if (ref_target != arg_target) {
68-
ref_target = "default";
66+
std::string arg_target = "defualt";
67+
const CallNode* call = arg.as<CallNode>();
68+
69+
if (call && call->op == compiler_begin_op) {
70+
// Argument is already compiler begin node meaning that this is not the first time
71+
// running this pass, so we simply remove it and will add a new one later.
72+
CHECK_EQ(call->args.size(), 1U);
73+
const CallNode* end = call->args[0].as<CallNode>();
74+
if (end->op == compiler_end_op) {
75+
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
6976
}
77+
compiler_ends.push_back(call->args[0]);
78+
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
79+
arg_target = op_expr_to_target_[arg];
80+
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
7081
} else {
7182
// Input vars.
7283
compiler_ends.push_back(arg);
7384
}
85+
86+
// Maintain reference target in case the target of the current node is unassigned.
87+
if (ref_target == "") {
88+
ref_target = arg_target;
89+
} else if (ref_target != arg_target) {
90+
ref_target = "default";
91+
}
7492
}
7593

7694
// Determine compiler begin target.
7795
std::string op_target = (target == "") ? ref_target : target;
7896

7997
Array<Expr> compiler_begins;
8098
for (const auto& end : compiler_ends) {
81-
compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op));
99+
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
82100
}
83101

84102
return {op_target, compiler_begins};
@@ -87,15 +105,41 @@ class AnnotateTargetWrapper : public ExprMutator {
87105
Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
88106
Expr new_op = (*ann_op)(expr, target);
89107
new_op->checked_type_ = expr->checked_type_;
90-
return new_op;
108+
return std::move(new_op);
91109
}
92110

93111
Expr VisitExpr_(const CallNode* cn) final {
94112
// Supported targets for this node. The order implies the priority.
95113
std::vector<std::string> supported_targets;
96114

115+
auto op_node = cn->op.as<OpNode>();
116+
117+
// This graph has annotations, meaning that this is not the first time running this pass.
118+
if (op_node && cn->op == compiler_begin_op) {
119+
// Bypass compiler begin due to lack of target information. It will be processed
120+
// when the following op handling arguments.
121+
CHECK_EQ(cn->args.size(), 1U);
122+
return VisitExpr(cn->args[0]);
123+
} else if (op_node && cn->op == compiler_end_op) {
124+
// Override compiler end with the new target.
125+
CHECK_EQ(cn->args.size(), 1U);
126+
auto input_expr = VisitExpr(cn->args[0]);
127+
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
128+
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
129+
}
130+
131+
// Peek the first argument. If it is compiler begin then this node had annotated by
132+
// another target before, so we also consider that target as a supported target.
133+
const CallNode* first_arg_call = cn->args[0].as<CallNode>();
134+
if (first_arg_call && first_arg_call->op == compiler_begin_op) {
135+
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
136+
if (arg_target != "default") {
137+
supported_targets.push_back(arg_target);
138+
}
139+
}
140+
97141
// Check which targets this op can be offloaded.
98-
if (cn->op->IsInstance<OpNode>()) {
142+
if (op_node) {
99143
// TVM operators: Check target specific op checking function and add to supported_targets
100144
// if it is supported.
101145
Op op = Downcast<Op>(cn->op);
@@ -179,7 +223,7 @@ class AnnotateTargetWrapper : public ExprMutator {
179223
func = Downcast<Function>(new_e);
180224
new_body = func->body;
181225
if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
182-
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], end_op);
226+
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
183227
op_expr_to_target_[new_body] = op_expr_to_target_[func->body];
184228
}
185229
}

tests/python/relay/test_pass_annotate_target.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,37 @@ def after():
338338
assert tvm.ir.structural_equal(expected, result)
339339

340340

341+
def test_multiple_runs():
342+
@reg.register("nn.relu", "target.A")
343+
def relu(attrs, args): # pylint: disable=unused-variable
344+
return True
345+
346+
@reg.register("add", "target.B")
347+
def add(attrs, args): # pylint: disable=unused-variable
348+
return True
349+
350+
def before():
351+
x = relay.var("x", shape=(10, 5))
352+
a_1 = relay.nn.relu(x)
353+
a_2 = relay.abs(a_1)
354+
a_3 = relay.nn.relu(a_1)
355+
out = relay.add(a_2, a_3)
356+
357+
f = relay.Function([x], out)
358+
mod = tvm.IRModule.from_expr(f)
359+
return mod
360+
361+
mod = transform.AnnotateTarget("A")(before())
362+
mod = transform.AnnotateTarget("B")(mod)
363+
expected = transform.AnnotateTarget(["A", "B"])(before())
364+
assert tvm.ir.structural_equal(expected, mod)
365+
366+
341367
if __name__ == "__main__":
342368
test_extern_dnnl()
343369
test_composite_function()
344370
#test_extern_dnnl_mobilenet()
345371
test_multiple_ends()
346372
test_type_propagation()
347373
test_tuple()
374+
test_multiple_runs()

0 commit comments

Comments
 (0)