@@ -301,6 +301,8 @@ PrimExpr max_value(const DataType& dtype, Span span) {
301301 } else if (dtype.bits () == 16 ) {
302302 return FloatImm (dtype, 65504.0 , span);
303303 }
304+ } else if (dtype.is_tfloat32 ()) {
305+ return FloatImm (dtype, std::numeric_limits<float >::max (), span);
304306 } else if (dtype.is_bfloat16 ()) {
305307 return FloatImm (dtype, std::numeric_limits<float >::max (), span);
306308 } else if (dtype.is_float8 ()) {
@@ -336,14 +338,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
336338PrimExpr min_value (const DataType& dtype, Span span) {
337339 using namespace tir ;
338340 ICHECK_EQ (dtype.lanes (), 1 );
339- if (datatype::Registry::Global ()->GetTypeRegistered (dtype.code ())) {
340- // TODO(tkonolige): need to convert all registered min functions to use the span.
341- auto f = datatype::GetMinFunc (dtype.code ());
342- ICHECK (f) << " No minimum function registered for custom dtype " << (unsigned int )dtype.code ();
343- // TODO(@hypercubestart) Document this change (and others associated with the overflowing
344- // floatimm min bug)
345- return (*f)(dtype.bits ()).cast <PrimExpr>();
346- } else if (dtype.is_int ()) {
341+ if (dtype.is_int ()) {
347342 if (dtype.bits () == 64 ) {
348343 return IntImm (dtype, std::numeric_limits<int64_t >::lowest (), span);
349344 } else if (dtype.bits () < 64 ) {
@@ -361,6 +356,9 @@ PrimExpr min_value(const DataType& dtype, Span span) {
361356 } else if (dtype.bits () == 16 ) {
362357 return FloatImm (dtype, -65504.0 , span);
363358 }
359+ }
360+ else if (dtype.is_tfloat32 ()) {
361+ return FloatImm (dtype, std::numeric_limits<float >::lowest (), span);
364362 } else if (dtype.is_bfloat16 ()) {
365363 return FloatImm (dtype, std::numeric_limits<float >::lowest (), span);
366364 } else if (dtype.is_float8 ()) {
@@ -888,7 +886,7 @@ PrimExpr abs(PrimExpr x, Span span) {
888886 return IntImm (x.dtype (), std::abs (px->value ), px->span );
889887 }
890888 return tir::Select (x >= make_zero (x.dtype ()), x, -x, span);
891- } else if (x.dtype ().is_float () || x.dtype ().is_bfloat ()) {
889+ } else if (x.dtype ().is_float () || x.dtype ().is_bfloat () || x. dtype (). is_tfloat () ) {
892890 using tir::FloatImmNode;
893891 const FloatImmNode* fx = x.as <FloatImmNode>();
894892 if (fx) {
0 commit comments