diff --git a/backend/src/nodes/impl/tensorrt/engine_builder.py b/backend/src/nodes/impl/tensorrt/engine_builder.py index 10b52ae29..a4da3df26 100644 --- a/backend/src/nodes/impl/tensorrt/engine_builder.py +++ b/backend/src/nodes/impl/tensorrt/engine_builder.py @@ -18,7 +18,7 @@ class BuildConfig: """Configuration for TensorRT engine building.""" - precision: Literal["fp32", "fp16"] + precision: Literal["fp32", "fp16", "bf16"] workspace_size_gb: float min_shape: tuple[int, int] # (height, width) opt_shape: tuple[int, int] # (height, width) @@ -93,6 +93,14 @@ def configure_builder_config( logger.info("FP16 mode enabled") else: logger.warning("FP16 not supported on this platform, falling back to FP32") + elif config.precision == "bf16": + if hasattr(trt.BuilderFlag, "BF16"): + builder_config.set_flag(trt.BuilderFlag.BF16) + logger.info("BF16 mode enabled") + else: + logger.warning( + "BF16 not supported by this TensorRT version, falling back to FP32" + ) return builder_config @@ -165,6 +173,12 @@ def build_engine_from_onnx( logger.info("FP16 mode enabled") else: logger.warning("FP16 not supported on this platform, using FP32") + elif config.precision == "bf16": + if hasattr(trt.BuilderFlag, "BF16"): + builder_config.set_flag(trt.BuilderFlag.BF16) + logger.info("BF16 mode enabled") + else: + logger.warning("BF16 not supported by this TensorRT version, using FP32") # Configure dynamic shapes if needed has_dynamic = any(d == -1 for d in input_shape) @@ -231,13 +245,13 @@ def build_engine_from_onnx( gpu_architecture=gpu_arch, tensorrt_version=trt.__version__, has_dynamic_shapes=has_dynamic or config.use_dynamic_shapes, - min_shape=(config.min_shape[1], config.min_shape[0]) + min_shape=(1, input_channels, config.min_shape[0], config.min_shape[1]) if config.use_dynamic_shapes else None, - opt_shape=(config.opt_shape[1], config.opt_shape[0]) + opt_shape=(1, input_channels, config.opt_shape[0], config.opt_shape[1]) if config.use_dynamic_shapes else None, - max_shape=(config.max_shape[1], config.max_shape[0]) + max_shape=(1, input_channels, config.max_shape[0], config.max_shape[1]) if config.use_dynamic_shapes else None, ) diff --git a/backend/src/nodes/impl/tensorrt/model.py b/backend/src/nodes/impl/tensorrt/model.py index 1a9b5ca3d..82ab26ddd 100644 --- a/backend/src/nodes/impl/tensorrt/model.py +++ b/backend/src/nodes/impl/tensorrt/model.py @@ -10,7 +10,7 @@ class TensorRTEngineInfo: """Metadata about a TensorRT engine.""" - precision: Literal["fp32", "fp16", "int8"] + precision: Literal["fp32", "fp16", "bf16", "int8"] input_channels: int output_channels: int scale: int | None @@ -36,7 +36,7 @@ def __init__(self, engine_bytes: bytes, info: TensorRTEngineInfo): self.info: TensorRTEngineInfo = info @property - def precision(self) -> Literal["fp32", "fp16", "int8"]: + def precision(self) -> Literal["fp32", "fp16", "bf16", "int8"]: return self.info.precision @property diff --git a/backend/src/nodes/properties/inputs/tensorrt_inputs.py b/backend/src/nodes/properties/inputs/tensorrt_inputs.py index 11cc44061..5dbc7b428 100644 --- a/backend/src/nodes/properties/inputs/tensorrt_inputs.py +++ b/backend/src/nodes/properties/inputs/tensorrt_inputs.py @@ -33,6 +33,11 @@ def TensorRTPrecisionDropdown() -> DropDownInput: "value": "fp16", "type": "TrtPrecision::fp16", }, + { + "option": "BF16", + "value": "bf16", + "type": "TrtPrecision::bf16", + }, ], ) diff --git a/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py b/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py index 4fcff48e9..f105fea06 100644 --- a/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py +++ b/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py @@ -99,11 +99,13 @@ def load_engine_node( has_dynamic = any(d == -1 for d in input_shape) # Detect precision from the engine - precision = ( - "fp16" - if engine.get_tensor_dtype(input_name) == trt.DataType.HALF - else "fp32" - ) + input_dtype = engine.get_tensor_dtype(input_name) + if input_dtype == trt.DataType.HALF: + precision = "fp16" + elif input_dtype == trt.DataType.BF16: + precision = "bf16" + else: + precision = "fp32" tensor_profile_name = input_name min_shape, opt_shape, max_shape = engine.get_tensor_profile_shape( diff --git a/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py b/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py index 391b2f7f3..ccbf3aba8 100644 --- a/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py +++ b/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py @@ -22,11 +22,13 @@ class Precision(Enum): FP32 = "fp32" FP16 = "fp16" + BF16 = "bf16" PRECISION_LABELS = { Precision.FP32: "FP32 (Higher Precision)", Precision.FP16: "FP16 (Faster)", + Precision.BF16: "BF16 (Balanced)", } @@ -62,6 +64,7 @@ class ShapeMode(Enum): option_labels=PRECISION_LABELS, ).with_docs( "FP16: lower precision but faster and uses less memory, especially on RTX GPUs. FP16 also does not work with certain models.", + "BF16: same exponent range as FP32 with reduced mantissa. Better numerical stability than FP16 while still being faster than FP32. Good for models that produce NaN/artifacts with FP16.", "FP32: higher precision but slower. Use especially if FP16 fails.", ), EnumInput( diff --git a/src/common/types/chainner-scope.ts b/src/common/types/chainner-scope.ts index 8145175a7..2acc4eb44 100644 --- a/src/common/types/chainner-scope.ts +++ b/src/common/types/chainner-scope.ts @@ -105,7 +105,7 @@ struct TensorRTEngine { maxWidth: int(1..) | null, } -enum TrtPrecision { fp32, fp16 } +enum TrtPrecision { fp32, fp16, bf16 } enum TrtShapeMode { fixed, dynamic } def pytorchToOnnx(model: PyTorchModel): OnnxModel {