Skip to content

[Feature] Add cubin launcher utility as an extra header#283

Merged
yaoyaoding merged 35 commits intoapache:mainfrom
yaoyaoding:cubin-loader
Nov 25, 2025
Merged

[Feature] Add cubin launcher utility as an extra header#283
yaoyaoding merged 35 commits intoapache:mainfrom
yaoyaoding:cubin-loader

Conversation

@yaoyaoding
Copy link
Copy Markdown
Contributor

@yaoyaoding yaoyaoding commented Nov 22, 2025

This PR adds include/tvm/ffi/extra/cubin_launcher.h, a header only utility, that loads cubin from byte buffer or a file and launch it. The utility is based on CUDA Runtime API. It's not compiled with the libtvm_ffi.so, but shipped with the package.


Usage

import torch
from tvm_ffi import cpp
from tvm_ffi.cpp import nvrtc

# Step 1: Define CUDA kernel source
cuda_source = """
extern "C" __global__ void add_one(float* x, float* y, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        y[idx] = x[idx] + 1.0f;
    }
}
"""

# Step 2: Compile to CUBIN using NVRTC
cubin_bytes = nvrtc.nvrtc_compile(cuda_source, name="kernel.cu")

# Step 3: Define C++ wrapper with embedded CUBIN
cpp_wrapper = """
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>

// Declare embedded CUBIN module
TVM_FFI_EMBED_CUBIN(my_cubin);

void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
  // Get kernel from embedded CUBIN (cached for efficiency)
  static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_cubin, "add_one");

  // Prepare kernel arguments
  int64_t n = x.size(0);
  void* x_ptr = x.data_ptr();
  void* y_ptr = y.data_ptr();
  void* args[] = {&x_ptr, &y_ptr, &n};

  // Configure launch parameters
  tvm::ffi::dim3 grid((n + 255) / 256);
  tvm::ffi::dim3 block(256);

  // Get CUDA stream and launch
  DLDevice device = x.device();
  CUstream stream = static_cast<CUstream>(
      TVMFFIEnvGetStream(device.device_type, device.device_id));

  CUresult result = kernel.Launch(args, grid, block, stream);
  TVM_FFI_CHECK_CUDA_DRIVER_ERROR(result);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, AddOne);
"""

# Step 4: Load module with embedded CUBIN
mod = cpp.load_inline(
    "my_module",
    cuda_sources=cpp_wrapper,
    embed_cubin={"my_cubin": cubin_bytes}
)

# Step 5: Use the kernel
x = torch.arange(1024, dtype=torch.float32, device="cuda")
y = torch.empty_like(x)
mod.add_one(x, y)

# Verify results
assert torch.allclose(y, x + 1)

C++ Core Usage

// Load CUBIN module from memory
tvm::ffi::CubinModule module(cubin_data);

// Get kernel by name
tvm::ffi::CubinKernel kernel = module["my_kernel"];

// Launch kernel (same as embedded example)
void* args[] = {...};
tvm::ffi::dim3 grid(...);
tvm::ffi::dim3 block(...);
CUstream stream = ...;

CUresult result = kernel.Launch(args, grid, block, stream);
TVM_FFI_CHECK_CUDA_DRIVER_ERROR(result);
  • documentation: see the guide at docs/guides/cubin_launcher.rst
  • more examples: see more examples at examples/cubin_launcher

Benchmark

Kernel Launch Overhead Benchmark: Triton vs TVM-FFI
============================================================
              CPU: Intel(R) Core(TM) Ultra 9 285K
      CUDA device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
  PyTorch version: 2.9.1+cu128
   Triton version: 3.5.1

Compiling empty Triton kernel to CUBIN...
Compiled CUBIN: 7,480 bytes

Benchmarking kernel launch overhead (10,000 calls)...
============================================================
  Triton launch                 :    4.911 μs/call
  TVM-FFI launch                :    2.034 μs/call

  Overhead: -58.59%
  TVM-FFI is 2.41x faster
============================================================

Benchmark script at examples/cubin_launcher/benchmark_overhead.py

