From be03f13c55e30392451f4a2698c64bbec5128f8f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Apr 2024 14:43:40 -0700 Subject: [PATCH 1/2] [TVMScript] Add parser and printer support for e4m3/e5m2 fp8 --- include/tvm/script/ir_builder/tir/ir.h | 12 ++++++ python/tvm/script/ir_builder/tir/ir.py | 39 +++++++++++++------ src/script/ir_builder/tir/ir.cc | 5 +++ .../codegen/test_target_codegen_cuda_fp8.py | 14 +++++++ .../tvmscript/test_tvmscript_printer_tir.py | 31 +++++++++++++++ 5 files changed, 89 insertions(+), 12 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 735d5ba6c0a1..c4ba44f67359 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -489,6 +489,18 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); + +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64)); + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2); + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a5c09cf1a311..127d2a4356b1 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1408,30 +1408,39 @@ def func( uint32x64 = func_gen(("UInt32x64")) uint64x64 = func_gen(("UInt64x64")) -float8 = func_gen(("Float8")) float16 = func_gen(("Float16")) float32 = func_gen(("Float32")) float64 = func_gen(("Float64")) -float8x4 = func_gen(("Float8x4")) float16x4 = func_gen(("Float16x4")) float32x4 = func_gen(("Float32x4")) float64x4 = func_gen(("Float64x4")) -float8x8 = func_gen(("Float8x8")) float16x8 = func_gen(("Float16x8")) float32x8 = func_gen(("Float32x8")) float64x8 = func_gen(("Float64x8")) -float8x16 = func_gen(("Float8x16")) float16x16 = func_gen(("Float16x16")) float32x16 = func_gen(("Float32x16")) float64x16 = func_gen(("Float64x16")) -float8x32 = func_gen(("Float8x32")) float16x32 = func_gen(("Float16x32")) float32x32 = func_gen(("Float32x32")) float64x32 = func_gen(("Float64x32")) -float8x64 = func_gen(("Float8x64")) float16x64 = func_gen(("Float16x64")) float32x64 = func_gen(("Float32x64")) float64x64 = func_gen(("Float64x64")) + +e4m3_float8 = func_gen(("E4M3Float8")) +e4m3_float8x4 = func_gen(("E4M3Float8x4")) +e4m3_float8x8 = func_gen(("E4M3Float8x8")) +e4m3_float8x16 = func_gen(("E4M3Float8x16")) +e4m3_float8x32 = func_gen(("E4M3Float8x32")) +e4m3_float8x64 = func_gen(("E4M3Float8x64")) + +e5m2_float8 = func_gen(("E5M2Float8")) +e5m2_float8x4 = func_gen(("E5M2Float8x4")) +e5m2_float8x8 = func_gen(("E5M2Float8x8")) +e5m2_float8x16 = func_gen(("E5M2Float8x16")) +e5m2_float8x32 = func_gen(("E5M2Float8x32")) +e5m2_float8x64 = func_gen(("E5M2Float8x64")) + # pylint: enable=invalid-name @@ -1954,27 +1963,33 @@ def wrapped(*args, **kwargs): "uint16x64", "uint32x64", "uint64x64", - "float8", + "e4m3_float8", + "e5m2_float8", "float16", "float32", "float64", - "float8x4", + "e4m3_float8x4", + "e5m2_float8x4", "float16x4", "float32x4", "float64x4", - "float8x8", + "e4m3_float8x8", + "e5m2_float8x8", "float16x8", "float32x8", "float64x8", - "float8x16", + "e4m3_float8x16", + "e5m2_float8x16", "float16x16", "float32x16", "float64x16", - "float8x32", + "e4m3_float8x32", + "e5m2_float8x32", "float16x32", "float32x32", "float64x32", - "float8x64", + "e4m3_float8x64", + "e5m2_float8x64", "float16x64", "float32x64", "float64x64", diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 1ae1051d254d..ccb5a8b57b5b 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -751,6 +751,11 @@ TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index dade970418f9..b8924e17b20e 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -799,5 +799,19 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2) +@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) +def test_const(dtype): + @T.prim_func + def func(A: T.Buffer((4,), dtype)) -> None: + A_local = T.alloc_buffer((4,), dtype=dtype, scope="local") + for tx in T.thread_binding(0, 4, "threadIdx.x"): + for i in T.vectorized(4): + A_local[i] = T.float32(1.0).astype(dtype) + A[tx] = A_local[tx] + + mod = tvm.IRModule({"main": func}) + tvm.build(mod, target="cuda") + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 97a6b889c011..edc6da31636b 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -917,5 +917,36 @@ def func(): _assert_print(func, expected_output) +@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"]) +def test_float8(dtype): + from tvm.script import tir as T + + def get_func(dtype): + if dtype == "e4m3_float8": + + @T.prim_func + def func(): + T.evaluate(T.e4m3_float8(0.0)) + + return func + elif dtype == "e5m2_float8": + + @T.prim_func + def func(): + T.evaluate(T.e5m2_float8(0.0)) + + return func + + expected_output = f""" +# from tvm.script import tir as T + +@T.prim_func +def func(): + T.evaluate(T.{dtype}(0)) + """ + func = get_func(dtype) + _assert_print(func, expected_output) + + if __name__ == "__main__": tvm.testing.main() From e74706b9616e78eb59ac1633358ec0cae88b1afd Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Apr 2024 15:46:18 -0700 Subject: [PATCH 2/2] remove unrelated --- .../python/codegen/test_target_codegen_cuda_fp8.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index b8924e17b20e..dade970418f9 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -799,19 +799,5 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2) -@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) -def test_const(dtype): - @T.prim_func - def func(A: T.Buffer((4,), dtype)) -> None: - A_local = T.alloc_buffer((4,), dtype=dtype, scope="local") - for tx in T.thread_binding(0, 4, "threadIdx.x"): - for i in T.vectorized(4): - A_local[i] = T.float32(1.0).astype(dtype) - A[tx] = A_local[tx] - - mod = tvm.IRModule({"main": func}) - tvm.build(mod, target="cuda") - - if __name__ == "__main__": tvm.testing.main()