Skip to content

Commit 1c5cdb8

Browse files
authored
Add tfloat32 datatype (apache#31)
* Add tfloat32 datatype * fix: change tfloat32 type code to 130 * minor fix
1 parent 882a774 commit 1c5cdb8

7 files changed

Lines changed: 61 additions & 10 deletions

File tree

include/tvm/runtime/data_type.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class DataType {
7272
kFloat6_e2m3fn = kDLFloat6_e2m3fn,
7373
kFloat6_e3m2fn = kDLFloat6_e3m2fn,
7474
kFloat4_e2m1fn = kDLFloat4_e2m1fn,
75-
kCustomBegin = 129
75+
kCustomBegin = 129,
76+
kTensorFloat32 = 130
7677
};
7778
/*! \brief default constructor */
7879
DataType() { data_ = DataType::Void(); }
@@ -109,6 +110,9 @@ class DataType {
109110
if (code == kFloat4_e2m1fn) {
110111
ICHECK_EQ(bits, 4);
111112
}
113+
if (code == kTensorFloat32) {
114+
ICHECK_EQ(bits, 32);
115+
}
112116
}
113117
/*! \return The type code. */
114118
int code() const { return static_cast<int>(data_.code); }
@@ -146,6 +150,8 @@ class DataType {
146150
bool is_float() const { return code() == DataType::kFloat; }
147151
/*! \return whether type is a bfloat type. */
148152
bool is_bfloat() const { return code() == DataType::kBFloat; }
153+
/*! \return whether type is a tfloat type. */
154+
bool is_tfloat() const { return code() == DataType::kTensorFloat32; }
149155
/*! \return whether type is any 8-bit custom Float8 variant. */
150156
bool is_float8() const {
151157
return bits() == 8 &&
@@ -185,6 +191,8 @@ class DataType {
185191
bool is_float6_e3m2fn() const { return bits() == 6 && code() == DataType::kFloat6_e3m2fn; }
186192
/*! \return whether type is Float4E2M1FN. */
187193
bool is_float4_e2m1fn() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; }
194+
/*! \return whether type is a tfloat32 type. */
195+
bool is_tfloat32() const { return bits() == 32 && code() == DataType::kTensorFloat32; }
188196
/*! \return whether type is a float16 type. */
189197
bool is_float16() const { return is_float() && bits() == 16; }
190198
/*! \return whether type is a bfloat16 type. */
@@ -377,6 +385,14 @@ class DataType {
377385
* \return The constructed data type.
378386
*/
379387
static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
388+
389+
/*!
390+
* \brief Construct a tensorfloat32 datatype.
391+
* \param lanes The number of lanes
392+
* \return The constructed data type.
393+
*/
394+
static DataType TensorFloat32(int lanes = 1) { return DataType(kTensorFloat32, 32, lanes); }
395+
380396
/*!
381397
* \brief Construct a bool type.
382398
* \param lanes The number of lanes.

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::Float
529529

530530
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::Float4E2M1FN);
531531

532+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(TensorFloat32, DataType::TensorFloat32);
533+
532534
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
533535
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
534536

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,13 @@ class bfloat16x8: ...
15841584
class bfloat16x16: ...
15851585
class bfloat16x32: ...
15861586
class bfloat16x64: ...
1587+
class tfloat32: ...
1588+
class tfloat32x2: ...
1589+
class tfloat32x4: ...
1590+
class tfloat32x8: ...
1591+
class tfloat32x16: ...
1592+
class tfloat32x32: ...
1593+
class tfloat32x64: ...
15871594
else:
15881595
# pylint: disable=invalid-name
15891596
int8 = func_gen(("Int8"))
@@ -1756,6 +1763,14 @@ class bfloat16x64: ...
17561763
bfloat16x16 = func_gen(("BFloat16x16"))
17571764
bfloat16x32 = func_gen(("BFloat16x32"))
17581765
bfloat16x64 = func_gen(("BFloat16x64"))
1766+
1767+
tfloat32 = func_gen(("TensorFloat32"))
1768+
tfloat32x2 = func_gen(("TensorFloat32x2"))
1769+
tfloat32x4 = func_gen(("TensorFloat32x4"))
1770+
tfloat32x8 = func_gen(("TensorFloat32x8"))
1771+
tfloat32x16 = func_gen(("TensorFloat32x16"))
1772+
tfloat32x32 = func_gen(("TensorFloat32x32"))
1773+
tfloat32x64 = func_gen(("TensorFloat32x64"))
17591774
# pylint: enable=invalid-name
17601775

17611776

@@ -2337,6 +2352,13 @@ def wrapped(*args, **kwargs):
23372352
"bfloat16x16",
23382353
"bfloat16x32",
23392354
"bfloat16x64",
2355+
"tfloat32",
2356+
"tfloat32x2",
2357+
"tfloat32x4",
2358+
"tfloat32x8",
2359+
"tfloat32x16",
2360+
"tfloat32x32",
2361+
"tfloat32x64",
23402362
"buffer",
23412363
"buffer_decl",
23422364
"prim_func",

src/script/ir_builder/tir/ir.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
894894
.TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN);
895895
}
896896

897+
TVM_FFI_STATIC_INIT_BLOCK() {
898+
namespace refl = tvm::ffi::reflection;
899+
refl::GlobalDef()
900+
.def("script.ir_builder.tir.TensorFloat32", TensorFloat32)
901+
.TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.TensorFloat32", TensorFloat32);
902+
}
903+
897904
TVM_FFI_STATIC_INIT_BLOCK() {
898905
namespace refl = tvm::ffi::reflection;
899906
refl::GlobalDef()

src/target/datatype/registry.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
4747
.def_packed("runtime._datatype_get_type_registered", [](ffi::PackedArgs args, ffi::Any* ret) {
4848
*ret = Registry::Global()->GetTypeRegistered(args[0].cast<int>());
4949
});
50+
// Register tfloat32 as a custom datatype with type code 130
51+
Registry::Global()->Register("tfloat32", 130);
5052
}
5153

5254
Registry* Registry::Global() {

src/target/source/intrin_rule_cuda.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ struct CUDAMath {
5252
default:
5353
return "";
5454
}
55+
} else if (t.is_tfloat32()) {
56+
if (name == "fabs") {
57+
return "abs";
58+
}
5559
} else if (t.is_bfloat16()) {
5660
if (name == "fabs") {
5761
return "__habs";

src/tir/op/op.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ PrimExpr max_value(const DataType& dtype, Span span) {
301301
} else if (dtype.bits() == 16) {
302302
return FloatImm(dtype, 65504.0, span);
303303
}
304+
} else if (dtype.is_tfloat32()) {
305+
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
304306
} else if (dtype.is_bfloat16()) {
305307
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
306308
} else if (dtype.is_float8()) {
@@ -336,14 +338,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
336338
PrimExpr min_value(const DataType& dtype, Span span) {
337339
using namespace tir;
338340
ICHECK_EQ(dtype.lanes(), 1);
339-
if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
340-
// TODO(tkonolige): need to convert all registered min functions to use the span.
341-
auto f = datatype::GetMinFunc(dtype.code());
342-
ICHECK(f) << "No minimum function registered for custom dtype " << (unsigned int)dtype.code();
343-
// TODO(@hypercubestart) Document this change (and others associated with the overflowing
344-
// floatimm min bug)
345-
return (*f)(dtype.bits()).cast<PrimExpr>();
346-
} else if (dtype.is_int()) {
341+
if (dtype.is_int()) {
347342
if (dtype.bits() == 64) {
348343
return IntImm(dtype, std::numeric_limits<int64_t>::lowest(), span);
349344
} else if (dtype.bits() < 64) {
@@ -361,6 +356,9 @@ PrimExpr min_value(const DataType& dtype, Span span) {
361356
} else if (dtype.bits() == 16) {
362357
return FloatImm(dtype, -65504.0, span);
363358
}
359+
}
360+
else if (dtype.is_tfloat32()) {
361+
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
364362
} else if (dtype.is_bfloat16()) {
365363
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
366364
} else if (dtype.is_float8()) {
@@ -888,7 +886,7 @@ PrimExpr abs(PrimExpr x, Span span) {
888886
return IntImm(x.dtype(), std::abs(px->value), px->span);
889887
}
890888
return tir::Select(x >= make_zero(x.dtype()), x, -x, span);
891-
} else if (x.dtype().is_float() || x.dtype().is_bfloat()) {
889+
} else if (x.dtype().is_float() || x.dtype().is_bfloat() || x.dtype().is_tfloat()) {
892890
using tir::FloatImmNode;
893891
const FloatImmNode* fx = x.as<FloatImmNode>();
894892
if (fx) {

0 commit comments

Comments
 (0)