diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 12218955f110..ad4640bd7d86 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -182,9 +182,9 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { TVMByteArray arr; arr.data = &obj[0]; arr.size = obj.length(); + std::string ll(data_ll.begin(), data_ll.end()); std::string hsaco = (*f)(arr); - std::string ll(data_ll.begin(), data_ll.end()); return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll); } diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc new file mode 100644 index 000000000000..e2c7188c6ff1 --- /dev/null +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -0,0 +1,78 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file intrin_rule_rocm.cc + */ +#ifdef TVM_LLVM_VERSION +#if TVM_ROCM_RUNTIME + +#include +#include +#include +#include +#include "./llvm_common.h" + +namespace tvm { +namespace codegen { +namespace llvm { + +using namespace ir; + +// num_signature means number of arguments used to query signature +template +inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { + Expr e = targs[0]; + const Call* call = e.as(); + CHECK(call != nullptr); + Array cargs; + // intrin id. + cargs.push_back(UIntImm::make(UInt(32), id)); + cargs.push_back(UIntImm::make(UInt(32), num_signature)); + + for (Expr arg : call->args) { + cargs.push_back(arg); + } + *rv = Call::make( + call->type, "llvm_intrin", cargs, Call::PureIntrinsic); +} + +template +inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { + Expr e = targs[0]; + const Call* call = e.as(); + CHECK(call != nullptr); + Array cargs; + // intrin id. + cargs.push_back(UIntImm::make(UInt(32), id)); + cargs.push_back(UIntImm::make(UInt(32), num_signature)); + for (Expr arg : call->args) { + cargs.push_back(arg); + } + *rv = Call::make( + call->type, "llvm_intrin", cargs, Call::Intrinsic); +} + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.rocm.prefetch") +.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.rocm.exp") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.rocm.fma") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.rocm.log") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.rocm.sqrt") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.rocm.pow") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>); + +} // namespace llvm +} // namespace codegen +} // namespace tvm + +#endif // TVM_ROCM_RUNTIME + +#endif // LLVM_VERSION