Skip to content

Commit a9d8230

Browse files
authored
Arm backend: Support a8w4 for convolution and linear. (#16577)
Int4 is added as a TosaSpecialDtype. Since torch.dtype doesn't have int4, the dtype is carried in an int8, quantized with qmin=-7, qmax=7. This can then be detected when folding qparams to add the tosa_special_dtype meta. Some additional changes are needed to make sure the meta survives all the way to TOSA. Tests are added to conv2d, conv3d, depthwise conv, and linear. One conv2d test case was actually dw conv, so it was modified. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent cccf977 commit a9d8230

12 files changed

Lines changed: 365 additions & 34 deletions

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -24,6 +24,7 @@
2424
from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass
2525
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
2626
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
27+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2728
from executorch.exir import ExportedProgram
2829

2930
from executorch.exir.dialects._ops import ops as exir_ops
@@ -32,6 +33,13 @@
3233
from torch.fx import GraphModule, Node
3334

3435

36+
def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
37+
if qspec.dtype == torch.int8:
38+
if qspec.qmax == 7 and qspec.qmin == -7:
39+
return TosaSpecialDtype.INT4
40+
return None
41+
42+
3543
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
3644
"""
3745
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
@@ -157,6 +165,11 @@ def fold_and_annotate_arg(
157165
node.replace_input_with(n, cast(Node, n.args[0]))
158166
if len(n.users) == 0:
159167
graph_module.graph.erase_node(n)
168+
special_dtype = _get_special_dtype(input_qparams)
169+
if special_dtype:
170+
node.all_input_nodes[i].meta[
171+
TosaSpecialDtype.meta_key()
172+
] = special_dtype
160173

161174
def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
162175
"""Fold outmost quant nodes inside submodule.

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -18,6 +18,7 @@
1818
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
1919
FuseEqualPlaceholdersPass,
2020
)
21+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2122
from executorch.backends.transforms.utils import (
2223
create_constant_placeholder,
2324
delete_constant_placeholder,
@@ -52,6 +53,23 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5253
super().__init__(*args, **kwargs)
5354
self.exported_program = exported_program
5455

56+
def _propagate_special_dtype(self, from_nodes, to_node, data):
57+
"""Propagate special dtype meta if it exists."""
58+
special_dtypes = set()
59+
for input_node in from_nodes:
60+
special_type = input_node.meta.get(TosaSpecialDtype.meta_key(), None)
61+
if special_type:
62+
special_dtypes.add(special_type)
63+
if len(special_dtypes) > 1:
64+
logger.warning(
65+
"Propagating mixed special dtypes is not implemented, skipping."
66+
)
67+
elif len(special_dtypes) == 1:
68+
special_dtype = list(special_dtypes)[0]
69+
# Make sure data is still within special dtype range.
70+
if data.abs().max() <= special_dtype.max():
71+
to_node.meta[TosaSpecialDtype.meta_key()] = special_dtype
72+
5573
def _fuse_nodes(self, node) -> bool:
5674
"""
5775
Takes a node with only parameter inputs and replaces it with one constant tensor node with
@@ -105,6 +123,8 @@ def resolve_arg(arg):
105123
persistent_buffer=persistent_buffer,
106124
)
107125

126+
self._propagate_special_dtype(input_nodes, const_node, data)
127+
108128
node.replace_all_uses_with(const_node)
109129

110130
return True

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -53,11 +53,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
5353

5454
# ensure we don't merge any special case int48_t tensors with int32_t tensors
5555
# since int48_t tensors needs to be instantiated separately.
56-
is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None)
56+
is_special_dtype = node.meta.get(TosaSpecialDtype.meta_key(), None)
5757
t_cpu = tensor.detach().cpu().contiguous()
5858
data_bytes = t_cpu.numpy().tobytes()
5959
key = (
60-
is_int48,
60+
is_special_dtype,
6161
str(t_cpu.dtype),
6262
tuple(t_cpu.shape),
6363
hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(),

backends/arm/ethosu/compile_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -73,7 +73,7 @@ def __init__(
7373
compiler_flags.append(f"--memory-mode={memory_mode}")
7474

7575
# Set TOSA version.
76-
base_tosa_version = "TOSA-1.0+INT+int16"
76+
base_tosa_version = "TOSA-1.0+INT+int16+int4"
7777
if "u55" in target_lower:
7878
# Add the Ethos-U55 extension marker
7979
base_tosa_version += "+u55"

backends/arm/operators/ops_identity.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -53,6 +53,8 @@ def define_node(
5353
supported_dtypes += [ts.DType.FP32]
5454
if self.tosa_spec.support_extension("int16"):
5555
supported_dtypes += [ts.DType.INT48]
56+
if self.tosa_spec.support_extension("int4"):
57+
supported_dtypes += [ts.DType.INT4]
5658
validate_valid_dtype(
5759
self.target,
5860
[inputs[0], output],

backends/arm/process_node.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -12,7 +12,7 @@
1212
import torch.fx
1313
import tosa_serializer as ts
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
15-
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
15+
from executorch.backends.arm.tosa.mapping import TosaArg
1616
from executorch.backends.arm.tosa.specification import TosaSpecification
1717
from executorch.backends.arm.tosa.utils import tosa_shape
1818
from torch._export.utils import (
@@ -116,21 +116,10 @@ def process_inputs_to_parameters(
116116
)
117117
parameter_values = parameter_data.detach().numpy()
118118

119-
if tosa_arg.dtype == torch.float32:
120-
if not tosa_spec.support_float():
121-
raise ValueError(f"{tosa_spec} doesn't support float operations")
122-
123-
# Handle special case for INT48 tensors
124-
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
125-
if isinstance(special_type, TosaSpecialDtype):
126-
tosa_dtype = special_type.get_tosa_dtype()
127-
else:
128-
tosa_dtype = tosa_arg.dtype
129-
130119
parameter_values = np.transpose(parameter_values, tosa_arg.dim_order)
131120

132121
tosa_graph.addConst(
133-
parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name
122+
parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name
134123
)
135124

136125

backends/arm/quantizer/arm_quantizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ def get_symmetric_quantization_config(
180180
return quantization_config
181181

182182

183+
def get_symmetric_a8w4_quantization_config(
184+
is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False
185+
):
186+
return get_symmetric_quantization_config(
187+
is_per_channel, is_qat, is_dynamic, weight_qmin=-7, weight_qmax=7
188+
)
189+
190+
183191
@functools.lru_cache
184192
def get_symmetric_a16w8_quantization_config(
185193
is_per_channel: bool = True,

backends/arm/test/ops/test_conv2d.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -7,6 +7,9 @@
77
from typing import List, Tuple, Union
88

99
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
get_symmetric_a8w4_quantization_config,
12+
)
1013
from executorch.backends.arm.test import common
1114
from executorch.backends.arm.test.tester.test_pipeline import (
1215
EthosU55PipelineINT,
@@ -17,6 +20,7 @@
1720
VgfPipeline,
1821
)
1922

23+
2024
aten_op = "torch.ops.aten.conv2d.default"
2125
exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default"
2226

@@ -162,8 +166,8 @@ def forward(self, x):
162166
batches=1,
163167
)
164168

165-
conv2d_2x2_1x1x14x13_st2 = Conv2d(
166-
in_channels=1,
169+
conv2d_2x2_2x1x14x13_st2 = Conv2d(
170+
in_channels=2,
167171
out_channels=1,
168172
kernel_size=(2, 2),
169173
stride=2,
@@ -363,7 +367,7 @@ def forward(self, x):
363367
"3x3_1x3x24x24_st1": lambda: conv2d_3x3_1x3x24x24_st1,
364368
"3x3_1x3x12x12_st2_pd1": lambda: conv2d_3x3_1x3x12x12_st2_pd1,
365369
"1x1_1x2x16x16_st1": lambda: conv2d_1x1_1x2x16x16_st1,
366-
"2x2_1x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_1x1x14x13_st2,
370+
"2x2_2x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_2x1x14x13_st2,
367371
"5x5_1x3x14x15_st3_pd1_needs_adjust_pass": lambda: conv2d_5x5_1x3x14x15_st3_pd1,
368372
"7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": lambda: conv2d_7x7_1x3x16x16_st2_pd1_dl2,
369373
"7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass": lambda: conv2d_7x7_1x3x15x15_st1_pd0_dl1,
@@ -391,6 +395,15 @@ def forward(self, x):
391395
input_t = Tuple[torch.Tensor]
392396

393397

398+
def _get_dtype_count(model: torch.nn.Module):
399+
nbr_convs: int = model.nbr_convs # noqa
400+
return {
401+
"CONST": {"INT4": nbr_convs * 2}, # One for the weight, one for the zp.
402+
"CONV2D": {"INT32": nbr_convs},
403+
"RESCALE": {"INT8": nbr_convs},
404+
}
405+
406+
394407
@common.parametrize("test_data", test_data_FP)
395408
def test_convolution_2d_tosa_FP(test_data):
396409
model = test_data()
@@ -417,6 +430,36 @@ def test_convolution_2d_tosa_INT(test_data):
417430
pipeline.run()
418431

419432

433+
@common.parametrize(
434+
"test_data",
435+
test_data_INT,
436+
xfails={
437+
"groups,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726",
438+
"groups,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726",
439+
"groups_bias,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726",
440+
"groups_bias,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726",
441+
},
442+
)
443+
def test_convolution_2d_tosa_INT_a8w4(test_data):
444+
model, per_channel_quantization = test_data()
445+
pipeline = TosaPipelineINT[input_t](
446+
model,
447+
model.get_inputs(),
448+
aten_op,
449+
exir_op,
450+
tosa_extensions=["int4"],
451+
)
452+
pipeline.quantizer.set_global(
453+
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
454+
)
455+
pipeline.add_stage_after(
456+
"to_edge_transform_and_lower",
457+
pipeline.tester.check_dtype_count,
458+
_get_dtype_count(model),
459+
)
460+
pipeline.run()
461+
462+
420463
@common.parametrize("test_data", test_data_INT)
421464
@common.XfailIfNoCorstone300
422465
def test_convolution_2d_u55_INT(test_data):
@@ -431,6 +474,21 @@ def test_convolution_2d_u55_INT(test_data):
431474
pipeline.run()
432475

433476

477+
@common.parametrize("test_data", test_data_INT)
478+
def test_convolution_2d_u55_INT_a8w4(test_data):
479+
model, per_channel_quantization = test_data()
480+
pipeline = EthosU55PipelineINT[input_t](
481+
model,
482+
model.get_inputs(),
483+
aten_op,
484+
exir_op,
485+
)
486+
pipeline.quantizer.set_global(
487+
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
488+
)
489+
pipeline.run()
490+
491+
434492
@common.parametrize("test_data", test_data_INT)
435493
@common.XfailIfNoCorstone320
436494
def test_convolution_u85_INT(test_data):
@@ -445,6 +503,21 @@ def test_convolution_u85_INT(test_data):
445503
pipeline.run()
446504

447505

506+
@common.parametrize("test_data", test_data_INT)
507+
def test_convolution_2d_u85_INT_a8w4(test_data):
508+
model, per_channel_quantization = test_data()
509+
pipeline = EthosU85PipelineINT[input_t](
510+
model,
511+
model.get_inputs(),
512+
aten_op,
513+
exir_op,
514+
)
515+
pipeline.quantizer.set_global(
516+
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
517+
)
518+
pipeline.run()
519+
520+
448521
@common.parametrize("test_data", test_data_FP)
449522
@common.SkipIfNoModelConverter
450523
def test_convolution_2d_vgf_no_quant(test_data):

0 commit comments

Comments
 (0)