Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d63ef67
init
Kathryn-cat Nov 25, 2025
882f7bd
upd
Kathryn-cat Nov 27, 2025
c4514c4
upd
Kathryn-cat Nov 27, 2025
7e76a59
upd
Kathryn-cat Nov 28, 2025
9df7435
upd
Kathryn-cat Nov 28, 2025
e8d1657
upd
Kathryn-cat Nov 28, 2025
5601ddb
addressed segfault
Kathryn-cat Nov 28, 2025
9040d59
remove host-side deps; unit test passed except CUtensorMap
Kathryn-cat Nov 28, 2025
568412e
fixed int8 tests
Kathryn-cat Nov 28, 2025
d10e9fe
CUtensorMap patch
Kathryn-cat Nov 28, 2025
a498b1f
dual compilation problem fixed
Kathryn-cat Nov 29, 2025
5f905e2
TVM_CUDA_COMPILE_MODE
Kathryn-cat Nov 29, 2025
925022f
unit tests
Kathryn-cat Nov 29, 2025
073f51e
remove deps in cmake
Kathryn-cat Nov 30, 2025
4008da0
update call site
Kathryn-cat Nov 30, 2025
7a28348
gpu ci env
Kathryn-cat Nov 30, 2025
734bf71
lint
Kathryn-cat Nov 30, 2025
9c36b0b
skip test if cuda-python is not available
Kathryn-cat Nov 30, 2025
c598ba3
robustify CUDA header files search
Kathryn-cat Nov 30, 2025
c8969ec
fix CI
Kathryn-cat Nov 30, 2025
6158357
fixed nvshmem
Kathryn-cat Dec 2, 2025
fe4780e
nvrtc nvshmem compile
Kathryn-cat Dec 2, 2025
7856b5c
remove nvshmem tests
Kathryn-cat Dec 12, 2025
143707e
fall back to nvcc for nvshmem
Kathryn-cat Dec 12, 2025
2bcbc03
update error message
Kathryn-cat Dec 12, 2025
b7fb6ca
lint
Kathryn-cat Dec 12, 2025
e5a9c0e
lint
Kathryn-cat Dec 12, 2025
92514c1
lint
Kathryn-cat Dec 12, 2025
94f1f56
lint
Kathryn-cat Dec 12, 2025
4b11de2
lint
Kathryn-cat Dec 12, 2025
4b14e38
lint
Kathryn-cat Dec 12, 2025
4a92047
add fast math to enable perf equal for nvcc and nvrtc
Kathryn-cat Dec 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint
  • Loading branch information
Kathryn-cat committed Dec 12, 2025
commit e5a9c0e73d93b9782588926296e582e4ae1121c9
46 changes: 13 additions & 33 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def compile_cuda(
# is not yet implemented
use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in code
if compiler == "nvcc" or use_nvshmem:
return _compile_cuda_nvcc(
code, target_format, arch, options, path_target, use_nvshmem
)
return _compile_cuda_nvcc(code, target_format, arch, options, path_target, use_nvshmem)
elif compiler == "nvrtc":
return _compile_cuda_nvrtc(code, target_format, arch, options)
else:
Expand Down Expand Up @@ -276,9 +274,7 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
"or set compiler='nvcc' for fatbin compilation."
)
if target_format not in ["cubin", "ptx"]:
raise ValueError(
f"target_format must be 'cubin' or 'ptx', got: {target_format}"
)
raise ValueError(f"target_format must be 'cubin' or 'ptx', got: {target_format}")

# Validate options
if options is not None and not isinstance(options, (str, list)):
Expand Down Expand Up @@ -312,9 +308,7 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
str.encode(code_filtered), b"tvm_kernels.cu", 0, None, None
)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError(
f"Failed to create NVRTC program: {nvrtc.nvrtcGetErrorString(result)}"
)
raise RuntimeError(f"Failed to create NVRTC program: {nvrtc.nvrtcGetErrorString(result)}")

# Prepare compilation options
cuda_path = find_cuda_path()
Expand Down Expand Up @@ -344,9 +338,7 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
include_paths.append(arch_include)

