From e619731efd2d012157e2a31bb827702dcdad5e27 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 12 Apr 2024 18:39:25 -0700 Subject: [PATCH 1/2] [TIR] Make T.reinterpret nop when dtype is the same --- python/tvm/tir/op.py | 4 ++-- src/tir/op/op.cc | 2 ++ .../tvmscript/test_tvmscript_parser_tir.py | 22 +++++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 8816880e7b52..6b72e63f2990 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1789,7 +1789,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any: return _ffi_api.infinity(dtype, span) # type: ignore -def reinterpret(dtype, value) -> Any: +def reinterpret(dtype, value, span: Optional[Span] = None) -> Any: """infinity value of dtype Parameters @@ -1808,7 +1808,7 @@ def reinterpret(dtype, value) -> Any: value : tvm.Expr The reinterpret cast value of dtype. """ - return call_intrin(dtype, "tir.reinterpret", value) + return _ffi_api.reinterpret(dtype, value, span) # type: ignore def exp(x): diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 7f47e660625b..29b2f0e27153 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -1083,6 +1083,8 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); + // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 465ffa5cb602..530746a6fcb6 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -449,5 +449,27 @@ def func(a_handle: T.handle, b_handle: T.handle): tvm.ir.assert_structural_equal(func.struct_info, expected) +def test_reinterpret_nop(): + """Test builtin reinterpret op""" + + @T.prim_func + def func(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None: + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 32): + with T.block(): + vi = T.axis.remap("S", [i]) + B[vi] = T.reinterpret("float32", A[vi]) + + @T.prim_func + def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None: + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 32): + with T.block(): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + tvm.ir.assert_structural_equal(func, expected) + + if __name__ == "__main__": tvm.testing.main() From 4d57b00193f762a7ea2156cc1e9113b5f6445f65 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 13 Apr 2024 12:52:22 -0700 Subject: [PATCH 2/2] fix scalable vec handling --- src/tir/op/op.cc | 6 ++++-- tests/python/codegen/test_target_codegen_cuda.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 29b2f0e27153..b61363978615 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -409,8 +409,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { // reinterpret PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { if (value.dtype() == t) return value; - ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) - << "Bitcast requires size match " << t << " vs " << value.dtype(); + if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) { + ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) + << "Bitcast requires size match " << t << " vs " << value.dtype(); + } return tir::Call(t, tir::builtin::reinterpret(), {value}, span); } diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 23ba0fc3ce3a..112c521d06d4 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -1120,7 +1120,7 @@ def test_invalid_reinterpret(): @T.prim_func def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: for tx in T.thread_binding(4, "threadIdx.x"): - B[tx] = T.reinterpret("uint8", A[tx]) + B[tx] = T.call_intrin("uint8", "tir.reinterpret", A[tx]) with pytest.raises(tvm.error.TVMError): tvm.build(func, target="cuda")