From cb3142059a1a92ef3dad9c4f180b80d9de3cbe48 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 28 Mar 2025 16:49:59 -0700 Subject: [PATCH] [XNNPACK] Serialize fp16 weights as fp16 --- backends/xnnpack/operators/node_visitor.py | 20 +++++++++----------- backends/xnnpack/operators/op_conv2d.py | 6 +++--- backends/xnnpack/operators/op_linear.py | 9 +++++++-- backends/xnnpack/test/ops/test_linear.py | 8 +++----- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 0185f18d249..50781eade4d 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -210,7 +210,7 @@ def get_serialized_dtype( self, quant_params: Optional[QuantParams], node: torch.fx.Node, - fp32_static_weight: bool = False, + force_fp32: bool = False, ) -> XNNDatatype: # Default initialization dtype = XNNDatatype.xnn_datatype_fp32 @@ -267,7 +267,7 @@ def get_per_channel_dtype( if node_dtype is not None and node_dtype == torch.float16: dtype = ( XNNDatatype.xnn_datatype_fp32 - if fp32_static_weight + if force_fp32 else XNNDatatype.xnn_datatype_fp16 ) @@ -348,7 +348,7 @@ def define_tensor( # noqa: C901 convert_to_nhwc: bool = False, swap_in_out_for_weights: bool = False, quant_params: Optional[QuantParams] = None, - fp32_static_weights: bool = False, + force_fp32: bool = False, groups: int = 1, ) -> None: """ @@ -368,7 +368,7 @@ def define_tensor( # noqa: C901 constant data. If used along with convert_to_nhwc, this swap will happen before converting to nhwc. quant_params: Quantization meta data for this tensor, None if it is not quantized - fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv + force_fp32: forces tensor to be serialize as fp32, used for bias of dynamically quantized ops groups: number of groups for swap_in_out_for_weights """ @@ -405,7 +405,7 @@ def define_tensor( # noqa: C901 convert_to_nhwc, swap_in_out_for_weights, quant_params, - fp32_static_weights, + force_fp32, groups, ) @@ -417,9 +417,7 @@ def define_tensor( # noqa: C901 check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor") dims = [dims[i] for i in PERM_NCHW_TO_NHWC] - dtype = self.get_serialized_dtype( - quant_params, tensor, fp32_static_weight=fp32_static_weights - ) + dtype = self.get_serialized_dtype(quant_params, tensor, force_fp32=force_fp32) tvalue = XNNTensorValue( datatype=dtype, @@ -504,7 +502,7 @@ def get_serialized_buffer_index( convert_to_nhwc: bool, swap_in_out_for_weights: bool, quant_params: Optional[QuantParams], - fp32_static_weights: bool = False, + force_fp32: bool = False, groups: int = 1, ) -> int: """ @@ -525,7 +523,7 @@ def get_serialized_buffer_index( constant data. If used along with convert_to_nhwc, this swap will happen before converting to nhwc. quant_params: Quantization meta data for this tensor, None if it is not quantize - fp32_static_weights: bool to indicate whether tensor is fp32 static weights + force_fp32: bool to indicate whether tensor is fp32 static weights groups: groups for swap_in_out_for_weights Returns: @@ -554,7 +552,7 @@ def get_serialized_buffer_index( # Quantize buffer if static data is indeed quantized if quant_params is not None and not quant_params.is_dynamic: const_val = quant_params.quantize_tensor(const_val).contiguous() - elif const_val.dtype != torch.float16 or fp32_static_weights: + elif const_val.dtype != torch.float16 or force_fp32: # ensure that the const is fp32 const_val = const_val.to(dtype=torch.float32).contiguous() diff --git a/backends/xnnpack/operators/op_conv2d.py b/backends/xnnpack/operators/op_conv2d.py index 1272f1b5250..8cecb3c62ad 100644 --- a/backends/xnnpack/operators/op_conv2d.py +++ b/backends/xnnpack/operators/op_conv2d.py @@ -82,7 +82,6 @@ def define_node( weight_quant_params = QuantParams.from_weights( kernel_node, self._exported_program ) - fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16 if weight_quant_params is not None and weight_quant_params.per_channel: if is_transpose: @@ -102,8 +101,8 @@ def define_node( convert_to_nhwc=True, swap_in_out_for_weights=is_depthwise_conv or is_transpose, quant_params=weight_quant_params, - fp32_static_weights=fp32_static_weights, groups=groups if is_transpose else 1, + force_fp32=True, ) kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)] @@ -127,13 +126,14 @@ def define_node( bias_quant_params = QuantParams.from_bias( bias_node, weight_quant_params, input_quant_params ) + self.define_tensor( get_input_node(node, 2), xnn_graph, vals_to_ids, convert_to_nhwc=False, quant_params=bias_quant_params, - fp32_static_weights=fp32_static_weights, + force_fp32=True, ) kwargs["bias_id"] = vals_to_ids[get_input_node(node, 2)] diff --git a/backends/xnnpack/operators/op_linear.py b/backends/xnnpack/operators/op_linear.py index 560f7d1a516..dda1d3e53ef 100644 --- a/backends/xnnpack/operators/op_linear.py +++ b/backends/xnnpack/operators/op_linear.py @@ -59,7 +59,6 @@ def define_node( xnn_graph, vals_to_ids, quant_params=weight_quant_params, - fp32_static_weights=True, ) filter_id = vals_to_ids[weight_node] @@ -69,12 +68,18 @@ def define_node( bias_quant_params = QuantParams.from_bias( bias_node, weight_quant_params, input_quant_params ) + # For dynamic quantization, there are no kernels with fp16 bias + # So we need to force the fp16 bias to fp32 + force_fp32 = False + if input_quant_params is not None and input_quant_params.is_dynamic: + force_fp32 = True + self.define_tensor( get_input_node(node, 2), xnn_graph, vals_to_ids, quant_params=bias_quant_params, - fp32_static_weights=True, + force_fp32=force_fp32, ) bias_id = vals_to_ids[bias_node] else: diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 849a1b237e8..bcaf2e82a08 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -605,9 +605,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo if legacy_partitioner: tester.to_edge() - tester.partition( - Partition(DynamicallyQuantizedPartitioner) - ).dump_artifact() + tester.partition(Partition(DynamicallyQuantizedPartitioner)) # should have [add]mm node if uses_bias: tester.check( @@ -624,7 +622,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo else: tester.to_edge_transform_and_lower( ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner]) - ).dump_artifact() + ) # should not have a delegate node tester.check_not( [ @@ -717,7 +715,7 @@ def test_fp16_linear(self): num_batch_dims=num_batch_dims, uses_bias=use_bias, dtype=torch.float16, - atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs + atol=5e-3, # TODO(T212995726): Investigate right atol for rand[n] inputs ) def test_fp32_linear(self):