Note: we did not check the dtype, shape constraints in triton's launch case, thus the number above is the lower-bound of a typical triton launch overhead. When the constraints are checked, the launch overhead of Triton will be higher. On the other hand, TVM-FFI checks the constraints in C++ and is fast.


About Triton Example
The triton example is tricky, since the cubin's kernel parameters is not aligned with the kernel definition of the triton kernel in python. I checked the generated PTX, there are 5 parameters while there is only 3 in the kernel definition. The last two are not used. We need some Triton expert to write the host side code the prepare the kernel parameters. To make it widely used, we need better documentation on the triton's calling convention regarding the generated cubin. But this is out of the scope of this PR.

Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yaoyaoding, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant new feature by adding a CUBIN launcher utility to the TVM-FFI framework. This utility simplifies the process of loading and executing pre-compiled CUDA kernels, offering greater flexibility for advanced GPU programming. It enables developers to dynamically load CUBIN files or embed them directly into libraries, facilitating integration with external GPU code generation tools and custom kernel execution workflows.

Highlights

  • New CUBIN Launcher Utility: Introduced include/tvm/ffi/extra/cubin_launcher.h, a header-only C++ utility that provides a lightweight wrapper around the CUDA Driver API for loading CUBIN modules and launching kernels. It supports loading CUBIN from both memory buffers and files.
  • CUDA Driver API Integration: The utility leverages CUDA Driver API's Library Management functions (cuLibraryLoadData, cuLibraryGetKernel, cuLaunchKernel) for robust and flexible CUBIN handling, including multi-GPU support via primary contexts and RAII-based resource management.
  • Comprehensive Examples and Documentation: Added a new examples/cubin_launcher directory with detailed README.md and Python examples demonstrating various usage patterns: embedding CUBIN data at compile time, dynamic CUBIN file loading at runtime, and an experimental integration with Triton kernels.
  • CMake Build System Enhancements: Updated the main CMakeLists.txt to prioritize virtual environments when searching for Python, and added a dedicated CMake configuration for the CUBIN launcher examples to handle CUBIN compilation and embedding.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a useful header-only utility for loading and launching CUBINs, along with comprehensive examples. The code is well-structured, particularly the cubin_launcher.h header which correctly uses RAII for resource management. I've found a critical resource leak in one of the C++ examples and a few typos in filenames and documentation. My review includes suggestions to fix the leak, improve thread-safety, and correct the typos.

Comment thread examples/cubin_launcher/src/lib_embedded.cc Outdated
Comment thread examples/cubin_launcher/README.md Outdated
Comment thread examples/cubin_launcher/README.md Outdated
Comment thread examples/cubin_launcher/embedded_cubin/main.py
Comment thread CMakeLists.txt
@yaoyaoding yaoyaoding requested a review from tqchen November 22, 2025 07:20
@yaoyaoding
Copy link
Copy Markdown
Contributor Author

yaoyaoding commented Nov 22, 2025

The Triton kernel:

    # Define the kernel dynamically
    @triton.jit
    def square_kernel(X_ptr, Y_ptr, n, BLOCK: tl.constexpr = 1024):  # noqa
        pid = tl.program_id(0)
        start = pid * BLOCK
        offsets = start + tl.arange(0, BLOCK)
        mask = offsets < n
        x = tl.load(X_ptr + offsets, mask=mask, other=0.0)
        y = x * x
        tl.store(Y_ptr + offsets, y, mask=mask)

    # Trigger kernel compilation by doing a dummy call
    x_dummy = torch.ones(1024, dtype=torch.float32, device="cuda")
    y_dummy = torch.empty(1024, dtype=torch.float32, device="cuda")
    square_kernel[1, 1](x_dummy, y_dummy, 1024)

The PTX:

