Skip to content

Commit 869798c

Browse files
committed
Implementation of relay_to_tir target hook
This the first new hook proposed in the Additional Target Hooks RFC, longer term the compilation should move to using `Target` proper but this unblocks our current work whilst illustrating the eventual interface via `Target` in `src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc` Ideally the host target would be annotated onto the `IRModule` so as this `Pass` could use it instead of defaulting to C but this is fine for now.
1 parent 60a4db6 commit 869798c

9 files changed

Lines changed: 328 additions & 13 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ include(cmake/modules/contrib/EthosU.cmake)
407407
include(cmake/modules/contrib/BLAS.cmake)
408408
include(cmake/modules/contrib/CODEGENC.cmake)
409409
include(cmake/modules/contrib/DNNL.cmake)
410+
include(cmake/modules/contrib/ExampleTargetHooks.cmake)
410411
include(cmake/modules/contrib/Random.cmake)
411412
include(cmake/modules/contrib/Posit.cmake)
412413
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc)
19+
list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC})

include/tvm/relay/transform.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,13 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
426426
*/
427427
TVM_DLL Pass SimplifyExpr();
428428

429+
/*!
430+
* \brief Run any registered RelayToTIR passes registered on the functions in a module.
431+
*
432+
* \return The pass.
433+
*/
434+
TVM_DLL Pass RelayToTIRTargetHook();
435+
429436
/*!
430437
* \brief A pass for manifesting explicit memory allocations and rewriting
431438
* specific dialects.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
2+
/*
3+
* Licensed to the Apache Software Foundation (ASF) under one
4+
* or more contributor license agreements. See the NOTICE file
5+
* distributed with this work for additional information
6+
* regarding copyright ownership. The ASF licenses this file
7+
* to you under the Apache License, Version 2.0 (the
8+
* "License"); you may not use this file except in compliance
9+
* with the License. You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing,
14+
* software distributed under the License is distributed on an
15+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
* KIND, either express or implied. See the License for the
17+
* specific language governing permissions and limitations
18+
* under the License.
19+
*/
20+
#include <tvm/relay/expr_functor.h>
21+
#include <tvm/relay/transform.h>
22+
#include <tvm/tir/buffer.h>
23+
#include <tvm/tir/builtin.h>
24+
#include <tvm/tir/expr.h>
25+
#include <tvm/tir/function.h>
26+
#include <tvm/tir/op.h>
27+
28+
namespace tvm {
29+
namespace relay {
30+
namespace contrib {
31+
namespace example_target_hooks {
32+
33+
class ConvertAddToSubtract : public MixedModeMutator {
34+
public:
35+
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
36+
: ir_module_(ir_module), host_target_(host_target) {}
37+
38+
IRModule Mutate() {
39+
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
40+
BaseFunc main = ir_module_->Lookup(main_global_var);
41+
Function main_func = GetRef<Function>(main.as<FunctionNode>());
42+
43+
// Copy everything across and mutate the body
44+
Function mutated_main =
45+
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
46+
main_func->type_params, main_func->attrs, main_func->span);
47+
48+
ir_module_->Update(main_global_var, mutated_main);
49+
50+
return ir_module_;
51+
}
52+
53+
private:
54+
tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) {
55+
return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true());
56+
}
57+
58+
void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) {
59+
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
60+
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
61+
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));
62+
63+
tir::Var x_var("x", DataType::Handle());
64+
tir::Var y_var("y", DataType::Handle());
65+
tir::Var out_var("out", DataType::Handle());
66+
67+
Map<String, ObjectRef> dict_attrs;
68+
dict_attrs.Set("global_symbol", new_global_var->name_hint);
69+
dict_attrs.Set("tir.noalias", Bool(true));
70+
71+
te::Var index("index", DataType::Int(32));
72+
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
73+
tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true());
74+
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);
75+
76+
Map<tir::Var, tir::Buffer> buffer_map = {
77+
{x_var, x_buffer},
78+
{y_var, y_buffer},
79+
{out_var, out_buffer},
80+
};
81+
82+
tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
83+
buffer_map, DictAttrs(dict_attrs));
84+
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
85+
ir_module_->Add(new_global_var, replacement_func);
86+
}
87+
88+
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
89+
if (const CallNode* call = post.as<CallNode>()) {
90+
auto* func = call->op.as<FunctionNode>();
91+
if (func == nullptr) {
92+
return post;
93+
}
94+
95+
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
96+
if (func_name.defined() && func_name == "replace_add_with_subtract") {
97+
// Introduce a new global var to map the function to and copy the source type
98+
// over for InferType
99+
GlobalVar new_global_var(func_name.value());
100+
new_global_var->checked_type_ = func->checked_type();
101+
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
102+
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);
103+
}
104+
}
105+
106+
return post;
107+
}
108+
109+
public:
110+
IRModule ir_module_;
111+
Target host_target_;
112+
};
113+
114+
transform::Pass RelayToTIR() {
115+
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
116+
[=](IRModule ir_module, transform::PassContext pass_context) {
117+
auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c"));
118+
return relay_to_tir.Mutate();
119+
};
120+
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});
121+
}
122+
123+
} // namespace example_target_hooks
124+
} // namespace contrib
125+
} // namespace relay
126+
127+
TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
128+
.set_attr<tvm::transform::Pass>("FTVMRelayToTIR",
129+
relay::contrib::example_target_hooks::RelayToTIR());
130+
131+
} // namespace tvm

