Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,32 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// call_extern calling convention with optional context
if (has_c_device_api_context) {
device_context = device_contexts_.Get(global_var).value();
args.push_back(device_context);

// call_extern has no further legalization steps, and
// requires the number of arguments to match exactly. For
// internal calls, conditionally append the device context.
bool requires_device_context = [&]() -> bool {
Optional<Integer> opt = num_arguments_.Get(global_var);
if (!opt.defined()) {
// For external calls, we must trust that the user has
// supplied a kernel that accepts a device_context
// argument.
return true;
}
int num_callee_params = opt.value()->value;
int num_args = call_lowered_props.arguments.size();
if (num_callee_params == num_args) {
return false;
} else if (num_callee_params == num_args + 1) {
return true;
} else {
LOG(FATAL) << "Callee " << global_var << " requires " << num_callee_params
<< ", but is called with " << num_args << " arguments.";
}
}();
if (requires_device_context) {
args.push_back(device_context);
}
}
func_call = tir::Evaluate(AddCheckReturn(
tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args)));
Expand Down Expand Up @@ -1007,6 +1032,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Map<String, tir::Var> devices_;
/*! \brief map of GlobalVars to C Device API contexts */
Map<GlobalVar, tir::Var> device_contexts_;
/*! \brief map of GlobalVars to the number of arguments they require */
Map<GlobalVar, Integer> num_arguments_;
/*! \brief input and output variables belonging to the main function signature */
Array<tir::Var> main_signature_;
/*! \brief input and output variables belonging to the main function signature */
Expand Down Expand Up @@ -1183,6 +1210,15 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
num_arguments_ = [&]() -> Map<GlobalVar, Integer> {
Map<GlobalVar, Integer> arg_count;
for (const auto& [gvar, func] : lowered_mod->functions) {
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
arg_count.Set(gvar, prim_func->params.size());
}
}
return arg_count;
}();
VisitExpr(lowered_main_func->body);

// Create the runner function. Please note that the function is not legal yet
Expand Down