.visible .entry square_kernel(
        .param .u64 .ptr .global .align 1 square_kernel_param_0,
        .param .u64 .ptr .global .align 1 square_kernel_param_1,
        .param .u32 square_kernel_param_2,
        .param .u64 .ptr .global .align 1 square_kernel_param_3,
        .param .u64 .ptr .global .align 1 square_kernel_param_4
)
.reqntid 128
{
        .reg .pred      %p<5>;
        .reg .b32       %r<33>;
        .reg .b64       %rd<8>;
        .loc    1 53 0                          // example_triton_cubin.py:53:0
$L__func_begin0:
        .loc    1 53 0                          // example_triton_cubin.py:53:0

// %bb.0:
        ld.param.b64    %rd5, [square_kernel_param_0];
        ld.param.b64    %rd6, [square_kernel_param_1];
$L__tmp0:
        .loc    1 54 24                         // example_triton_cubin.py:54:24
        mov.u32         %r25, %ctaid.x;
        .loc    1 55 18                         // example_triton_cubin.py:55:18
        shl.b32         %r26, %r25, 10;
        ld.param.b32    %r27, [square_kernel_param_2];
        .loc    1 56 35                         // example_triton_cubin.py:56:35
        mov.u32         %r28, %tid.x;
        shl.b32         %r29, %r28, 2;
        and.b32         %r30, %r29, 508;
        .loc    1 56 22                         // example_triton_cubin.py:56:22
        or.b32  %r31, %r30, %r26;
        or.b32  %r32, %r31, 512;
        .loc    1 57 21                         // example_triton_cubin.py:57:21
        setp.lt.s32     %p1, %r31, %r27;
        setp.lt.s32     %p2, %r32, %r27;
        .loc    1 58 24                         // example_triton_cubin.py:58:24
        mul.wide.s32    %rd7, %r31, 4;
        add.s64         %rd1, %rd5, %rd7;
        add.s64         %rd2, %rd1, 2048;
        mov.b32         %r5, 0;
        .loc    1 58 16                         // example_triton_cubin.py:58:16
        // begin inline asm
        mov.u32 %r1, %r5;
        mov.u32 %r2, %r5;
        mov.u32 %r3, %r5;
        mov.u32 %r4, %r5;
        @%p1 ld.global.v4.b32 { %r1, %r2, %r3, %r4 }, [ %rd1 + 0 ];
        // end inline asm
        // begin inline asm
        mov.u32 %r9, %r5;
        mov.u32 %r10, %r5;
        mov.u32 %r11, %r5;
        mov.u32 %r12, %r5;
        @%p2 ld.global.v4.b32 { %r9, %r10, %r11, %r12 }, [ %rd2 + 0 ];
        // end inline asm
        .loc    1 59 12                         // example_triton_cubin.py:59:12
        mul.f32         %r17, %r1, %r1;
        mul.f32         %r18, %r2, %r2;
        mul.f32         %r19, %r3, %r3;
        mul.f32         %r20, %r4, %r4;
        mul.f32         %r21, %r9, %r9;
        mul.f32         %r22, %r10, %r10;
        mul.f32         %r23, %r11, %r11;
        mul.f32         %r24, %r12, %r12;
        .loc    1 60 21                         // example_triton_cubin.py:60:21
        add.s64         %rd3, %rd6, %rd7;
        add.s64         %rd4, %rd3, 2048;
        .loc    1 60 30                         // example_triton_cubin.py:60:30
        // begin inline asm
        @%p1 st.global.v4.b32 [ %rd3 + 0 ], { %r17, %r18, %r19, %r20 };
        // end inline asm
        // begin inline asm
        @%p2 st.global.v4.b32 [ %rd4 + 0 ], { %r21, %r22, %r23, %r24 };
        // end inline asm
        .loc    1 60 4                          // example_triton_cubin.py:60:4
        ret;
$L__tmp1:
$L__func_end0:
                                        // -- End function
}

The two extra parameters are not used. Maybe it's used for some features that not used in this simple kernel.

Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h Outdated
Comment thread include/tvm/ffi/extra/cubin_launcher.h Outdated
Comment thread include/tvm/ffi/extra/cubin_launcher.h Outdated
Comment thread include/tvm/ffi/extra/cubin_launcher.h Outdated
Comment thread examples/cubin_launcher/src/lib_embedded.cc Outdated