src/relay/backend/te_compiler.cc

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class TECompilerImpl : public TECompilerNode {
131131
Array<tvm::runtime::Module> ret;
132132
std::unordered_map<std::string, std::string> cached_symbol;
133133
std::vector<CCacheKey> cached_ext_funcs;
134+
134135
for (const auto& it : cache_) {
135136
auto src_func = it.first->source_func;
136137
ICHECK(src_func.defined());
@@ -383,10 +384,9 @@ class LowerTensorExprMutator : public ExprMutator {
383384
* \brief Returns the primitive function associated with \p expr, or
384385
* nullptr if none.
385386
*/
386-
Function ResolveToPrimitive(Expr expr) {
387+
BaseFunc ResolveToPrimitive(Expr expr) {
387388
if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
388-
BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(gvn));
389-
return ResolveToPrimitive(base_func);
389+
return module_->Lookup(GetRef<GlobalVar>(gvn));
390390
} else if (const VarNode* vn = expr.as<VarNode>()) {
391391
auto itr = primitive_functions_.find(GetRef<Var>(vn));
392392
return itr == primitive_functions_.end() ? Function() : itr->second;
@@ -516,10 +516,17 @@ class LowerTensorExprMutator : public ExprMutator {
516516
Expr VisitExpr_(const LetNode* let) override {
517517
Var var = Downcast<Var>(Mutate(let->var));
518518
Expr value = Mutate(let->value);
519-
Function prim_func = ResolveToPrimitive(value);
519+
BaseFunc prim_func = ResolveToPrimitive(value);
520+
520521
if (prim_func.defined()) {
522+
// Already lowered by other means, no need to mutate the Let node
523+
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
524+
return GetRef<Let>(let);
525+
}
526+
521527
// Remember let var is bound to (possibly indirectly) to a primitive.
522-
primitive_functions_.emplace(let->var, prim_func);
528+
Function func = GetRef<Function>(prim_func.as<FunctionNode>());
529+
primitive_functions_.emplace(let->var, func);
523530
}
524531
Expr body = Mutate(let->body);
525532
if (prim_func.defined()) {
@@ -537,7 +544,7 @@ class LowerTensorExprMutator : public ExprMutator {
537544
Call expr = GetRef<Call>(call);
538545

539546
// Look for (indirect) calls to primitives.
540-
Function prim_func = ResolveToPrimitive(call->op);
547+
BaseFunc prim_func = ResolveToPrimitive(call->op);
541548
if (!prim_func.defined()) {
542549
// Not a call to a primitive function.
543550
if (const FunctionNode* fn = call->op.as<FunctionNode>()) {
@@ -546,6 +553,12 @@ class LowerTensorExprMutator : public ExprMutator {
546553
return ExprMutator::VisitExpr_(call);
547554
}
548555

556+
// Already lowered by other means so we don't need to mutate
557+
// the call
558+
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
559+
return expr;
560+
}
561+
549562
// Find the desired target device.
550563
Target target;
551564
if (prim_func->GetAttr<String>(attr::kCompiler).defined()) {
@@ -565,7 +578,8 @@ class LowerTensorExprMutator : public ExprMutator {
565578
}
566579

567580
// Lower the primitive function for that target.
568-
std::pair<GlobalVar, Attrs> pair = LowerFunction(prim_func, target);
581+
Function func = GetRef<Function>(prim_func.as<FunctionNode>());
582+
std::pair<GlobalVar, Attrs> pair = LowerFunction(func, target);
569583

570584
// Similarly transform arguments.
571585
Array<Expr> args;
@@ -648,8 +662,6 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
648662

649663
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets,
650664
Map<Expr, backend::StorageInfo> storage_info_map) {
651-
CHECK_EQ(mod->functions.size(), 1)
652-
<< "There should only be one function in the module passed to UpdateMainWorkspaceSize";
653665
Function func = Downcast<Function>(mod->Lookup("main"));
654666

655667
// This is a Map<device,Map<storage_id, size>>
@@ -926,8 +938,10 @@ Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
926938
PassContext ctx) {
927939
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
928940
};
929-
return tvm::transform::Sequential(
930-
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});
941+
942+
return tvm::transform::Sequential({tvm::relay::transform::RelayToTIRTargetHook(),
943+
tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}),
944+
InferType()});
931945
}
932946
} // namespace tec
933947
} // namespace relay
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file target_hooks.cc
22+
* \brief Relay passes for processing Target Hooks which have been registered on functions within
23+
* the IRModule
24+
*/
25+
26+
#include <tvm/relay/expr_functor.h>
27+
#include <tvm/relay/transform.h>
28+
29+
namespace tvm {
30+
namespace relay {
31+
namespace transform {
32+
33+
class TargetHookVisitor : public tvm::relay::MixedModeVisitor {
34+
/*! \brief Collected pass list for all nodes */
35+
std::vector<Pass> pass_list_;
36+
/*! \brief Attribute map for all registered targets */
37+
TargetKindAttrMap<Pass> target_attr_map_;
38+
39+
public:
40+
TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap<Pass>("FTVMRelayToTIR")) {}
41+
42+
std::vector<Pass> Visit(const IRModule& ir_mod) {
43+
for (const auto& it : ir_mod->functions) {
44+
const BaseFunc& base_func = it.second;
45+
VisitExpr(base_func);
46+
}
47+
return pass_list_;
48+
}
49+
50+
void VisitExpr_(const CallNode* call) override {
51+
// Descend the call tree
52+
for (auto arg : call->args) {
53+
VisitExpr(arg);
54+
}
55+
56+
if (const FunctionNode* func = call->op.as<FunctionNode>()) {
57+
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
58+
return;
59+
}
60+
String code_gen_name = func->GetAttr<String>(attr::kCompiler).value();
61+
Optional<TargetKind> target_kind = tvm::TargetKind::Get(code_gen_name);
62+
if (!target_kind || !target_attr_map_.count(target_kind.value())) {
63+
return;
64+
}
65+
Pass custom_target_pass = target_attr_map_[target_kind.value()];
66+
if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) {
67+
pass_list_.push_back(custom_target_pass);
68+
}
69+
}
70+
}
71+
};
72+
73+
Pass RelayToTIRTargetHook() {
74+
auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) {
75+
auto target_hook_visitor = TargetHookVisitor();
76+
std::vector<Pass> pass_list = target_hook_visitor.Visit(mod);
77+
Sequential run_hooks(pass_list);
78+
79+
return run_hooks(mod);
80+
};
81+
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {});
82+
}
83+
84+
} // namespace transform
85+
} // namespace relay
86+
} // namespace tvm

src/target/target_kind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
* \brief Target kind registry
2323
*/
2424
#include <tvm/ir/expr.h>
25+
#include <tvm/ir/transform.h>
2526
#include <tvm/runtime/device_api.h>
2627
#include <tvm/runtime/registry.h>
2728
#include <tvm/target/target.h>

0 commit comments

Comments
 (0)