Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
bugfix
  • Loading branch information
yzh119 committed May 16, 2023
commit 91b0fb191816203c7385ad0c2e85bffef47eaa64
80 changes: 22 additions & 58 deletions src/tir/transforms/dtype_conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ PrimExpr ReinterpretAsUInt(PrimExpr value) {

DataType GetStorageUIntDType(DataType dtype) { return DataType::UInt(dtype.bits(), dtype.lanes()); }

/*!
* \brief Conversion from one floating point data type to another floating point data type.
* \param src_value The floating point value to be converted.
* \param tgt_dtype The target floating point data type.
* \param round_mode The rounding mode to use, defaults to kHalfToEven.
*/
PrimExpr FpToFp(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode) {
PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode) {
DataType src_dtype = src_value.dtype();
// Step 1: check dtype
// The lanes of src dtype and target dtype must match.
CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes())
<< "The lanes for data type for source value must matches the target datatype.";
auto is_floating_point = [](DataType dtype) {
return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16();
};
// Both source dtype and target dtype should be floating point.
CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype));
FloatConfig src_fp = FloatConfig::FromDataType(src_value.dtype()),
tgt_fp = FloatConfig::FromDataType(tgt_dtype);
int exponent_delta = tgt_fp.exponent - src_fp.exponent;
Expand All @@ -60,14 +64,13 @@ PrimExpr FpToFp(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode)
PrimExpr ret = src_uint_value;
if (mantissa_delta >= 0) {
ret = cast(tgt_uint, ret) << mantissa_delta;
if (bias_delta != 0) {
ret = ret + (make_const(tgt_uint, bias_delta) << tgt_fp.mantissa);
}
} else { // mantissa_delta < 0
ret = cast(tgt_uint, ret >> (-mantissa_delta));
if (bias_delta != 0) {
ret = ret + (make_const(tgt_uint, bias_delta) << tgt_fp.mantissa);
}
}
if (bias_delta > 0) {
ret = ret + (make_const(tgt_uint, bias_delta) << tgt_fp.mantissa);
} else if (bias_delta < 0) {
ret = ret - (make_const(tgt_uint, -bias_delta) << tgt_fp.mantissa);
}
return reinterpret(tgt_dtype, ret);
} else {
Expand All @@ -77,55 +80,16 @@ PrimExpr FpToFp(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode)
: (src_uint_value << (-mantissa_delta)))) &
make_const(tgt_uint, (int64_t(1) << (tgt_fp.mantissa)) - 1);
PrimExpr ret_exponent =
max(cast(tgt_uint, (((src_uint_value << 1) >> (src_fp.mantissa + 1)) + bias_delta)),
make_const(tgt_uint, 0))
<< tgt_fp.mantissa;
(bias_delta > 0)
? (cast(tgt_uint, ((src_uint_value << 1) >> (src_fp.mantissa + 1)) + bias_delta)
<< tgt_fp.mantissa)
: (cast(tgt_uint, max(((src_uint_value << 1) >> (src_fp.mantissa + 1)) - (-bias_delta),
make_const(tgt_uint, 0)))
<< tgt_fp.mantissa);
PrimExpr ret_sign = make_const(tgt_uint, int64_t(1) << (tgt_fp.mantissa + tgt_fp.exponent));
return reinterpret(tgt_dtype, ret_mantissa | ret_exponent | ret_sign);
}
}

/*!
* \brief Conversion from integer to floating point data type.
* \param src_value The integer value to be converted.
* \param tgt_dtype The target floating point data type.
* \param round_mode The rounding mode to use, defaults to kHalfToEven.
*/
PrimExpr IntToFp(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode) {
// TODO(tvm-team): implement integer to floating point conversion with clz primitive.
LOG(FATAL) << "Not implemented.";
}

/*!
* \brief Conversion from floating point data type to integer.
* \param src_value The floating point value to be converted.
* \param tgt_dtype The target integer data type.
* \param round_mode The rounding mode to use, defaults to kHalfToEven.
*/
PrimExpr FpToInt(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode) {

}

PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode) {
DataType src_dtype = src_value.dtype();
CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes())
<< "The lanes for data type for source value must matches the target datatype.";
auto is_floating_point = [](DataType dtype) {
return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16();
};
auto is_integer = [](DataType dtype) { return dtype.is_int() || dtype.is_uint(); };
if (is_floating_point(src_dtype) && is_floating_point(tgt_dtype)) {
return FpToFp(src_value, tgt_dtype, round_mode);
} else {
if (is_integer(src_dtype) && is_floating_point(tgt_dtype)) {
return IntToFp(src_value, tgt_dtype, round_mode);
} else if (is_floating_point(src_dtype) && is_integer(src_dtype)) {
return FpToInt(src_value, tgt_dtype, round_mode);
} else {
LOG(FATAL) << "Not Implemented yet";
}
}
}

} // namespace tir
} // namespace tvm
7 changes: 4 additions & 3 deletions src/tir/transforms/dtype_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ PrimExpr ReinterpretAsUInt(PrimExpr value);
DataType GetStorageUIntDType(DataType dtype);

/*!
* \brief Conversion routine from value stored in one data type to target data type.
* \param src_value The value to be converted.
* \param tgt_dtype The target data type.
* \brief Conversion routine from value stored in one floating point data type to another floating
* point data type.
* \param src_value The floating point value to be converted.
* \param tgt_dtype The target floating point data type.
* \param round_mode The rounding mode to use, defaults to kHalfToEven.
* \note Used when there is no native data type conversion implementation.
*/
Expand Down
47 changes: 30 additions & 17 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
* \file unsupported_dtype_legalize.cc
* \brief legalize bf16/fp8 type by adding cast_to_fp32
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>


#include <cmath>
#include <tuple>
Expand All @@ -46,8 +45,9 @@ class ComputeLegalizePlanner : public StmtExprVisitor {
public:
ComputeLegalizePlanner(
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* buffer_remap,
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap)
: buffer_remap_(buffer_remap), var_remap_(var_remap) {}
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap,
DataType promote_dtype)
: buffer_remap_(buffer_remap), var_remap_(var_remap), promote_dtype_(promote_dtype) {}

// run planning to populate buffer remap and var remap.
void Plan(PrimFunc func) {
Expand Down Expand Up @@ -76,9 +76,9 @@ class ComputeLegalizePlanner : public StmtExprVisitor {
virtual bool MatchDType(DataType dtype) const = 0;

void VisitStmt_(const AllocateNode* op) final {
// remap all intermediate constant buffer to fp32
// remap all intermediate constant buffer to promote data types (fp16/fp32)
if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) {
DataType dtype = DataType::Float(32, op->dtype.lanes());
DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes());
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
(*var_remap_)[op->buffer_var] = buffer_var;
}
Expand Down Expand Up @@ -113,7 +113,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor {
auto var_it = var_remap_->find(buf->data);
if (var_it == var_remap_->end()) return;

Buffer new_buffer(var_it->second, DataType::Float(32, buf->dtype.lanes()), buf->shape,
Buffer new_buffer(var_it->second, promote_dtype_.with_lanes(buf->dtype.lanes()), buf->shape,
buf->strides, buf->elem_offset, buf->name, buf->data_alignment,
buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span);
(*buffer_remap_)[buf] = new_buffer;
Expand All @@ -122,23 +122,26 @@ class ComputeLegalizePlanner : public StmtExprVisitor {
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* buffer_remap_;
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> opaque_var_access_;
DataType promote_dtype_;
};

class BF16ComputeLegalizePlanner : public ComputeLegalizePlanner {
public:
explicit BF16ComputeLegalizePlanner(
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* buffer_remap,
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap)
: ComputeLegalizePlanner(buffer_remap, var_remap) {}
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap,
DataType promote_dtype)
: ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {}
bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); }
};

class FP8ComputeLegalizePlanner : public ComputeLegalizePlanner {
public:
explicit FP8ComputeLegalizePlanner(
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* buffer_remap,
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap)
: ComputeLegalizePlanner(buffer_remap, var_remap) {}
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>* var_remap,
DataType promote_dtype)
: ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {}
bool MatchDType(DataType dtype) const { return dtype.is_float8(); }
};