// External symbols for embedded CUBIN data (linked via objcopy)
extern "C" const char __cubin_data[];
extern "C" const char __cubin_data_end[];
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing to consider is how to avoid symbol conflict when we link multiple such embedded files,

here is what gemini suggest, do it in the following steps, maybe we can have a tool via python tvm.ffi.cpp.embed_cubin(output_object, input_object, cubin, key="env") to do that. we can also provide a cmake macro for those who like to use tvm-ffi cmake

  • Step 1: Compile the C++ Source (source.o) Compile your C++ code normally. Ensure you declare the symbols as extern so the compiler creates an "undefined reference" (a hole to be filled later). g++ -c source.cc -o source.o
  • Step 2: Convert Binary to Object (blob_raw.o) Use objcopy to wrap the raw binary file into a linkable object file. This creates Global symbols by default. objcopy -I binary -O elf64-x86-64 kernel.cubin blob_raw.o
  • Step 3: Rename Symbols (blob_renamed.o) Change the auto-generated names (e.g., _binary_kernel_start) to the specific names your C++ code expects (tvm_ffi...). objcopy --redefine-sym old_name=new_name blob_raw.o blob_renamed.o
  • Step 4: Partial Link / Merge (merged.o) Use ld -r to fuse the code (source.o) and the data (blob_renamed.o) together. This resolves the "undefined reference." ld -r source.o blob_renamed.o -o merged.o
  • Step 5: Localize Symbols (final.o) Crucial Last Step: Now that the code and data are in the same file, use objcopy to change the symbols from Global to Local. This hides them from the outside world (Internal Linkage). objcopy --localize-symbol=__tvm_ffi__cubin_data merged.o final.o
# ==========================================
# Configuration
# ==========================================

# Files
BINARY_FILE := kernel.cubin
SOURCE_FILE := source.cc
OUTPUT_OBJ  := final_module.o

# The symbol names your C++ code uses (extern "C")
SYM_NAME      := __tvm_ffi__cubin_data
SYM_NAME_END  := __tvm_ffi__cubin_data_end

# Compiler settings
CXX      := g++
CXXFLAGS := -O2 -Wall -fPIC
LD       := ld
OBJCOPY  := objcopy

# ------------------------------------------
# Internal Calculation for objcopy default names
# objcopy converts "kernel.cubin" -> "_binary_kernel_cubin_start"
# We replace dots and slashes with underscores to match objcopy's behavior.
# ------------------------------------------
BINARY_FLAT   := $(subst /,_,$(subst .,_,$(BINARY_FILE)))
DEFAULT_START := _binary_$(BINARY_FLAT)_start
DEFAULT_END   := _binary_$(BINARY_FLAT)_end
DEFAULT_SIZE  := _binary_$(BINARY_FLAT)_size

# ==========================================
# Rules
# ==========================================

.PHONY: all clean check

all: $(OUTPUT_OBJ)

# 1. Compile the C++ source into an object file.
#    (Contains undefined references to the symbols)
source.o: $(SOURCE_FILE)
	@echo "[1/5] Compiling Source..."
	$(CXX) $(CXXFLAGS) -c $< -o $@

# 2. Convert the raw binary into an ELF object file.
#    (Symbols are Global and named _binary_kernel_cubin_start)
blob_raw.o: $(BINARY_FILE)
	@echo "[2/5] Converting Binary to Object..."
	$(OBJCOPY) -I binary -O elf64-x86-64 $< $@

# 3. Rename the symbols to match your C++ declaration.
#    (Still Global, but names match __tvm_ffi__...)
blob_renamed.o: blob_raw.o
	@echo "[3/5] Renaming Symbols..."
	$(OBJCOPY) \
		--redefine-sym $(DEFAULT_START)=$(SYM_NAME) \
		--redefine-sym $(DEFAULT_END)=$(SYM_NAME_END) \
		--strip-symbol=$(DEFAULT_SIZE) \
		$< $@

# 4. Partial Link (Merge).
#    (Fuses source.o and blob.o. Code can now see Data.)
merged.o: source.o blob_renamed.o
	@echo "[4/5] Linking (Partial Merge)..."
	$(LD) -r source.o blob_renamed.o -o $@

