From 7b6bc55f7709f9dbbd3c6a5247fb32040df99016 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 30 Dec 2025 12:25:56 +0100 Subject: [PATCH 1/2] Arm backend: Support a8w4 for convolution and linear. Int4 is added as a TosaSpecialDtype. Since torch.dtype doesn't have int4, the dtype is carried in an int8, quantized with qmin=-7, qmax=7. This can then be detected when folding qparams to add the tosa_special_dtype meta. Some additional changes are needed to make sure the meta survives all the way to TOSA. Tests are added to conv2d, conv3d, depthwise conv, and linear. One conv2d test case was actually dw conv, so it was modified. Signed-off-by: Erik Lundell Change-Id: Iedcc45f2e419c261fd3205981fa632736c69524d --- .../fold_qdq_with_annotated_qparams_pass.py | 15 ++- .../arm/_passes/fuse_constant_ops_pass.py | 22 +++- .../_passes/fuse_equal_placeholders_pass.py | 6 +- backends/arm/ethosu/compile_spec.py | 4 +- backends/arm/operators/ops_identity.py | 4 +- backends/arm/process_node.py | 17 +-- backends/arm/quantizer/arm_quantizer.py | 10 +- backends/arm/test/ops/test_conv2d.py | 81 ++++++++++++- backends/arm/test/ops/test_conv3d.py | 66 ++++++++++- backends/arm/test/ops/test_depthwise_conv.py | 107 +++++++++++++++++- backends/arm/test/ops/test_linear.py | 32 +++++- backends/arm/tosa/mapping.py | 37 +++++- 12 files changed, 366 insertions(+), 35 deletions(-) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 0f99d6cbbdf..0ecb7ff2070 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -24,6 +24,7 @@ from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -32,6 +33,13 @@ from torch.fx import GraphModule, Node +def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None: + if qspec.dtype == torch.int8: + if qspec.qmax == 7 and qspec.qmin == -7: + return TosaSpecialDtype.INT4 + return None + + def get_input_qparams(node: Node) -> dict[int, QuantArgs]: """ Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. @@ -157,6 +165,11 @@ def fold_and_annotate_arg( node.replace_input_with(n, cast(Node, n.args[0])) if len(n.users) == 0: graph_module.graph.erase_node(n) + special_dtype = _get_special_dtype(input_qparams) + if special_dtype: + node.all_input_nodes[i].meta[ + TosaSpecialDtype.meta_key() + ] = special_dtype def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): """Fold outmost quant nodes inside submodule. diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 79ce4ec8848..c29603d0b4c 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,6 +18,7 @@ from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, @@ -52,6 +53,23 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.exported_program = exported_program + def _propagate_special_dtype(self, from_nodes, to_node, data): + """Propagate special dtype meta if it exists.""" + special_dtypes = set() + for input_node in from_nodes: + special_type = input_node.meta.get(TosaSpecialDtype.meta_key(), None) + if special_type: + special_dtypes.add(special_type) + if len(special_dtypes) > 1: + logger.warning( + "Propagating mixed special dtypes is not implemented, skipping." + ) + elif len(special_dtypes) == 1: + special_dtype = list(special_dtypes)[0] + # Make sure data is still within special dtype range. + if data.abs().max() <= special_dtype.max(): + to_node.meta[TosaSpecialDtype.meta_key()] = special_dtype + def _fuse_nodes(self, node) -> bool: """ Takes a node with only parameter inputs and replaces it with one constant tensor node with @@ -105,6 +123,8 @@ def resolve_arg(arg): persistent_buffer=persistent_buffer, ) + self._propagate_special_dtype(input_nodes, const_node, data) + node.replace_all_uses_with(const_node) return True diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 6c3f9dde99e..37cac8a8c56 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -53,11 +53,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # ensure we don't merge any special case int48_t tensors with int32_t tensors # since int48_t tensors needs to be instantiated separately. - is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None) + is_special_dtype = node.meta.get(TosaSpecialDtype.meta_key(), None) t_cpu = tensor.detach().cpu().contiguous() data_bytes = t_cpu.numpy().tobytes() key = ( - is_int48, + is_special_dtype, str(t_cpu.dtype), tuple(t_cpu.shape), hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(), diff --git a/backends/arm/ethosu/compile_spec.py b/backends/arm/ethosu/compile_spec.py index 8f6d6284f74..1d311cbf74c 100644 --- a/backends/arm/ethosu/compile_spec.py +++ b/backends/arm/ethosu/compile_spec.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -73,7 +73,7 @@ def __init__( compiler_flags.append(f"--memory-mode={memory_mode}") # Set TOSA version. - base_tosa_version = "TOSA-1.0+INT+int16" + base_tosa_version = "TOSA-1.0+INT+int16+int4" if "u55" in target_lower: # Add the Ethos-U55 extension marker base_tosa_version += "+u55" diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index a7ffd4eacca..0930d7e7997 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -53,6 +53,8 @@ def define_node( supported_dtypes += [ts.DType.FP32] if self.tosa_spec.support_extension("int16"): supported_dtypes += [ts.DType.INT48] + if self.tosa_spec.support_extension("int4"): + supported_dtypes += [ts.DType.INT4] validate_valid_dtype( self.target, [inputs[0], output], diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 5a1d563ee0b..b85b1b43013 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -12,7 +12,7 @@ import torch.fx import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype +from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape from torch._export.utils import ( @@ -116,21 +116,10 @@ def process_inputs_to_parameters( ) parameter_values = parameter_data.detach().numpy() - if tosa_arg.dtype == torch.float32: - if not tosa_spec.support_float(): - raise ValueError(f"{tosa_spec} doesn't support float operations") - - # Handle special case for INT48 tensors - special_type = node.meta.get(TosaSpecialDtype.meta_key(), None) - if isinstance(special_type, TosaSpecialDtype): - tosa_dtype = special_type.get_tosa_dtype() - else: - tosa_dtype = tosa_arg.dtype - parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) tosa_graph.addConst( - parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name + parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 425fea0987b..28cef0d95ca 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -180,6 +180,14 @@ def get_symmetric_quantization_config( return quantization_config +def get_symmetric_a8w4_quantization_config( + is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False +): + return get_symmetric_quantization_config( + is_per_channel, is_qat, is_dynamic, weight_qmin=-7, weight_qmax=7 + ) + + @functools.lru_cache def get_symmetric_a16w8_quantization_config( is_per_channel: bool = True, diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 2b86ea6a5c4..a8cd21058f9 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,6 +7,9 @@ from typing import List, Tuple, Union import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a8w4_quantization_config, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -17,6 +20,7 @@ VgfPipeline, ) + aten_op = "torch.ops.aten.conv2d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" @@ -162,8 +166,8 @@ def forward(self, x): batches=1, ) -conv2d_2x2_1x1x14x13_st2 = Conv2d( - in_channels=1, +conv2d_2x2_2x1x14x13_st2 = Conv2d( + in_channels=2, out_channels=1, kernel_size=(2, 2), stride=2, @@ -363,7 +367,7 @@ def forward(self, x): "3x3_1x3x24x24_st1": lambda: conv2d_3x3_1x3x24x24_st1, "3x3_1x3x12x12_st2_pd1": lambda: conv2d_3x3_1x3x12x12_st2_pd1, "1x1_1x2x16x16_st1": lambda: conv2d_1x1_1x2x16x16_st1, - "2x2_1x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_1x1x14x13_st2, + "2x2_2x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_2x1x14x13_st2, "5x5_1x3x14x15_st3_pd1_needs_adjust_pass": lambda: conv2d_5x5_1x3x14x15_st3_pd1, "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": lambda: conv2d_7x7_1x3x16x16_st2_pd1_dl2, "7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass": lambda: conv2d_7x7_1x3x15x15_st1_pd0_dl1, @@ -391,6 +395,15 @@ def forward(self, x): input_t = Tuple[torch.Tensor] +def _get_dtype_count(model: torch.nn.Module): + nbr_convs: int = model.nbr_convs # noqa + return { + "CONST": {"INT4": nbr_convs * 2}, # One for the weight, one for the zp. + "CONV2D": {"INT32": nbr_convs}, + "RESCALE": {"INT8": nbr_convs}, + } + + @common.parametrize("test_data", test_data_FP) def test_convolution_2d_tosa_FP(test_data): model = test_data() @@ -417,6 +430,36 @@ def test_convolution_2d_tosa_INT(test_data): pipeline.run() +@common.parametrize( + "test_data", + test_data_INT, + xfails={ + "groups,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726", + "groups,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726", + "groups_bias,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726", + "groups_bias,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726", + }, +) +def test_convolution_2d_tosa_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + tosa_extensions=["int4"], + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + _get_dtype_count(model), + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT) @common.XfailIfNoCorstone300 def test_convolution_2d_u55_INT(test_data): @@ -431,6 +474,21 @@ def test_convolution_2d_u55_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +def test_convolution_2d_u55_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT) @common.XfailIfNoCorstone320 def test_convolution_u85_INT(test_data): @@ -445,6 +503,21 @@ def test_convolution_u85_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +def test_convolution_2d_u85_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_FP) @common.SkipIfNoModelConverter def test_convolution_2d_vgf_no_quant(test_data): diff --git a/backends/arm/test/ops/test_conv3d.py b/backends/arm/test/ops/test_conv3d.py index 9c831c9ba49..e020ea0c5ac 100644 --- a/backends/arm/test/ops/test_conv3d.py +++ b/backends/arm/test/ops/test_conv3d.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, + get_symmetric_a8w4_quantization_config, TOSAQuantizer, ) from executorch.backends.arm.test import common, conftest @@ -430,6 +431,15 @@ def forward(self, x): } +def _get_dtype_count(model: torch.nn.Module): + nbr_convs: int = model.nbr_convs # noqa + return { + "CONST": {"INT4": nbr_convs * 2}, + "CONV3D": {"INT32": nbr_convs}, + "RESCALE": {"INT8": nbr_convs}, + } + + def get_symmetric_a16w8_conv3d_quantizer(per_channel_quantization: bool = False): tosa_version = conftest.get_option("tosa_version") tosa_profiles = { @@ -474,6 +484,28 @@ def test_convolution_3d_tosa_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +def test_convolution_3d_tosa_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + tosa_extensions=["int4"], + qtol=1, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + _get_dtype_count(model), + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT16) def test_convolution_3d_tosa_INT_a16w8(test_data): model, per_channel_quantization = test_data() @@ -543,6 +575,22 @@ def test_convolution_3d_u55_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +@pytest.mark.skip(reason="Ethos-U55 does not support CONV3D yet.") +def test_convolution_3d_u55_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT) @pytest.mark.skip(reason="Ethos-U85 does not support CONV3D yet.") def test_convolution_3d_u85_INT(test_data): @@ -557,6 +605,22 @@ def test_convolution_3d_u85_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +@pytest.mark.skip(reason="Ethos-U85 does not support CONV3D yet.") +def test_convolution_3d_u85_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_FP) @common.SkipIfNoModelConverter def test_convolution_3d_vgf_no_quant(test_data): diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 017993e737b..166724ef69b 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,6 +8,9 @@ import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a8w4_quantization_config, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -198,6 +201,15 @@ } +def _get_dtype_count(model: torch.nn.Module): + nbr_convs: int = model.nbr_convs # noqa + return { + "CONST": {"INT4": nbr_convs * 2}, + "DEPTHWISE_CONV2D": {"INT32": nbr_convs}, + "RESCALE": {"INT8": nbr_convs}, + } + + @common.parametrize("test_data", test_data_conv1d_FP | test_data_conv2d_FP) def test_convolution_2d_tosa_FP_depthwise(test_data: torch.nn.Module): pipeline = TosaPipelineFP[input_t]( @@ -223,6 +235,27 @@ def test_convolution_2d_tosa_INT_depthwise(test_data): pipeline.run() +@common.parametrize("test_data", test_data_conv1d_INT | test_data_conv2d_INT) +def test_convolution_2d_tosa_INT_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op=[], + exir_op=exir_op, + tosa_extensions=["int4"], + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + _get_dtype_count(model), + ) + pipeline.run() + + @common.parametrize("test_data", test_data_conv1d_FP | test_data_conv2d_FP) @common.SkipIfNoModelConverter def test_convolution_2d_vgf_no_quant_depthwise(test_data: torch.nn.Module): @@ -251,7 +284,7 @@ def test_convolution_2d_vgf_quant_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone300 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone300 @common.parametrize("test_data", test_data_conv2d_INT) def test_convolution_2d_u55_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -265,7 +298,23 @@ def test_convolution_2d_u55_INT_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone300 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_conv2d_INT) +def test_convolution_2d_u55_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 @common.parametrize("test_data", test_data_conv1d_INT) def test_convolution_1d_u55_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -279,7 +328,23 @@ def test_convolution_1d_u55_INT_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone320 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_conv1d_INT) +def test_convolution_1d_u55_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 @common.parametrize("test_data", test_data_conv2d_INT) def test_convolution_2d_u85_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -293,7 +358,23 @@ def test_convolution_2d_u85_INT_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone320 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_conv2d_INT) +def test_convolution_2d_u85_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 @common.parametrize("test_data", test_data_conv1d_INT) def test_convolution_1d_u85_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -305,3 +386,19 @@ def test_convolution_1d_u85_INT_depthwise(test_data): per_channel_quantization=per_channel_quantization, ) pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_conv1d_INT) +def test_convolution_1d_u85_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 77e512cdf2f..7e22ad304e4 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,6 +11,7 @@ import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, + get_symmetric_a8w4_quantization_config, TOSAQuantizer, ) from executorch.backends.arm.test import common, conftest @@ -166,6 +167,35 @@ def test_linear_tosa_INT(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT) +def test_linear_tosa_INT_a8w4(test_data: torch.Tensor): + test_data, out_features, has_bias, per_channel_quantization = test_data() + in_features = test_data.shape[-1] + pipeline = TosaPipelineINT[input_t1]( + Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ), + (test_data,), + aten_op, + tosa_extensions=["int4"], + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + { + "CONST": {"INT4": 2}, + "CONV2D": {"INT32": 1}, + "RESCALE": {"INT8": 1}, + }, + ) + pipeline.run() + + @common.parametrize("test_data", test_data_rank1_INT) @common.XfailIfNoCorstone300 def test_linear_u55_INT(test_data: torch.Tensor): diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index ca83c6c09ea..c11a046cd66 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -36,6 +36,7 @@ class TosaSpecialDtype(Enum): """Special TOSA dtypes not natively expressed in PyTorch.""" INT48 = ts.DType.INT48 + INT4 = ts.DType.INT4 def get_tosa_dtype(self) -> ts.DType: """Return the underlying ``ts.DType`` enumerant. @@ -56,6 +57,24 @@ def meta_key() -> str: """ return "tosa_special_dtype" + def max(self): + match self: + case self.INT4: + return 7 + case self.INT48: + return 2**47 - 1 + case _: + raise ValueError(f"Unrecognized TosaSpecialDtype {self}.") + + def min(self): + match self: + case self.INT4: + return -7 + case self.INT48: + return -(2**47) + case _: + raise ValueError(f"Unrecognized TosaSpecialDtype {self}.") + def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: """Map a ``torch.dtype`` to a ``ts.DType``. @@ -180,6 +199,11 @@ def __process_node(self, argument: torch.fx.Node): else: self.multiple_output_names = [] + if not self.__validate(): + raise ValueError( + f"{self.tosa_spec} doesn't support tensor {self.__repr__()}" + ) + def __process_list(self, argument): """Capture a sequence argument as ``special``. @@ -198,6 +222,17 @@ def __process_number(self, argument: float | int): """ self.number: float | int = argument + def __validate(self) -> bool: + match getattr(self, "dtype", None): + case ts.DType.FP32: + if not self.tosa_spec.support_float(): + return False + case ts.DType.INT4: + if not self.tosa_spec.support_extension("int4"): + return False + + return True + def __init__( self, argument: Any, tosa_spec: Optional[TosaSpecification] = None ) -> None: From 4d157fc664ae84c89177e43fd53b5b24cee17e4b Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 14 Jan 2026 09:50:20 +0100 Subject: [PATCH 2/2] Arm backend: Fix test names Signed-off-by: Erik Lundell Change-Id: Id0cd35dead7558346424bba9e7712b36e90f5429 --- backends/arm/test/ops/test_conv2d.py | 4 ++-- backends/arm/test/ops/test_conv3d.py | 4 ++-- backends/arm/test/ops/test_depthwise_conv.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index a8cd21058f9..55eee293f95 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -475,7 +475,7 @@ def test_convolution_2d_u55_INT(test_data): @common.parametrize("test_data", test_data_INT) -def test_convolution_2d_u55_a8w4(test_data): +def test_convolution_2d_u55_INT_a8w4(test_data): model, per_channel_quantization = test_data() pipeline = EthosU55PipelineINT[input_t]( model, @@ -504,7 +504,7 @@ def test_convolution_u85_INT(test_data): @common.parametrize("test_data", test_data_INT) -def test_convolution_2d_u85_a8w4(test_data): +def test_convolution_2d_u85_INT_a8w4(test_data): model, per_channel_quantization = test_data() pipeline = EthosU85PipelineINT[input_t]( model, diff --git a/backends/arm/test/ops/test_conv3d.py b/backends/arm/test/ops/test_conv3d.py index e020ea0c5ac..f28315dcdae 100644 --- a/backends/arm/test/ops/test_conv3d.py +++ b/backends/arm/test/ops/test_conv3d.py @@ -577,7 +577,7 @@ def test_convolution_3d_u55_INT(test_data): @common.parametrize("test_data", test_data_INT) @pytest.mark.skip(reason="Ethos-U55 does not support CONV3D yet.") -def test_convolution_3d_u55_a8w4(test_data): +def test_convolution_3d_u55_INT_a8w4(test_data): model, per_channel_quantization = test_data() pipeline = EthosU55PipelineINT[input_t]( model, @@ -607,7 +607,7 @@ def test_convolution_3d_u85_INT(test_data): @common.parametrize("test_data", test_data_INT) @pytest.mark.skip(reason="Ethos-U85 does not support CONV3D yet.") -def test_convolution_3d_u85_a8w4(test_data): +def test_convolution_3d_u85_INT_a8w4(test_data): model, per_channel_quantization = test_data() pipeline = EthosU85PipelineINT[input_t]( model, diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 166724ef69b..b4289f922ce 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -300,7 +300,7 @@ def test_convolution_2d_u55_INT_depthwise(test_data): @common.XfailIfNoCorstone300 @common.parametrize("test_data", test_data_conv2d_INT) -def test_convolution_2d_u55_a8w4_depthwise(test_data): +def test_convolution_2d_u55_INT_a8w4_depthwise(test_data): model, per_channel_quantization = test_data() pipeline = EthosU55PipelineINT[input_t]( model, @@ -330,7 +330,7 @@ def test_convolution_1d_u55_INT_depthwise(test_data): @common.XfailIfNoCorstone300 @common.parametrize("test_data", test_data_conv1d_INT) -def test_convolution_1d_u55_a8w4_depthwise(test_data): +def test_convolution_1d_u55_INT_a8w4_depthwise(test_data): model, per_channel_quantization = test_data() pipeline = EthosU55PipelineINT[input_t]( model, @@ -360,7 +360,7 @@ def test_convolution_2d_u85_INT_depthwise(test_data): @common.XfailIfNoCorstone320 @common.parametrize("test_data", test_data_conv2d_INT) -def test_convolution_2d_u85_a8w4_depthwise(test_data): +def test_convolution_2d_u85_INT_a8w4_depthwise(test_data): model, per_channel_quantization = test_data() pipeline = EthosU85PipelineINT[input_t]( model, @@ -390,7 +390,7 @@ def test_convolution_1d_u85_INT_depthwise(test_data): @common.XfailIfNoCorstone320 @common.parametrize("test_data", test_data_conv1d_INT) -def test_convolution_1d_u85_a8w4_depthwise(test_data): +def test_convolution_1d_u85_INT_a8w4_depthwise(test_data): model, per_channel_quantization = test_data() pipeline = EthosU85PipelineINT[input_t]( model,