Skip to content

Commit 532db33

Browse files
authored
[TIR] Fix host/device function check for build (apache#18199)
This PR fixes a bug of deciding whether a function is host or device function in TIR build. Previously the decision is made based on checking whether `"cpu"` is a substring of the target string. This check fails to work for ROCm target, which usually comes with an `"mcpu"` attribute that also contains `"cpu"`. This PR fixes by checking target kind. Targets with kind `"llvm"` or `"c"` will be treated as host functions.
1 parent d6c3dea commit 532db33

2 files changed

Lines changed: 15 additions & 11 deletions

File tree

python/tvm/tir/build.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
# pylint: disable=invalid-name
1919
"""The build utils in python."""
20-
from typing import Union, Optional, Dict, Tuple
20+
from typing import Dict, Optional, Tuple, Union
2121

2222
import tvm
2323
from tvm import ir
24-
from tvm.runtime import ndarray
25-
from tvm.tir import PrimFunc
2624
from tvm.ir.module import IRModule
25+
from tvm.runtime import ndarray
2726
from tvm.target import Target
27+
from tvm.tir import PrimFunc
2828

2929

3030
def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, IRModule]]:
@@ -100,10 +100,12 @@ def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.han
100100
- Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch)
101101
"""
102102

103-
host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in str(f.attrs.get("target", "cpu")))(mod)
104-
device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in str(f.attrs.get("target", "cpu")))(
105-
mod
106-
)
103+
def is_host_func(f):
104+
target = f.attrs.get("target", tvm.target.Target("llvm"))
105+
return str(target.kind) in ["llvm", "c"]
106+
107+
host_mod = tvm.tir.transform.Filter(is_host_func)(mod)
108+
device_mod = tvm.tir.transform.Filter(lambda f: not is_host_func(f))(mod)
107109
# TODO(syfeng): Here we use str as key since target hash is not correct
108110
target_str2target = {}
109111
device_func_dict = {}

tests/python/codegen/test_target_codegen_cuda.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -820,10 +820,12 @@ def main(
820820
for tx in T.thread_binding(length, "threadIdx.x"):
821821
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device
822822

823-
# If we set host to llvm, it will raise an error of
824-
# "the tir.ret should be transformed to return zero before the llvm code generation."
825-
# Need to revisit this.
826-
target = tvm.target.Target("cuda", host="c")
823+
# 1. If we set host to llvm, it will raise an error of
824+
# "the tir.ret should be transformed to return zero before the llvm code generation."
825+
# Need to revisit this.
826+
# 2. We set a dummy mcpu value for testing purpose,
827+
# in order to avoid checking a function is host or device based on the "cpu" substring.
828+
target = tvm.target.Target({"kind": "cuda", "mcpu": "dummy_mcpu"}, host="c")
827829
lib = tvm.compile(Module, target=target)
828830
cuda_code = lib.mod.imported_modules[0].get_source()
829831
assert 'extern "C" __device__ int add(int a, int b) {\n return (a + b);\n}' in cuda_code

0 commit comments

Comments
 (0)