# 5. Localize Symbols.
#    (Hides the symbols from the outside world. Global D -> Local d)
$(OUTPUT_OBJ): merged.o
	@echo "[5/5] Finalizing: Hiding Symbols..."
	$(OBJCOPY) \
		--localize-symbol=$(SYM_NAME) \
		--localize-symbol=$(SYM_NAME_END) \
		$< $@
	@echo "Success! Created $(OUTPUT_OBJ)"

# ==========================================
# Utilities
# ==========================================

# Helper to prove the symbols are local
check: $(OUTPUT_OBJ)
	@echo "Checking symbol visibility in $(OUTPUT_OBJ)..."
	@echo "Look for lowercase 'd' (local data) or 'r' (local read-only):"
	@nm $(OUTPUT_OBJ) | grep __tvm_ffi__

clean:
	rm -f *.o

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The claude sonnet gives a slightly different steps. See the documentation of the macro TVM_FFI_EMBED_CUBIN

Copy link
Copy Markdown
Contributor Author

@yaoyaoding yaoyaoding Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a python script to embed cubin to an object file:

usage: python -m tvm_ffi.utils.embed_cubin [-h] --output-obj PATH --input-obj PATH --cubin PATH --name NAME [-v]

Embed CUBIN data into existing object files that use TVM_FFI_EMBED_CUBIN macro

options:
  -h, --help         show this help message and exit
  --output-obj PATH  Output object file path (e.g., new.o)
  --input-obj PATH   Input object file path containing TVM_FFI_EMBED_CUBIN usage (e.g., old.o)
  --cubin PATH       Input CUBIN file path (e.g., kernel.cubin)
  --name NAME        Name used in TVM_FFI_EMBED_CUBIN macro (e.g., my_kernels)
  -v, --verbose      Print detailed command output

Examples:
  # Basic usage
  python -m tvm_ffi.utils.embed_cubin \
      --output-obj new.o \
      --input-obj old.o \
      --cubin kernel.cubin \
      --name my_kernels

  # With verbose output
  python -m tvm_ffi.utils.embed_cubin \
      --output-obj new.o \
      --input-obj old.o \
      --cubin kernel.cubin \
      --name my_kernels \
      --verbose

Workflow:
  1. Compile C++ code that uses TVM_FFI_EMBED_CUBIN to create old.o
  2. Compile CUDA kernel to CUBIN (e.g., using nvcc or NVRTC)
  3. Use this tool to merge them into new.o
  4. Link new.o into your final shared library

Usage in C++ code (source compiled to old.o):
  TVM_FFI_EMBED_CUBIN(my_kernels);
  auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "kernel_name");

Requirements:
  - GNU binutils (ld and objcopy) must be available in PATH
  - Linux/Unix platform (Windows uses different embedding mechanisms)

For cmake, we have two utility functions at cmake/Utils/EmbedCubin.cmake

tvm_ffi_generate_cubin(
  OUTPUT <output_cubin_file>
  SOURCE <cuda_source_file>
  [ARCH <architecture>]
  [OPTIONS <extra_nvcc_options>...]
  [DEPENDS <additional_dependencies>...]
)
tvm_ffi_embed_cubin(
  OUTPUT <output_object_file>
  SOURCE <source_file>
  CUBIN <cubin_file>
  NAME <symbol_name>
  [DEPENDS <additional_dependencies>...]
)

Comment thread examples/cubin_launcher/example_triton_cubin.py Outdated
@tqchen
Copy link
Copy Markdown
Member

tqchen commented Nov 22, 2025

API-wise, i think we can streamline it a bit further

cpp_source = """
#include <tvm/ffi/extra/cuda/cubin_launcher.h>

TVM_FFI_EMBED_CUBIN(env);

void AddTwo(TensorView a, TensorView b) {
   static ffi::CubinKernel kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(env, "add_two");
   kernel.launch(...);
}
"""

cubin : bytes = compile_cubin_from_nvrtc(cuda_source);

