Skip to content

Commit f506c8b

Browse files
comaniaczhiics
andauthored
[BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)
* add target to region * refactor annotate_target * Make all unit test working * quick fix * enable BN, unit test failed * Fix vm test, unit test. Refactor annotate_target a bit. * quick fix fusion * revert fusion change * style fix * Refactor merge region pass * format * minor fix * Skip e2e test * lint * support AnnotateTarget multiple runs * Add HasAttr and revert DNNL codegen * address comment Co-authored-by: Zhi Chen <chzhi@amazon.com>
1 parent 5795539 commit f506c8b

15 files changed

Lines changed: 609 additions & 529 deletions

File tree

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,17 @@ def _func_wrapper(attrs, args):
5656
return _func_wrapper
5757

5858

59-
_register_external_op_helper("nn.batch_norm")
6059
_register_external_op_helper("nn.conv2d")
6160
_register_external_op_helper("nn.dense")
6261
_register_external_op_helper("nn.relu")
6362
_register_external_op_helper("add")
6463
_register_external_op_helper("subtract")
6564
_register_external_op_helper("multiply")
65+
66+
67+
@reg.register("nn.batch_norm", "target.dnnl")
68+
def batch_norm(attrs, args):
69+
"""Check if the external DNNL codegen should be used.
70+
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
71+
"""
72+
return False

python/tvm/relay/transform/transform.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,22 +587,24 @@ def PartitionGraph():
587587

588588

589589

590-
def AnnotateTarget(target):
590+
def AnnotateTarget(targets):
591591
"""Annotate ops in an experession with a provied compiler/target and then
592592
use it for codegen.
593593
594594
Parameters
595595
----------
596-
target : String
597-
The target compiler used for codegen.
596+
targets : str or List[str]
597+
The list of target compilers used for codegen.
598598
599599
Returns
600600
-------
601601
ret : tvm.relay.Pass
602602
The annotated pass that wrapps ops with subgraph_start and
603603
subgraph_end.
604604
"""
605-
return _ffi_api.AnnotateTarget(target)
605+
if isinstance(targets, str):
606+
targets = [targets]
607+
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
606608

607609

608610
def Inline():

src/relay/analysis/annotated_region_set.cc

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <tvm/relay/expr.h>
2323
#include <tvm/ir/error.h>
24+
#include <tvm/runtime/container.h>
2425

2526
#include <unordered_map>
2627
#include <vector>
@@ -31,7 +32,7 @@ namespace relay {
3132

3233
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
3334
for (auto candidate : regions_) {
34-
if (candidate->nodes.find(expr) != candidate->nodes.end()) {
35+
if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
3536
return candidate;
3637
}
3738
}
@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
4546
}
4647

4748
// Merge src to dest and erase src.
48-
dest->nodes.insert(src->nodes.begin(), src->nodes.end());
49-
for (const auto& input : src->ins) {
50-
dest->ins.push_back(input);
49+
dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
50+
for (const auto& input : src->ins_) {
51+
dest->ins_.push_back(input);
5152
}
52-
for (const auto& output : src->outs) {
53-
dest->outs.push_back(output);
53+
for (const auto& output : src->outs_) {
54+
dest->outs_.push_back(output);
5455
}
5556
// if any of the outputs of src are inputs of dest, they become internal nodes
5657
// so remove them from outs
5758
std::vector<Expr> ins_to_remove;
58-
for (const auto& input : dest->ins) {
59+
for (const auto& input : dest->ins_) {
5960
auto call = Downcast<Call>(input);
60-
auto it = src->nodes.find(call->args[0]);
61-
if (it != src->nodes.end()) {
62-
dest->outs.remove(*it);
61+
auto it = src->nodes_.find(call->args[0]);
62+
if (it != src->nodes_.end()) {
63+
dest->outs_.remove(*it);
6364
ins_to_remove.push_back(input);
6465
}
6566
}
6667
for (const auto& input : ins_to_remove) {
67-
dest->ins.remove(input);
68+
dest->ins_.remove(input);
6869
}
6970
regions_.erase(src);
7071
}
@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
7475
if (src.defined()) {
7576
MergeRegions(src, dest);
7677
} else {
77-
dest->nodes.insert(expr);
78+
dest->nodes_.insert(expr);
7879
}
7980
}
8081

81-
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
82+
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
8283
auto ret = regions_.emplace(AnnotatedRegion());
83-
(*ret.first)->id = region_id_++;
84+
(*ret.first)->id_ = region_id_++;
85+
(*ret.first)->target_ = target;
8486
return *ret.first;
8587
}
8688

8789
class AnnotatedRegionSet::Creator : public ExprVisitor {
8890
public:
89-
Creator(const Op& region_begin_op, const Op& region_end_op) :
90-
begin_op_(region_begin_op), end_op_(region_end_op) {}
91-
92-
AnnotatedRegionSet Create(const Expr& expr) {
93-
VisitExpr(expr);
94-
return std::move(region_set_);
95-
}
91+
Creator(const Op& region_begin_op, const Op& region_end_op)
92+
: begin_op_(region_begin_op), end_op_(region_end_op) {}
9693

9794
void VisitExpr_(const CallNode* call) {
9895
auto op_node = call->op.as<OpNode>();
@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
115112
<< "Cannot find the corresponding region for start annotation:\n"
116113
<< AsText(GetRef<Call>(call), false));
117114
}
118-
region->ins.push_back(GetRef<Call>(call));
115+
region->ins_.push_back(GetRef<Call>(call));
119116
} else {
120117
CHECK_EQ(call->op, end_op_);
121118
// The annotation node is inserted on edge so it must have only one argument.
122119
CHECK_EQ(call->args.size(), 1U);
120+
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
123121

124122
// Check if the argument already belongs to a region
125123
auto region = region_set_->GetRegion(call->args[0]);
126124
if (!region.defined()) {
127-
region = region_set_->MakeRegion();
128-
region->nodes.insert(call->args[0]);
125+
// Create a new region if the argument is not belonged to any regions yet.
126+
region = region_set_->MakeRegion(target);
127+
region->nodes_.insert(call->args[0]);
128+
} else {
129+
// If the argument is belonged to a region, it must have the same target.
130+
// Otherwise we should see a region_begin op.
131+
CHECK_EQ(region->GetTarget(), target);
129132
}
130-
region->nodes.insert(GetRef<Call>(call));
131-
region->outs.push_back(GetRef<Call>(call));
133+
region->nodes_.insert(GetRef<Call>(call));
134+
region->outs_.push_back(GetRef<Call>(call));
132135
}
133136
ExprVisitor::VisitExpr_(call);
134137
}
135138

