From 3b5d135def5d5f46ca2186487a014dc0bda2877c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Jul 2023 09:36:48 -0500 Subject: [PATCH] [AOT] Avoid call_extern() with incorrect argument count Prior to this commit, if device initialization is required, the AOT main function produced a `call_extern()` that included the device context as input. This commit updates the AOT main function to provide the device context only if the function being called accepts a device context as input. If an extra device context argument is included at the call site, the C codegen would produce a function signature that includes the device context for the caller's compilation unit, but a signature without the device context for the callee's compilation unit. While this can compile and run in some cases, it is undefined behavior for the signature to vary between compilation units, and should be avoided. This was initially discovered while debugging https://github.com/apache/tvm/pull/14985, in which changes to the lowering flow resulted in the caller and callee being within the same compilation unit. --- src/relay/backend/aot_executor_codegen.cc | 38 ++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 945290f70265..f698c654d6d8 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -454,7 +454,32 @@ class AOTExecutorCodegen : public MixedModeVisitor { // call_extern calling convention with optional context if (has_c_device_api_context) { device_context = device_contexts_.Get(global_var).value(); - args.push_back(device_context); + + // call_extern has no further legalization steps, and + // requires the number of arguments to match exactly. For + // internal calls, conditionally append the device context. + bool requires_device_context = [&]() -> bool { + Optional opt = num_arguments_.Get(global_var); + if (!opt.defined()) { + // For external calls, we must trust that the user has + // supplied a kernel that accepts a device_context + // argument. + return true; + } + int num_callee_params = opt.value()->value; + int num_args = call_lowered_props.arguments.size(); + if (num_callee_params == num_args) { + return false; + } else if (num_callee_params == num_args + 1) { + return true; + } else { + LOG(FATAL) << "Callee " << global_var << " requires " << num_callee_params + << ", but is called with " << num_args << " arguments."; + } + }(); + if (requires_device_context) { + args.push_back(device_context); + } } func_call = tir::Evaluate(AddCheckReturn( tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); @@ -1007,6 +1032,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { Map devices_; /*! \brief map of GlobalVars to C Device API contexts */ Map device_contexts_; + /*! \brief map of GlobalVars to the number of arguments they require */ + Map num_arguments_; /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief input and output variables belonging to the main function signature */ @@ -1183,6 +1210,15 @@ class AOTExecutorCodegen : public MixedModeVisitor { } CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); + num_arguments_ = [&]() -> Map { + Map arg_count; + for (const auto& [gvar, func] : lowered_mod->functions) { + if (const auto* prim_func = func.as()) { + arg_count.Set(gvar, prim_func->params.size()); + } + } + return arg_count; + }(); VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet