From 3c511ed1b4281ab1285bf0b93ce70d1e570cc935 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 21 Aug 2025 23:51:15 -0400 Subject: [PATCH] [NVSHMEM] Fix compatibility with CUDA code without nvshmem use This PR fixes two bugs that cause normal TIR functions (ones that don't use any NVSHMEM API) not being able to compile and run, in cases where `set(USE_NVSHMEM xxx)` is enabled. Co-authored-by: Bohan Hou --- python/tvm/contrib/nvcc.py | 9 ++++++--- src/runtime/contrib/nvshmem/init.cc | 6 ++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index c79305a739cd..e9d8fac761c0 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -57,10 +57,13 @@ def compile_cuda(code, target_format=None, arch=None, options=None, path_target= """ # Check for NVSHMEM dependency nvshmem_include_path, nvshmem_lib_path = None, None - use_nvshmem = ( - tvm.get_global_func("runtime.nvshmem.cumodule_init", allow_missing=True) is not None - ) + use_nvshmem = "#include " in code or "#include " in code if use_nvshmem: + # NOTE: we cannot check whether nvshmem is used based on whether + # the global function "runtime.nvshmem.cumodule_init" is defined. + # The reason is because that if the input code does not use any NVSHMEM functions + # while the global function is defined, using cubin to compile the + # code may cause a compilation error. target_format = "cubin" nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths() diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 1b0a65f4f1fc..4cb0558d611b 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -114,8 +114,10 @@ void NVSHMEMXCumoduleInit(void* cuModule) { // nvshmemx_cumodule_init. If not, we skip the cumodule initialization. if (status == NVSHMEM_STATUS_IS_INITIALIZED || status == NVSHMEM_STATUS_LIMITED_MPG || status == NVSHMEM_STATUS_FULL_MPG) { - int result = nvshmemx_cumodule_init(mod); - ICHECK_EQ(result, 0) << "nvshmemx_cumodule_init failed with error code: " << result; + // NOTE: we do not check the return value of nvshmemx_cumodule_init. + // The reason is because that the input cuModule might not use any NVSHMEM functions, + // in which case the nvshmemx_cumodule_init will fail. + nvshmemx_cumodule_init(mod); } }