139+
AnnotatedRegionSet Create(const Expr& expr) {
140+
VisitExpr(expr);
141+
return std::move(region_set_);
142+
}
143+
136144
void VisitExpr_(const TupleNode* op) {
137145
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
138146
if (region.defined()) {

src/relay/analysis/annotated_region_set.h

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <tvm/relay/expr.h>
3333
#include <tvm/ir/error.h>
3434
#include <tvm/relay/expr_functor.h>
35+
#include <tvm/runtime/container.h>
3536
#include <tvm/relay/transform.h>
3637

3738
#include <string>
@@ -49,47 +50,55 @@ class AnnotatedRegionSet;
4950
class AnnotatedRegionNode : public Object {
5051
public:
5152
void VisitAttrs(AttrVisitor* v) {
52-
v->Visit("id", &id);
53-
Array<Expr> nodes_array(nodes.begin(), nodes.end());
53+
v->Visit("id", &id_);
54+
v->Visit("target", &target_);
55+
Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
5456
v->Visit("nodes", &nodes_array);
55-
Array<Expr> args_array(ins.begin(), ins.end());
57+
Array<Expr> args_array(ins_.begin(), ins_.end());
5658
v->Visit("args", &args_array);
57-
Array<Expr> rets_array(outs.begin(), outs.end());
59+
Array<Expr> rets_array(outs_.begin(), outs_.end());
5860
v->Visit("rets", &rets_array);
5961
}
6062

6163
/*! \brief Get the region ID. */
6264
int GetID() const {
63-
return id;
65+
return id_;
66+
}
67+
68+
/*! \brief Get the region target. */
69+
std::string GetTarget() const {
70+
return target_;
6471
}
6572

6673
/*! \brief Get the region's inputs. */
6774
std::list<Expr> GetInputs() const {
68-
return ins;
75+
return ins_;
6976
}
7077

7178
/*! \brief Get the region's outputs. */
7279
std::list<Expr> GetOutputs() const {
73-
return outs;
80+
return outs_;
7481
}
7582

7683
/*! \brief Get the region's nodes. */
7784
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
78-
return nodes;
85+
return nodes_;
7986
}
8087

8188
static constexpr const char* _type_key = "relay.AnnotatedRegion";
8289
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);
8390

8491
protected:
8592
/*! \brief The region ID. */
86-
int id{-1};
93+
int id_{-1};
94+
/*! \brief The target for this region. */
95+
std::string target_ = "default";
8796
/*! \brief The inputs to this region. */
88-
std::list<Expr> ins;
97+
std::list<Expr> ins_;
8998
/*! \brief The outputs of this region */
90-
std::list<Expr> outs;
99+
std::list<Expr> outs_;
91100
/*! \brief Nodes in this region. */
92-
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
101+
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;
93102

94103
friend class AnnotatedRegionSet;
95104
friend class AnnotatedRegionSetNode;
@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
184193
void AddToRegion(AnnotatedRegion dest, const Expr& expr);
185194

186195
/*!
187-
* \brief Make a new region.
196+
* \brief Make a new region for a target.
188197
*
189198
* \return The new region.
190199
*/
191-
AnnotatedRegion MakeRegion();
200+
AnnotatedRegion MakeRegion(const std::string& target);
192201

193202
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
194203
/*! \brief The next region ID to assign. */

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
5353
}
5454

5555
void VisitExpr_(const TupleGetItemNode* op) final {
56-
VisitExpr(op->tuple);
57-
CHECK(out_.size() > static_cast<size_t>(op->index));
58-
59-
// Only keep the item we want for the child node.
60-
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
61-
auto item = out_[op->index];
62-
out_.clear();
63-
out_.push_back(item);
56+
// Do nothing
6457
}
6558

6659
void VisitExpr_(const CallNode* call) final {
6760
std::ostringstream decl_stream;
68-
61+
std::ostringstream buf_stream;
6962
// Args: ID
7063
std::vector<std::string> args;
7164

@@ -103,52 +96,36 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
10396
}
10497
}
10598

106-
// Analyze the output buffers
107-
std::vector<Type> out_types;
108-
if (call->checked_type()->IsInstance<TupleTypeNode>()) {
109-
auto type_node = call->checked_type().as<TupleTypeNode>();
110-
for (auto field : type_node->fields) {
111-
CHECK(field->IsInstance<TensorTypeNode>());
112-
out_types.push_back(field);
113-
}
114-
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
115-
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
116-
out_types.push_back(call->checked_type());
117-
} else {
118-
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
119-
}
120-
121-
out_.clear();
122-
for (auto out_type : out_types) {
123-
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
124-
125-
std::string out = "buf_" + std::to_string(buf_idx_++);
126-
auto out_shape = GetShape(out_type);
127-
int out_size = 1;
128-
for (size_t i = 0; i < out_shape.size(); ++i) {
129-
out_size *= out_shape[i];
130-
}
131-
this->PrintIndents();
132-
std::ostringstream buf_stream;
133-
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
134-
buf_decl_.push_back(buf_stream.str());
135-
decl_stream << ", " << out;
136-
137-
// Update output buffer
138-
Output output;
139-
output.name = out;
140-
output.dtype = dtype;
141-
output.need_copy = true;
142-
output.size = out_size;
143-
out_.push_back(output);
99+
// Analyze the output buffer
100+
auto type_node = call->checked_type().as<TensorTypeNode>();
101+
CHECK(type_node);
102+
const auto& dtype = GetDtypeString(type_node);
103+
std::string out = "buf_" + std::to_string(buf_idx_++);
104+
auto out_shape = GetShape(call->checked_type());
105+
int out_size = 1;
106+
for (size_t i = 0; i < out_shape.size(); ++i) {
107+
out_size *= out_shape[i];
144108
}
109+
this->PrintIndents();
110+
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
111+
buf_decl_.push_back(buf_stream.str());
112+
decl_stream << ", " << out;
145113

146114
// Attach attribute arguments
147115
for (size_t i = 0; i < args.size(); ++i) {
148116
decl_stream << ", " << args[i];
149117
}
150118
decl_stream << ");";
151119
ext_func_body.push_back(decl_stream.str());
120+
121+
// Update output buffer
122+
out_.clear();
123+
Output output;
124+
output.name = out;
125+
output.dtype = dtype;
126+
output.need_copy = true;
127+
output.size = out_size;
128+
out_.push_back(output);
152129
}
153130

154131
std::string JIT(void) {

src/relay/backend/vm/compiler.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -924,20 +924,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
924924
pass_seqs.push_back(transform::LambdaLift());
925925
pass_seqs.push_back(transform::InlinePrimitives());
926926

927-
// Manifest the allocations.
928-
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
929-
// Compute away possibly introduced constant computation.
930-
pass_seqs.push_back(transform::FoldConstant());
931-
// Fuse the shape functions.
932-
pass_seqs.push_back(transform::FuseOps());
933-
934927
// Inline the functions that are lifted to the module scope. We perform this
935928
// pass after all other optimization passes but before the memory allocation
936929
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
937930
// and we use these ops to invoke the symbols in the module generated by
938931
// external codegen.
939932
pass_seqs.push_back(transform::Inline());
940933

934+
// Manifest the allocations.
935+
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
936+
// Compute away possibly introduced constant computation.
937+
pass_seqs.push_back(transform::FoldConstant());
938+
// Fuse the shape functions.
939+
pass_seqs.push_back(transform::FuseOps());
941940
// Manifest the allocations needed for the shape functions.
942941
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
943942

0 commit comments

Comments
 (0)