diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 584b3cbf58f4..c52027acba13 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -435,10 +435,19 @@ class IRConvertSSA final : public StmtExprMutator { private: struct ScopedRedefine { ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) { + bool is_size_var = old_var->IsInstance(); if (old_var->type_annotation.defined()) { - new_var = Var(old_var->name_hint, old_var->type_annotation); + if (is_size_var) { + new_var = SizeVar(old_var->name_hint, old_var->type_annotation); + } else { + new_var = Var(old_var->name_hint, old_var->type_annotation); + } } else { - new_var = Var(old_var->name_hint, old_var->dtype); + if (is_size_var) { + new_var = SizeVar(old_var->name_hint, old_var->dtype); + } else { + new_var = Var(old_var->name_hint, old_var->dtype); + } } parent->scope_[old_var.get()].push_back(new_var); } diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py b/tests/python/tir-transform/test_tir_transform_split_host_device.py index 6adfbeb81d54..2d0d8a68d83e 100644 --- a/tests/python/tir-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te import tvm.testing -from tvm.script import tir as T, ir as I +from tvm import te +from tvm.script import ir as I +from tvm.script import tir as T @tvm.testing.requires_cuda @@ -345,5 +346,25 @@ def default_function_kernel( tvm.ir.assert_structural_equal(expected, after) +def test_size_var(): + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle): + T.func_attr({"target": T.target("cuda")}) + m = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (m,)) + B = T.match_buffer(var_B, (m,)) + T.attr(T.target("cuda"), "target", 0) + blockIdx_x = T.launch_thread("blockIdx.x", m) + B_1 = T.Buffer((m,), data=B.data) + A_1 = T.Buffer((m,), data=A.data) + B_1[blockIdx_x] = A_1[blockIdx_x] + + after = tvm.tir.transform.SplitHostDevice()(Module) + assert len(after["main_kernel"].params) == 3 + assert isinstance(after["main_kernel"].params[2], tvm.tir.SizeVar) + + if __name__ == "__main__": tvm.testing.main()