Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 78 additions & 28 deletions python/tvm/tir/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

# pylint: disable=invalid-name
"""The build utils in python."""
from typing import Union, Optional, Dict
import enum
from typing import Union, Optional, Dict, Tuple

import tvm
from tvm import ir
Expand All @@ -28,44 +27,95 @@
from tvm.target import Target


def split_host_device_mods(mod):
def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, IRModule]]:
"""Split an IRModule into host and device modules.

This function takes an IRModule containing functions with different target attributes
and separates them into host (CPU) and device (GPU/accelerator) modules. Functions
are categorized based on their target attribute in func_attr.

Parameters
----------
mod : tvm.IRModule
The input module to split
The input module to split.
The module should contain functions with target attributes in their func_attr.
Functions with "cpu" in their target string are considered host functions,
while others are considered device functions.

Returns
-------
host_mod : tvm.IRModule
The module containing host functions
The module containing host functions (CPU-targeted functions)
device_mod_dict : Dict[Target, tvm.IRModule]
A dict mapping targets to device modules
A dict mapping targets to device modules. Each device module contains
functions targeting the same device (e.g., CUDA GPU, OpenCL, etc.)

Examples
--------
Given an IRModule with the following functions:

.. code-block:: python

@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"],
"kind": "cuda", "max_num_threads": 1024}))
return a + b

@T.prim_func(private=True)
def add_host(a: T.int32, b: T.int32) -> T.int32:
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}))
return a + b

@T.prim_func
def main_kernel(A: T.handle, B: T.handle, C: T.handle, length: T.int32):
T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"],
"kind": "cuda"}),
"calling_conv": 2, # kDeviceKernelLaunch for device kernels
"tir.is_global_func": True})
# ... kernel implementation

@T.prim_func
def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.handle):
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}),
"calling_conv": 1, # kCPackedFunc for entry functions
"tir.is_entry_func": True})
# ... main function implementation

The function will return:
- host_mod: Contains `add_host` and `main` functions (CPU targets)
- device_mod_dict: Contains a CUDA module with `add` and `main_kernel` functions

Notes
-----
- Functions are categorized based on string matching of their target attribute
- Functions with "cpu" in the target string are considered host functions
- Device functions are grouped by their target to create separate modules
- The function uses string-based target matching due to target hash limitations
- All functions must have a `calling_conv` attribute in their func_attr:
- Private helper functions (private=True): use `calling_conv: 0` (kDefault, by default)
- Public entry functions: use `calling_conv: 1` (kCPackedFunc)
- Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch)
"""

class CallConv(enum.IntEnum):
"""Enum representing different calling conventions.
Corresponds to the C++ tvm::ir::CallingConv enum.
"""

kDefault = 0
kCPackedFunc = 1
kDeviceKernelLaunch = 2

host_mod = tvm.tir.transform.Filter(
lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
!= int(CallConv.kDeviceKernelLaunch)
)(mod)
device_mod = tvm.tir.transform.Filter(
lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
== int(CallConv.kDeviceKernelLaunch)
)(mod)
device_mod_dict = {}
host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in str(f.attrs.get("target", "cpu")))(mod)
device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in str(f.attrs.get("target", "cpu")))(
mod
)
# TODO(syfeng): Here we use str as key since target hash is not correct
target_str2target = {}
device_func_dict = {}
device_mod_dict: Dict[Target, IRModule] = {}
for gv, func in device_mod.functions.items():
device_mod_dict.setdefault(func.attrs.get("target", None), dict()).update({gv: func})
for target, funcs in device_mod_dict.items():
device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
target = func.attrs.get("target", None)
target_str = str(target) if target is not None else ""
target_str2target[target_str] = target # This might be overridden by the last one
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to make sure in which cases different Target obects might have the same string representations target_str.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now str uniquely maps a target so it is ok here, but good to document such invariance and like @Hzfengsy commented, we can fix after target hash is supported

device_func_dict.setdefault(target_str, dict()).update({gv: func})
for target_str in target_str2target.keys():
target = target_str2target[target_str]
device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs)
return host_mod, device_mod_dict


Expand Down Expand Up @@ -162,7 +212,7 @@ def build(
# Step 3: Bind the target to the input module
mod = tvm.tir.transform.BindTarget(target_to_bind)(mod)

# Step 4: Apply the tir pipeline
# Step 4: Apply the tir pipeline
if pipeline is not None:
# custom pipeline
if isinstance(pipeline, str):
Expand Down
4 changes: 3 additions & 1 deletion src/target/build_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
if (global_symbol) {
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
}
return fmap;
}
Expand Down
9 changes: 6 additions & 3 deletions src/target/opt/build_cuda_on.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto prim_func = Downcast<PrimFunc>(base_func);
auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto calling_conv =
prim_func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault));
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch ||
calling_conv == CallingConv::kDefault)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch or "
"CallingConv::kDefault";
functions.Set(gvar, prim_func);
}

Expand Down
14 changes: 13 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,19 @@ void CodeGenCUDA::Init(bool output_ssa) {
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}

void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
void CodeGenCUDA::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) {
auto calling_conv =
func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault));
if (calling_conv == CallingConv::kDeviceKernelLaunch) {
os << "extern \"C\" __global__ ";
} else if (calling_conv == CallingConv::kDefault) {
os << "extern \"C\" __device__ ";
} else {
LOG(FATAL) << "Unsupported calling convention for cuda codegen: " << calling_conv;
}
CodeGenC::PrintFunctionSignature(function_name, func, os);
}

class ThreadIdxExtractor : public tir::StmtVisitor {
private:
Expand Down
3 changes: 2 additions & 1 deletion src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class CodeGenCUDA final : public CodeGenC {
enable_fp4_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
Expand Down
Loading
Loading