Skip to content

Commit 880126a

Browse files
Merge branch 'apache:main' into compute_at_floordivmod
2 parents fff360b + 7131411 commit 880126a

2 files changed

Lines changed: 21 additions & 5 deletions

File tree

src/tir/transforms/split_host_device.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& g
108108

109109
HostDeviceSplitter splitter(device_mod, name_prefix);
110110

111-
auto body = splitter(func->body);
112-
113-
if (!body.same_as(func->body)) {
111+
if (auto body = splitter(func->body); !body.same_as(func->body)) {
114112
func.CopyOnWrite()->body = body;
115-
auto target_host = target->GetHost().value_or(Target("llvm"));
116-
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
113+
}
114+
115+
if (auto target_host = target->GetHost()) {
116+
func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
117117
}
118118

119119
return func;

tests/python/unittest/test_tir_transform_split_host_device.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,5 +168,21 @@ def main_kernel(n: T.int32):
168168
return mod
169169

170170

171+
class TestSplitHostDevice(BaseCompare):
172+
"""Like TestSplitHostDevice, but no device regions to extract
173+
174+
Even if there are no device regions, the host-side function should
175+
still have its "target" attribute updated.
176+
"""
177+
178+
def before():
179+
T.func_attr({"target": T.target("ext_dev", host="llvm")})
180+
T.evaluate(0)
181+
182+
def expected():
183+
T.func_attr({"target": T.target("llvm")})
184+
T.evaluate(0)
185+
186+
171187
if __name__ == "__main__":
172188
tvm.testing.main()

0 commit comments

Comments
 (0)