|
| 1 | + |
| 2 | +/* |
| 3 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 4 | + * or more contributor license agreements. See the NOTICE file |
| 5 | + * distributed with this work for additional information |
| 6 | + * regarding copyright ownership. The ASF licenses this file |
| 7 | + * to you under the Apache License, Version 2.0 (the |
| 8 | + * "License"); you may not use this file except in compliance |
| 9 | + * with the License. You may obtain a copy of the License at |
| 10 | + * |
| 11 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | + * |
| 13 | + * Unless required by applicable law or agreed to in writing, |
| 14 | + * software distributed under the License is distributed on an |
| 15 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 16 | + * KIND, either express or implied. See the License for the |
| 17 | + * specific language governing permissions and limitations |
| 18 | + * under the License. |
| 19 | + */ |
| 20 | +#include <tvm/relay/expr_functor.h> |
| 21 | +#include <tvm/relay/transform.h> |
| 22 | +#include <tvm/tir/buffer.h> |
| 23 | +#include <tvm/tir/builtin.h> |
| 24 | +#include <tvm/tir/expr.h> |
| 25 | +#include <tvm/tir/function.h> |
| 26 | +#include <tvm/tir/op.h> |
| 27 | + |
| 28 | +namespace tvm { |
| 29 | +namespace relay { |
| 30 | +namespace contrib { |
| 31 | +namespace example_target_hooks { |
| 32 | + |
| 33 | +class ConvertAddToSubtract : public MixedModeMutator { |
| 34 | + public: |
| 35 | + explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) |
| 36 | + : ir_module_(ir_module), host_target_(host_target) {} |
| 37 | + |
| 38 | + IRModule Mutate() { |
| 39 | + GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); |
| 40 | + BaseFunc main = ir_module_->Lookup(main_global_var); |
| 41 | + Function main_func = GetRef<Function>(main.as<FunctionNode>()); |
| 42 | + |
| 43 | + // Copy everything across and mutate the body |
| 44 | + Function mutated_main = |
| 45 | + Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, |
| 46 | + main_func->type_params, main_func->attrs, main_func->span); |
| 47 | + |
| 48 | + ir_module_->Update(main_global_var, mutated_main); |
| 49 | + |
| 50 | + return ir_module_; |
| 51 | + } |
| 52 | + |
| 53 | + private: |
| 54 | + tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) { |
| 55 | + return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true()); |
| 56 | + } |
| 57 | + |
| 58 | + void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { |
| 59 | + tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); |
| 60 | + tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); |
| 61 | + tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); |
| 62 | + |
| 63 | + tir::Var x_var("x", DataType::Handle()); |
| 64 | + tir::Var y_var("y", DataType::Handle()); |
| 65 | + tir::Var out_var("out", DataType::Handle()); |
| 66 | + |
| 67 | + Map<String, ObjectRef> dict_attrs; |
| 68 | + dict_attrs.Set("global_symbol", new_global_var->name_hint); |
| 69 | + dict_attrs.Set("tir.noalias", Bool(true)); |
| 70 | + |
| 71 | + te::Var index("index", DataType::Int(32)); |
| 72 | + tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); |
| 73 | + tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true()); |
| 74 | + tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); |
| 75 | + |
| 76 | + Map<tir::Var, tir::Buffer> buffer_map = { |
| 77 | + {x_var, x_buffer}, |
| 78 | + {y_var, y_buffer}, |
| 79 | + {out_var, out_buffer}, |
| 80 | + }; |
| 81 | + |
| 82 | + tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), |
| 83 | + buffer_map, DictAttrs(dict_attrs)); |
| 84 | + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); |
| 85 | + ir_module_->Add(new_global_var, replacement_func); |
| 86 | + } |
| 87 | + |
| 88 | + Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
| 89 | + if (const CallNode* call = post.as<CallNode>()) { |
| 90 | + auto* func = call->op.as<FunctionNode>(); |
| 91 | + if (func == nullptr) { |
| 92 | + return post; |
| 93 | + } |
| 94 | + |
| 95 | + auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol); |
| 96 | + if (func_name.defined() && func_name == "replace_add_with_subtract") { |
| 97 | + // Introduce a new global var to map the function to and copy the source type |
| 98 | + // over for InferType |
| 99 | + GlobalVar new_global_var(func_name.value()); |
| 100 | + new_global_var->checked_type_ = func->checked_type(); |
| 101 | + ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func)); |
| 102 | + return Call(new_global_var, call->args, call->attrs, call->type_args, call->span); |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + return post; |
| 107 | + } |
| 108 | + |
| 109 | + public: |
| 110 | + IRModule ir_module_; |
| 111 | + Target host_target_; |
| 112 | +}; |
| 113 | + |
| 114 | +transform::Pass RelayToTIR() { |
| 115 | + runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func = |
| 116 | + [=](IRModule ir_module, transform::PassContext pass_context) { |
| 117 | + auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c")); |
| 118 | + return relay_to_tir.Mutate(); |
| 119 | + }; |
| 120 | + return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); |
| 121 | +} |
| 122 | + |
| 123 | +} // namespace example_target_hooks |
| 124 | +} // namespace contrib |
| 125 | +} // namespace relay |
| 126 | + |
| 127 | +TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) |
| 128 | + .set_attr<tvm::transform::Pass>("FTVMRelayToTIR", |
| 129 | + relay::contrib::example_target_hooks::RelayToTIR()); |
| 130 | + |
| 131 | +} // namespace tvm |
0 commit comments