2121
2222#include < tvm/relay/expr.h>
2323#include < tvm/ir/error.h>
24+ #include < tvm/runtime/container.h>
2425
2526#include < unordered_map>
2627#include < vector>
@@ -31,7 +32,7 @@ namespace relay {
3132
3233AnnotatedRegion AnnotatedRegionSetNode::GetRegion (const Expr& expr) const {
3334 for (auto candidate : regions_) {
34- if (candidate->nodes .find (expr) != candidate->nodes .end ()) {
35+ if (candidate->nodes_ .find (expr) != candidate->nodes_ .end ()) {
3536 return candidate;
3637 }
3738 }
@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
4546 }
4647
4748 // Merge src to dest and erase src.
48- dest->nodes .insert (src->nodes .begin (), src->nodes .end ());
49- for (const auto & input : src->ins ) {
50- dest->ins .push_back (input);
49+ dest->nodes_ .insert (src->nodes_ .begin (), src->nodes_ .end ());
50+ for (const auto & input : src->ins_ ) {
51+ dest->ins_ .push_back (input);
5152 }
52- for (const auto & output : src->outs ) {
53- dest->outs .push_back (output);
53+ for (const auto & output : src->outs_ ) {
54+ dest->outs_ .push_back (output);
5455 }
5556 // if any of the outputs of src are inputs of dest, they become internal nodes
5657 // so remove them from outs
5758 std::vector<Expr> ins_to_remove;
58- for (const auto & input : dest->ins ) {
59+ for (const auto & input : dest->ins_ ) {
5960 auto call = Downcast<Call>(input);
60- auto it = src->nodes .find (call->args [0 ]);
61- if (it != src->nodes .end ()) {
62- dest->outs .remove (*it);
61+ auto it = src->nodes_ .find (call->args [0 ]);
62+ if (it != src->nodes_ .end ()) {
63+ dest->outs_ .remove (*it);
6364 ins_to_remove.push_back (input);
6465 }
6566 }
6667 for (const auto & input : ins_to_remove) {
67- dest->ins .remove (input);
68+ dest->ins_ .remove (input);
6869 }
6970 regions_.erase (src);
7071}
@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
7475 if (src.defined ()) {
7576 MergeRegions (src, dest);
7677 } else {
77- dest->nodes .insert (expr);
78+ dest->nodes_ .insert (expr);
7879 }
7980}
8081
81- AnnotatedRegion AnnotatedRegionSetNode::MakeRegion () {
82+ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion (const std::string& target ) {
8283 auto ret = regions_.emplace (AnnotatedRegion ());
83- (*ret.first )->id = region_id_++;
84+ (*ret.first )->id_ = region_id_++;
85+ (*ret.first )->target_ = target;
8486 return *ret.first ;
8587}
8688
8789class AnnotatedRegionSet ::Creator : public ExprVisitor {
8890 public:
89- Creator (const Op& region_begin_op, const Op& region_end_op) :
90- begin_op_ (region_begin_op), end_op_(region_end_op) {}
91-
92- AnnotatedRegionSet Create (const Expr& expr) {
93- VisitExpr (expr);
94- return std::move (region_set_);
95- }
91+ Creator (const Op& region_begin_op, const Op& region_end_op)
92+ : begin_op_(region_begin_op), end_op_(region_end_op) {}
9693
9794 void VisitExpr_ (const CallNode* call) {
9895 auto op_node = call->op .as <OpNode>();
@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
115112 << " Cannot find the corresponding region for start annotation:\n "
116113 << AsText (GetRef<Call>(call), false ));
117114 }
118- region->ins .push_back (GetRef<Call>(call));
115+ region->ins_ .push_back (GetRef<Call>(call));
119116 } else {
120117 CHECK_EQ (call->op , end_op_);
121118 // The annotation node is inserted on edge so it must have only one argument.
122119 CHECK_EQ (call->args .size (), 1U );
120+ std::string target = call->attrs .as <CompilerAttrs>()->compiler ;
123121
124122 // Check if the argument already belongs to a region
125123 auto region = region_set_->GetRegion (call->args [0 ]);
126124 if (!region.defined ()) {
127- region = region_set_->MakeRegion ();
128- region->nodes .insert (call->args [0 ]);
125+ // Create a new region if the argument is not belonged to any regions yet.
126+ region = region_set_->MakeRegion (target);
127+ region->nodes_ .insert (call->args [0 ]);
128+ } else {
129+ // If the argument is belonged to a region, it must have the same target.
130+ // Otherwise we should see a region_begin op.
131+ CHECK_EQ (region->GetTarget (), target);
129132 }
130- region->nodes .insert (GetRef<Call>(call));
131- region->outs .push_back (GetRef<Call>(call));
133+ region->nodes_ .insert (GetRef<Call>(call));
134+ region->outs_ .push_back (GetRef<Call>(call));
132135 }
133136 ExprVisitor::VisitExpr_ (call);
134137 }
135138
139+ AnnotatedRegionSet Create (const Expr& expr) {
140+ VisitExpr (expr);
141+ return std::move (region_set_);
142+ }
143+
136144 void VisitExpr_ (const TupleNode* op) {
137145 auto region = region_set_->GetRegion (GetRef<Tuple>(op));
138146 if (region.defined ()) {
0 commit comments