cubin : bytes = compile_cubin_from_triton();
tvm_ffi.cpp.load_inline(cpp_source, embed_cubin={"env": cubin});

Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h
@vinx13
Copy link
Copy Markdown
Member

vinx13 commented Nov 23, 2025

The last two args of triton kernel are scratch memory, it need to be allocated using the size in metadata. See https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.py#L700-L714
you will also need to handle constexpr args that are not in the cubin args

@yaoyaoding
Copy link
Copy Markdown
Contributor Author

The last two args of triton kernel are scratch memory, it need to be allocated using the size in metadata. See https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.py#L700-L714 you will also need to handle constexpr args that are not in the cubin args

Cool, thanks for the information!

Copy link
Copy Markdown
Contributor

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! This is very helpful for further integrating Triton kernels.

Comment thread python/tvm_ffi/cpp/nvrtc.py
Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h
Comment thread examples/cubin_launcher/src/lib_dynamic.cc Outdated
Comment thread docs/guides/cubin_launcher.md Outdated
@yaoyaoding
Copy link
Copy Markdown
Contributor Author

yaoyaoding commented Nov 25, 2025

Hi @tqchen @junrushao @Ubospica , the PR is ready for review, could you have one pass?

Copy link
Copy Markdown
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking great, some final comments

Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h Outdated
Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h
Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h Outdated
Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h
Comment thread python/tvm_ffi/cpp/extension.py Outdated
for (int device_id = 0; device_id < device_count; ++device_id) {
// Query device's maximum shared memory per block
cudaDeviceProp prop;
err = cudaGetDeviceProperties(&prop, device_id);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


// Set max dynamic shared memory for all devices during initialization
// This allows the kernel to use maximum available shared memory when needed
int device_count = 0;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to a function kernel.SetMaxDynamicSharedMemory(size_t static_mem_size, int64_t dynamic_smem_max=-1); consider make private and friend to CubinModule

where -1 deduce max from max value mininus static_mem_size

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add CubinModule.GetKernelWithMaxDynamicSharedMemory(name, static_mem_size, dynamic_smem_max);

This is advanced mode since not all kernels need it

Copy link
Copy Markdown
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please fix the ci issue

@tqchen
Copy link
Copy Markdown
Member

tqchen commented Nov 25, 2025

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an excellent pull request that introduces a powerful cubin_launcher utility. The feature is well-designed with a clean C++ API, comprehensive build system integration via CMake and Python utilities, and thorough documentation and examples. The code quality is high, and the addition of NVRTC and Triton integration examples is particularly valuable. I have a couple of minor suggestions for improving the documentation clarity, but overall this is a fantastic contribution.

Comment thread docs/guides/cubin_launcher.rst Outdated
Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h Outdated
@yaoyaoding
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a powerful CUBIN launcher utility, enabling efficient execution of pre-compiled CUDA kernels. The changes are comprehensive, including a new C++ header, CMake and Python utilities for embedding CUBINs, NVRTC integration for runtime compilation, and extensive examples and documentation. The implementation is well-structured and the new features are a great addition. I've found a couple of minor issues: one correctness issue in the C++ header and a potential caching issue in the Python build extension. After addressing these, this PR will be in excellent shape.

Comment thread include/tvm/ffi/extra/cuda/cubin_launcher.h
Comment thread python/tvm_ffi/cpp/extension.py
Comment thread python/tvm_ffi/cpp/extension.py Outdated
@yaoyaoding
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a powerful new feature for launching pre-compiled CUDA kernels from CUBIN files. The implementation is comprehensive, including a C++ header-only library, CMake and Python build system integrations, runtime compilation utilities via NVRTC, and extensive examples and documentation. The code is well-structured and robust. I have one suggestion to improve the reliability of the Python build extension in different environments.

Comment thread python/tvm_ffi/cpp/extension.py
@yaoyaoding yaoyaoding merged commit d49effd into apache:main Nov 25, 2025
7 checks passed
@yaoyaoding yaoyaoding deleted the cubin-loader branch November 25, 2025 23:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants