Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ pip3 install --upgrade \
junitparser==2.4.2 \
six \
tornado \
pytest-lazy-fixture
pytest-lazy-fixture \
ml_dtypes
34 changes: 34 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

namespace tvm {
namespace runtime {

/*!
* \brief Runtime primitive data type.
*
Expand All @@ -54,6 +55,8 @@ class DataType {
kFloat = kDLFloat,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kBFloat = kDLBfloat,
kE4M3Float = 6U,
kE5M2Float = 7U,
kCustomBegin = 129
};
/*! \brief default constructor */
Expand All @@ -76,6 +79,9 @@ class DataType {
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
if (code == kE4M3Float || code == kE5M2Float) {
ICHECK_EQ(bits, 8);
}
}
/*! \return The type code. */
int code() const { return static_cast<int>(data_.code); }
Expand All @@ -91,6 +97,12 @@ class DataType {
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
code() == DataType::kE5M2Float) &&
bits() == 8;
}
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
Expand Down Expand Up @@ -183,6 +195,18 @@ class DataType {
* \return The constructed data type.
*/
static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
/*!
* \brief Construct NV float8 e4m3 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); }
/*!
* \brief Construct NV float8 e5m2 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
Expand Down Expand Up @@ -308,6 +332,10 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
return "handle";
case kDLBfloat:
return "bfloat";
case DataType::kE4M3Float:
return "e4m3_float";
case DataType::kE5M2Float:
return "e5m2_float";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
Expand Down Expand Up @@ -376,6 +404,12 @@ inline DLDataType String2DLDataType(std::string s) {
} else if (s.substr(0, 6) == "bfloat") {
t.code = DataType::kBFloat;
scan = s.c_str() + 6;
} else if (s.substr(0, 10) == "e4m3_float") {
t.code = DataType::kE4M3Float;
scan = s.c_str() + 10;
} else if (s.substr(0, 10) == "e5m2_float") {
t.code = DataType::kE5M2Float;
scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,25 @@ TVM_DLL Pass NarrowDataType(int target_bits);
*/
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
* \return The pass.
*/
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float32");

/*!
* \brief Legalize bf16 storage types to u16.
* \return The pass.
*/
TVM_DLL Pass BF16StorageLegalize();

/*!
* \brief Legalize fp8 storage types to u8.
* \return The pass.
*/
TVM_DLL Pass FP8StorageLegalize();

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand Down
1 change: 1 addition & 0 deletions python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"psutil",
"scipy",
"tornado",
"ml_dtypes",
],
),
),
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ctypes
import json
import numpy as np
import ml_dtypes
from .base import _LIB, check_call

tvm_shape_index_t = ctypes.c_int64
Expand Down Expand Up @@ -59,6 +60,8 @@ class DataTypeCode(object):
FLOAT = 2
HANDLE = 3
BFLOAT = 4
E4M3Float = 6
E5M2Float = 7


class DataType(ctypes.Structure):
Expand All @@ -71,6 +74,8 @@ class DataType(ctypes.Structure):
DataTypeCode.FLOAT: "float",
DataTypeCode.HANDLE: "handle",
DataTypeCode.BFLOAT: "bfloat",
DataTypeCode.E4M3Float: "e4m3_float",
DataTypeCode.E5M2Float: "e5m2_float",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
Expand All @@ -86,6 +91,9 @@ class DataType(ctypes.Structure):
np.dtype(np.float32): "float32",
np.dtype(np.float64): "float64",
np.dtype(np.float_): "float64",
np.dtype(ml_dtypes.bfloat16): "bfloat16",
np.dtype(ml_dtypes.float8_e4m3fn): "e4m3_float8",
np.dtype(ml_dtypes.float8_e5m2): "e5m2_float8",
}
STR2DTYPE = {
"bool": {"type_code": DataTypeCode.UINT, "bits": 1, "lanes": 1},
Expand All @@ -97,6 +105,9 @@ class DataType(ctypes.Structure):
"uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
"uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"float8": {"type_code": DataTypeCode.FLOAT, "bits": 8, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
Expand Down Expand Up @@ -141,6 +152,12 @@ def __init__(self, type_str):
elif head.startswith("bfloat"):
self.type_code = DataTypeCode.BFLOAT
head = head[6:]
elif head.startswith("e4m3_float"):
self.type_code = DataTypeCode.E4M3Float
head = head[10:]
elif head.startswith("e5m2_float"):
self.type_code = DataTypeCode.E5M2Float
head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,20 @@ def have_bf16(compute_version):
return True

return False


def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not

Parameters
----------
compute_version : str
GPU capability
"""
major, minor = parse_compute_version(compute_version)
# fp8 is suppored in Ada Lovelace (8.9) or later architectures.
if major == 8 and minor == 9:
return True
if major >= 9:
return True
return False
5 changes: 5 additions & 0 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ctypes
import warnings
import numpy as np
import ml_dtypes
import tvm._ffi

from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
Expand Down Expand Up @@ -220,6 +221,10 @@ def numpy(self):
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
if dtype == "e4m3_float8":
dtype = ml_dtypes.float8_e4m3fn
if dtype == "e5m2_float8":
dtype = ml_dtypes.float8_e5m2
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def Apply(ftransform):
fpass : tvm.transform.Pass
The result pass
"""

# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return ftransform(func)
Expand Down Expand Up @@ -297,6 +298,22 @@ def BF16ComputeLegalize():
return _ffi_api.BF16ComputeLegalize() # type: ignore


def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
"""Legalize fp8 compute Ops.

Parameters
----------
promote_dtype : str
The data type we promote fp8 to, options: float16/float32.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore


def BF16StorageLegalize():
"""Legalize bf16 storage types to u16.

Expand All @@ -308,6 +325,17 @@ def BF16StorageLegalize():
return _ffi_api.BF16StorageLegalize() # type: ignore


def FP8StorageLegalize():
"""Legalize fp8 storage types to u8.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FP8StorageLegalize() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
"""Replace redundant computations by new variables.

Expand Down
2 changes: 2 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::FP8ComputeLegalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down Expand Up @@ -586,6 +587,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

Expand Down
17 changes: 17 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_bfloat16_util;
}

if (enable_fp8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
decl_stream << "#include <cuda_fp8.h>\n";
decl_stream << "#endif\n\n";
}

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
}
Expand Down Expand Up @@ -249,6 +255,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
fail = true;
}
if (!fail) return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
Expand Down
5 changes: 4 additions & 1 deletion src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class CodeGenCUDA final : public CodeGenC {
void Init(bool output_ssa);
std::string Finish();
bool need_include_path() {
return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || need_math_constants_h_ ||
need_mma_h_);
}
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
Expand Down Expand Up @@ -93,6 +94,8 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_fp16_{false};
// whether enable bf16
bool enable_bf16_{false};
// whether enable fp8
bool enable_fp8_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
Expand Down
Loading