diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 945290f70265..f698c654d6d8 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -454,7 +454,32 @@ class AOTExecutorCodegen : public MixedModeVisitor { // call_extern calling convention with optional context if (has_c_device_api_context) { device_context = device_contexts_.Get(global_var).value(); - args.push_back(device_context); + + // call_extern has no further legalization steps, and + // requires the number of arguments to match exactly. For + // internal calls, conditionally append the device context. + bool requires_device_context = [&]() -> bool { + Optional opt = num_arguments_.Get(global_var); + if (!opt.defined()) { + // For external calls, we must trust that the user has + // supplied a kernel that accepts a device_context + // argument. + return true; + } + int num_callee_params = opt.value()->value; + int num_args = call_lowered_props.arguments.size(); + if (num_callee_params == num_args) { + return false; + } else if (num_callee_params == num_args + 1) { + return true; + } else { + LOG(FATAL) << "Callee " << global_var << " requires " << num_callee_params + << ", but is called with " << num_args << " arguments."; + } + }(); + if (requires_device_context) { + args.push_back(device_context); + } } func_call = tir::Evaluate(AddCheckReturn( tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); @@ -1007,6 +1032,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { Map devices_; /*! \brief map of GlobalVars to C Device API contexts */ Map device_contexts_; + /*! \brief map of GlobalVars to the number of arguments they require */ + Map num_arguments_; /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief input and output variables belonging to the main function signature */ @@ -1183,6 +1210,15 @@ class AOTExecutorCodegen : public MixedModeVisitor { } CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); + num_arguments_ = [&]() -> Map { + Map arg_count; + for (const auto& [gvar, func] : lowered_mod->functions) { + if (const auto* prim_func = func.as()) { + arg_count.Set(gvar, prim_func->params.size()); + } + } + return arg_count; + }(); VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet