Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions backend/src/nodes/impl/tensorrt/engine_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class BuildConfig:
"""Configuration for TensorRT engine building."""

precision: Literal["fp32", "fp16"]
precision: Literal["fp32", "fp16", "bf16"]
workspace_size_gb: float
min_shape: tuple[int, int] # (height, width)
opt_shape: tuple[int, int] # (height, width)
Expand Down Expand Up @@ -93,6 +93,14 @@ def configure_builder_config(
logger.info("FP16 mode enabled")
else:
logger.warning("FP16 not supported on this platform, falling back to FP32")
elif config.precision == "bf16":
if hasattr(trt.BuilderFlag, "BF16"):
builder_config.set_flag(trt.BuilderFlag.BF16)
logger.info("BF16 mode enabled")
else:
logger.warning(
"BF16 not supported by this TensorRT version, falling back to FP32"
)

return builder_config

Expand Down Expand Up @@ -165,6 +173,12 @@ def build_engine_from_onnx(
logger.info("FP16 mode enabled")
else:
logger.warning("FP16 not supported on this platform, using FP32")
elif config.precision == "bf16":
if hasattr(trt.BuilderFlag, "BF16"):
builder_config.set_flag(trt.BuilderFlag.BF16)
logger.info("BF16 mode enabled")
else:
logger.warning("BF16 not supported by this TensorRT version, using FP32")

# Configure dynamic shapes if needed
has_dynamic = any(d == -1 for d in input_shape)
Expand Down Expand Up @@ -231,13 +245,13 @@ def build_engine_from_onnx(
gpu_architecture=gpu_arch,
tensorrt_version=trt.__version__,
has_dynamic_shapes=has_dynamic or config.use_dynamic_shapes,
min_shape=(config.min_shape[1], config.min_shape[0])
min_shape=(1, input_channels, config.min_shape[0], config.min_shape[1])
if config.use_dynamic_shapes
else None,
opt_shape=(config.opt_shape[1], config.opt_shape[0])
opt_shape=(1, input_channels, config.opt_shape[0], config.opt_shape[1])
if config.use_dynamic_shapes
else None,
max_shape=(config.max_shape[1], config.max_shape[0])
max_shape=(1, input_channels, config.max_shape[0], config.max_shape[1])
if config.use_dynamic_shapes
else None,
)
Expand Down
4 changes: 2 additions & 2 deletions backend/src/nodes/impl/tensorrt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class TensorRTEngineInfo:
"""Metadata about a TensorRT engine."""

precision: Literal["fp32", "fp16", "int8"]
precision: Literal["fp32", "fp16", "bf16", "int8"]
input_channels: int
output_channels: int
scale: int | None
Expand All @@ -36,7 +36,7 @@ def __init__(self, engine_bytes: bytes, info: TensorRTEngineInfo):
self.info: TensorRTEngineInfo = info

@property
def precision(self) -> Literal["fp32", "fp16", "int8"]:
def precision(self) -> Literal["fp32", "fp16", "bf16", "int8"]:
return self.info.precision

@property
Expand Down
5 changes: 5 additions & 0 deletions backend/src/nodes/properties/inputs/tensorrt_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def TensorRTPrecisionDropdown() -> DropDownInput:
"value": "fp16",
"type": "TrtPrecision::fp16",
},
{
"option": "BF16",
"value": "bf16",
"type": "TrtPrecision::bf16",
},
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ def load_engine_node(
has_dynamic = any(d == -1 for d in input_shape)

# Detect precision from the engine
precision = (
"fp16"
if engine.get_tensor_dtype(input_name) == trt.DataType.HALF
else "fp32"
)
input_dtype = engine.get_tensor_dtype(input_name)
if input_dtype == trt.DataType.HALF:
precision = "fp16"
elif input_dtype == trt.DataType.BF16:
precision = "bf16"
else:
precision = "fp32"

tensor_profile_name = input_name
min_shape, opt_shape, max_shape = engine.get_tensor_profile_shape(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
class Precision(Enum):
FP32 = "fp32"
FP16 = "fp16"
BF16 = "bf16"


PRECISION_LABELS = {
Precision.FP32: "FP32 (Higher Precision)",
Precision.FP16: "FP16 (Faster)",
Precision.BF16: "BF16 (Balanced)",
}


Expand Down Expand Up @@ -62,6 +64,7 @@ class ShapeMode(Enum):
option_labels=PRECISION_LABELS,
).with_docs(
"FP16: lower precision but faster and uses less memory, especially on RTX GPUs. FP16 also does not work with certain models.",
"BF16: same exponent range as FP32 with reduced mantissa. Better numerical stability than FP16 while still being faster than FP32. Good for models that produce NaN/artifacts with FP16.",
"FP32: higher precision but slower. Use especially if FP16 fails.",
),
EnumInput(
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/chainner-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct TensorRTEngine {
maxWidth: int(1..) | null,
}

enum TrtPrecision { fp32, fp16 }
enum TrtPrecision { fp32, fp16, bf16 }
enum TrtShapeMode { fixed, dynamic }

def pytorchToOnnx(model: PyTorchModel): OnnxModel {
Expand Down