# Verify we can find essential CUDA headers
if not any(
os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in include_paths
):
if not any(os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in include_paths):
raise RuntimeError(
f"Cannot find CUDA headers in {cuda_path}. "
f"Searched in: {include_paths}. "
Expand All @@ -362,9 +354,7 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
if isinstance(options, str):
compile_opts.append(options.encode())
else:
compile_opts.extend(
[opt.encode() if isinstance(opt, str) else opt for opt in options]
)
compile_opts.extend([opt.encode() if isinstance(opt, str) else opt for opt in options])

# Compile
(result,) = nvrtc.nvrtcCompileProgram(prog, len(compile_opts), compile_opts)
Expand All @@ -389,30 +379,22 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
result, binary_size = nvrtc.nvrtcGetCUBINSize(prog)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(
f"Failed to get CUBIN size: {nvrtc.nvrtcGetErrorString(result)}"
)
raise RuntimeError(f"Failed to get CUBIN size: {nvrtc.nvrtcGetErrorString(result)}")
binary_buf = bytearray(binary_size)
(result,) = nvrtc.nvrtcGetCUBIN(prog, binary_buf)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(
f"Failed to get CUBIN: {nvrtc.nvrtcGetErrorString(result)}"
)
raise RuntimeError(f"Failed to get CUBIN: {nvrtc.nvrtcGetErrorString(result)}")
else: # ptx
result, binary_size = nvrtc.nvrtcGetPTXSize(prog)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(
f"Failed to get PTX size: {nvrtc.nvrtcGetErrorString(result)}"
)
raise RuntimeError(f"Failed to get PTX size: {nvrtc.nvrtcGetErrorString(result)}")
binary_buf = bytearray(binary_size)
(result,) = nvrtc.nvrtcGetPTX(prog, binary_buf)
if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(
f"Failed to get PTX: {nvrtc.nvrtcGetErrorString(result)}"
)
raise RuntimeError(f"Failed to get PTX: {nvrtc.nvrtcGetErrorString(result)}")

# Clean up
nvrtc.nvrtcDestroyProgram(prog)
Expand Down Expand Up @@ -543,9 +525,9 @@ def find_nvshmem_paths() -> Tuple[str, str]:
if os.path.isfile(os.path.join(include_path, "nvshmem.h")):
for lib_path in lib_paths_to_check:
# Check for both static (.a) and shared (.so) libraries
if os.path.isfile(
os.path.join(lib_path, "libnvshmem.a")
) or os.path.isfile(os.path.join(lib_path, "libnvshmem.so")):
if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")) or os.path.isfile(
os.path.join(lib_path, "libnvshmem.so")
):
return include_path, lib_path

error_message = [
Expand Down Expand Up @@ -604,9 +586,7 @@ def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
if compiler == "nvcc":
return compile_cuda(code, target_format="fatbin", compiler="nvcc")

raise ValueError(
f"Invalid TVM_CUDA_COMPILE_MODE: {compiler}. Expected 'nvcc' or 'nvrtc'."
)
raise ValueError(f"Invalid TVM_CUDA_COMPILE_MODE: {compiler}. Expected 'nvcc' or 'nvrtc'.")


@tvm_ffi.register_global_func("tvm_callback_libdevice_path")
Expand Down
12 changes: 3 additions & 9 deletions python/tvm/script/ir_builder/tir/external_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def _format_tvm_module_metadata(self, kernel_name, arg_types, launch_param_tags)
)
return tvm_metadata

def _create_cuda_module(
self, ptx, kernel_arg_types, launch_param_tags, kernel_name
):
def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_name):
"""
Create a CUDA module from PTX and metadata.

Expand Down Expand Up @@ -120,15 +118,11 @@ def compile_to_device_module( # pylint: disable=arguments-differ
"['threadIdx.x', 'threadIdx.y', 'threadIdx.z']"
)
assert isinstance(grid[0], (list, tuple)) and isinstance(grid[1], (list, tuple))
launch_param_tags = ["blockIdx.x", "blockIdx.y", "blockIdx.z"][
: len(grid[0])
] + [
launch_param_tags = ["blockIdx.x", "blockIdx.y", "blockIdx.z"][: len(grid[0])] + [
"threadIdx.x",
"threadIdx.y",
"threadIdx.z",
][
: len(grid[1])
]
][: len(grid[1])]
runtime_args = [arg if hasattr(arg, "dtype") else const(arg) for arg in args]
kernel_arg_types = [arg.dtype for arg in runtime_args]
runtime_args = runtime_args + list(grid[0]) + list(grid[1])
Expand Down