Expand All @@ -163,14 +166,15 @@ class ComputeLegalizer : public StmtExprMutator {
public:
ComputeLegalizer(DataType promote_dtype) : promote_dtype_(promote_dtype) {}

PrimFunc Legalize(PrimFunc func) {
BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_);
planner.Plan(func);
PrimFunc LegalizeWithPlanner(PrimFunc func, ComputeLegalizePlanner* planner) {
planner->Plan(func);
auto* n = func.CopyOnWrite();
n->body = this->VisitStmt(std::move(n->body));
return func;
}

virtual PrimFunc Legalize(PrimFunc func) = 0;

virtual bool MatchDType(DataType dtype) const = 0;

protected:
Expand All @@ -179,7 +183,7 @@ class ComputeLegalizer : public StmtExprMutator {

// all casts to matched data type (fp8/bf16) becomes f32
if (MatchDType(op->dtype)) {
return cast(DataType::Float(32, op->dtype.lanes()), op_val);
return cast(promote_dtype_.with_lanes(op->dtype.lanes()), op_val);
}

if (op_val.same_as(op->value)) {
Expand Down Expand Up @@ -229,7 +233,7 @@ class ComputeLegalizer : public StmtExprMutator {
auto fmutate = [this](const PrimExpr& e) { return PromoteToTarget(this->VisitExpr(e)); };
Array<PrimExpr> args = op->args.Map(fmutate);
if (MatchDType(op->dtype)) {
return Call(DataType::Float(32, op->dtype.lanes()), op->op, args);
return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args);
}
if (args.same_as(op->args)) {
return GetRef<PrimExpr>(op);
Expand All @@ -240,7 +244,7 @@ class ComputeLegalizer : public StmtExprMutator {

PrimExpr VisitExpr_(const FloatImmNode* op) final {
if (MatchDType(op->dtype)) {
return FloatImm(DataType::Float(32), op->value);
return FloatImm(promote_dtype_, op->value);
}
return GetRef<PrimExpr>(op);
}
Expand Down Expand Up @@ -432,6 +436,7 @@ class ComputeLegalizer : public StmtExprMutator {
return buf;
}

protected:
DataType promote_dtype_;
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
Expand All @@ -440,12 +445,20 @@ class ComputeLegalizer : public StmtExprMutator {
class BF16ComputeLegalizer : public ComputeLegalizer {
public:
BF16ComputeLegalizer() : ComputeLegalizer(DataType::Float(32)) {}
PrimFunc Legalize(PrimFunc func) {
BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_, promote_dtype_);
return LegalizeWithPlanner(func, &planner);
}
bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); }
};

class FP8ComputeLegalizer : public ComputeLegalizer {
public:
FP8ComputeLegalizer(DataType promote_dtype) : ComputeLegalizer(promote_dtype) {}
PrimFunc Legalize(PrimFunc func) {
FP8ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_, promote_dtype_);
return LegalizeWithPlanner(func, &planner);
}
bool MatchDType(DataType dtype) const { return dtype.is_float8(); }
};

Expand Down
10 changes: 5 additions & 5 deletions tests/python/unittest/test_tir_transform_fp8_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def main(
Aptr: T.handle(dtype), Bptr: T.handle(dtype), Dptr: T.handle(dtype)
):
T.func_attr({"global_symbol": "main"})
A = T.decl_buffer((100,), "dtype", data=Aptr)
B = T.decl_buffer((100,), "dtype", data=Bptr)
D = T.decl_buffer((100,), "dtype", data=Dptr)
C = T.decl_buffer((100,), "dtype")
A = T.decl_buffer((100,), dtype, data=Aptr)
B = T.decl_buffer((100,), dtype, data=Bptr)
D = T.decl_buffer((100,), dtype, data=Dptr)
C = T.decl_buffer((100,), dtype)
for i in T.grid(100):
C[i] = A[i] + B[i]
D[i] = T.exp(C[i])
Expand All @@ -54,7 +54,7 @@ def main(
A = T.decl_buffer((100,), dtype, data=Aptr)
B = T.decl_buffer((100,), dtype, data=Bptr)
D = T.decl_buffer((100,), dtype, data=Dptr)
C = T.decl_buffer((100,), dtype)
C = T.decl_buffer((100,), promote_dtype)
for i in T.grid(100):
C[i] = promote_f8(dtype, promote_dtype, A[i]) + promote_f8(dtype, promote_dtype, B[i])
D[i] = cast_to_f8(dtype, promote_dtype, T.exp(C[i]))
Expand Down