From 6adb52818f112c1ceacc801e2908fde731a2c850 Mon Sep 17 00:00:00 2001 From: Joey Ballentine Date: Sun, 22 Feb 2026 18:36:25 -0600 Subject: [PATCH 1/2] Add support for BF16 with TensorRT --- .../src/nodes/impl/tensorrt/engine_builder.py | 24 +++++++++++++++---- backend/src/nodes/impl/tensorrt/model.py | 4 ++-- .../properties/inputs/tensorrt_inputs.py | 5 ++++ .../tensorrt/io/load_engine.py | 12 ++++++---- .../tensorrt/utility/build_engine.py | 3 +++ src/common/types/chainner-scope.ts | 2 +- 6 files changed, 38 insertions(+), 12 deletions(-) diff --git a/backend/src/nodes/impl/tensorrt/engine_builder.py b/backend/src/nodes/impl/tensorrt/engine_builder.py index 10b52ae29..6c9a757dd 100644 --- a/backend/src/nodes/impl/tensorrt/engine_builder.py +++ b/backend/src/nodes/impl/tensorrt/engine_builder.py @@ -18,7 +18,7 @@ class BuildConfig: """Configuration for TensorRT engine building.""" - precision: Literal["fp32", "fp16"] + precision: Literal["fp32", "fp16", "bf16"] workspace_size_gb: float min_shape: tuple[int, int] # (height, width) opt_shape: tuple[int, int] # (height, width) @@ -93,6 +93,14 @@ def configure_builder_config( logger.info("FP16 mode enabled") else: logger.warning("FP16 not supported on this platform, falling back to FP32") + elif config.precision == "bf16": + if hasattr(trt.BuilderFlag, "BF16"): + builder_config.set_flag(trt.BuilderFlag.BF16) + logger.info("BF16 mode enabled") + else: + logger.warning( + "BF16 not supported by this TensorRT version, falling back to FP32" + ) return builder_config @@ -165,6 +173,14 @@ def build_engine_from_onnx( logger.info("FP16 mode enabled") else: logger.warning("FP16 not supported on this platform, using FP32") + elif config.precision == "bf16": + if hasattr(trt.BuilderFlag, "BF16"): + builder_config.set_flag(trt.BuilderFlag.BF16) + logger.info("BF16 mode enabled") + else: + logger.warning( + "BF16 not supported by this TensorRT version, using FP32" + ) # Configure dynamic shapes if needed has_dynamic = any(d == -1 for d in input_shape) @@ -231,13 +247,13 @@ def build_engine_from_onnx( gpu_architecture=gpu_arch, tensorrt_version=trt.__version__, has_dynamic_shapes=has_dynamic or config.use_dynamic_shapes, - min_shape=(config.min_shape[1], config.min_shape[0]) + min_shape=(1, input_channels, config.min_shape[0], config.min_shape[1]) if config.use_dynamic_shapes else None, - opt_shape=(config.opt_shape[1], config.opt_shape[0]) + opt_shape=(1, input_channels, config.opt_shape[0], config.opt_shape[1]) if config.use_dynamic_shapes else None, - max_shape=(config.max_shape[1], config.max_shape[0]) + max_shape=(1, input_channels, config.max_shape[0], config.max_shape[1]) if config.use_dynamic_shapes else None, ) diff --git a/backend/src/nodes/impl/tensorrt/model.py b/backend/src/nodes/impl/tensorrt/model.py index 1a9b5ca3d..82ab26ddd 100644 --- a/backend/src/nodes/impl/tensorrt/model.py +++ b/backend/src/nodes/impl/tensorrt/model.py @@ -10,7 +10,7 @@ class TensorRTEngineInfo: """Metadata about a TensorRT engine.""" - precision: Literal["fp32", "fp16", "int8"] + precision: Literal["fp32", "fp16", "bf16", "int8"] input_channels: int output_channels: int scale: int | None @@ -36,7 +36,7 @@ def __init__(self, engine_bytes: bytes, info: TensorRTEngineInfo): self.info: TensorRTEngineInfo = info @property - def precision(self) -> Literal["fp32", "fp16", "int8"]: + def precision(self) -> Literal["fp32", "fp16", "bf16", "int8"]: return self.info.precision @property diff --git a/backend/src/nodes/properties/inputs/tensorrt_inputs.py b/backend/src/nodes/properties/inputs/tensorrt_inputs.py index 11cc44061..5dbc7b428 100644 --- a/backend/src/nodes/properties/inputs/tensorrt_inputs.py +++ b/backend/src/nodes/properties/inputs/tensorrt_inputs.py @@ -33,6 +33,11 @@ def TensorRTPrecisionDropdown() -> DropDownInput: "value": "fp16", "type": "TrtPrecision::fp16", }, + { + "option": "BF16", + "value": "bf16", + "type": "TrtPrecision::bf16", + }, ], ) diff --git a/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py b/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py index 4fcff48e9..f105fea06 100644 --- a/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py +++ b/backend/src/packages/chaiNNer_tensorrt/tensorrt/io/load_engine.py @@ -99,11 +99,13 @@ def load_engine_node( has_dynamic = any(d == -1 for d in input_shape) # Detect precision from the engine - precision = ( - "fp16" - if engine.get_tensor_dtype(input_name) == trt.DataType.HALF - else "fp32" - ) + input_dtype = engine.get_tensor_dtype(input_name) + if input_dtype == trt.DataType.HALF: + precision = "fp16" + elif input_dtype == trt.DataType.BF16: + precision = "bf16" + else: + precision = "fp32" tensor_profile_name = input_name min_shape, opt_shape, max_shape = engine.get_tensor_profile_shape( diff --git a/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py b/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py index 391b2f7f3..ccbf3aba8 100644 --- a/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py +++ b/backend/src/packages/chaiNNer_tensorrt/tensorrt/utility/build_engine.py @@ -22,11 +22,13 @@ class Precision(Enum): FP32 = "fp32" FP16 = "fp16" + BF16 = "bf16" PRECISION_LABELS = { Precision.FP32: "FP32 (Higher Precision)", Precision.FP16: "FP16 (Faster)", + Precision.BF16: "BF16 (Balanced)", } @@ -62,6 +64,7 @@ class ShapeMode(Enum): option_labels=PRECISION_LABELS, ).with_docs( "FP16: lower precision but faster and uses less memory, especially on RTX GPUs. FP16 also does not work with certain models.", + "BF16: same exponent range as FP32 with reduced mantissa. Better numerical stability than FP16 while still being faster than FP32. Good for models that produce NaN/artifacts with FP16.", "FP32: higher precision but slower. Use especially if FP16 fails.", ), EnumInput( diff --git a/src/common/types/chainner-scope.ts b/src/common/types/chainner-scope.ts index 8145175a7..2acc4eb44 100644 --- a/src/common/types/chainner-scope.ts +++ b/src/common/types/chainner-scope.ts @@ -105,7 +105,7 @@ struct TensorRTEngine { maxWidth: int(1..) | null, } -enum TrtPrecision { fp32, fp16 } +enum TrtPrecision { fp32, fp16, bf16 } enum TrtShapeMode { fixed, dynamic } def pytorchToOnnx(model: PyTorchModel): OnnxModel { From 2ca3b7939f5b68880dc48275314c82bb8bcc2be7 Mon Sep 17 00:00:00 2001 From: Joey Ballentine Date: Sun, 22 Feb 2026 18:39:14 -0600 Subject: [PATCH 2/2] format --- backend/src/nodes/impl/tensorrt/engine_builder.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/src/nodes/impl/tensorrt/engine_builder.py b/backend/src/nodes/impl/tensorrt/engine_builder.py index 6c9a757dd..a4da3df26 100644 --- a/backend/src/nodes/impl/tensorrt/engine_builder.py +++ b/backend/src/nodes/impl/tensorrt/engine_builder.py @@ -178,9 +178,7 @@ def build_engine_from_onnx( builder_config.set_flag(trt.BuilderFlag.BF16) logger.info("BF16 mode enabled") else: - logger.warning( - "BF16 not supported by this TensorRT version, using FP32" - ) + logger.warning("BF16 not supported by this TensorRT version, using FP32") # Configure dynamic shapes if needed has_dynamic = any(d == -1 for d in input_shape)