Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Arm backend: Decompose conv2d with 16 bit activation
Support quantization to 16a8w. Since the resulting TOSA operator
needs to have the bias in int48 which isn't avaiable as a type in
torch, the conv2d needs to be decomposed into a conv + add, where
the conv result is scaled down to 32 bit before the addition of the
bias is done.

Signed-off-by: Per Åstrand <per.astrand@arm.com>
Change-Id: Ib8cae694035796374a55a9909e501596e983abf5
  • Loading branch information
per committed Sep 22, 2025
commit 60033f9ff59ed1c8961b58bc81b06247782b44e0
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_int16_activation_conv2d_pass import ( # noqa
DecomposeConv2dWithInt16ActivationPass,
)
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
Expand Down
9 changes: 8 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DecomposeAtanPass,
DecomposeAvgPool2d,
DecomposeBatchNormNoStatsPass,
DecomposeConv2dWithInt16ActivationPass,
DecomposeCoshPass,
DecomposeCosineSimilarityPass,
DecomposeCumsumPass,
Expand Down Expand Up @@ -183,6 +184,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(DecomposeGroupedConv())

self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
Expand All @@ -196,9 +198,14 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
# If we have a conv2d with int16 activation split up into a convolution
# and an addition, to work-around the lack of support for int48 in torch
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and an addition, to work-around the lack of support for int48 in torch

Or can it be done by using torch.dtype.int64 instead and then detecting and lowering it as int48 downstream?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was starting of in that direction, but it interfere a bit with the int64->int32 handling, so rather keep it separate.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah given int64 is treated as radioactive :P

# needs to happen before AddBiasPass, but after the table ops are inserted
# to be able to validate that conv2d has right dtype arguments.
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
self.add_pass(AddBiasPass(exported_program))

self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
Expand Down
145 changes: 145 additions & 0 deletions backends/arm/_passes/decompose_int16_activation_conv2d_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast

import torch
from executorch.backends.arm._passes.quant_args import QuantArgs

from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class DecomposeConv2dWithInt16ActivationPass(ExportPass):
"""
This pass decomposes a convolution with input dtype int16 and bias
into a convolution without bias followed by an addition of the bias
since the TOSA op requires the bias to be int48 which is hard to represent
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.convolution.default:
return super().call_operator(op, args, kwargs, meta)

tosa_spec = get_context_spec()
if not tosa_spec.support_integer():
return super().call_operator(op, args, kwargs, meta)

# return if no bias
if args[2] is None:
return super().call_operator(op, args, kwargs, meta)

if args[0].data.dtype == torch.int8:
return super().call_operator(op, args, kwargs, meta)
elif args[0].data.dtype == torch.int16:
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
"int16"
):
raise ValueError(
"int16 activation for convolution requires TOSA int16 extension"
)
else:
raise NotImplementedError(
"Decomposition to conv+add only implemented for activation of int16 type"
)

# convolution with bias and activation is int16
# The bias is assumed to be quantized with the same quantization parameters as
# as the output of the convolution
bias = args[2]
assert (
meta.data["output_qparams"][0].dtype == bias.data.dtype
), "Bias needs to have same type as quantized output type"
no_bias_args = list(args)
no_bias_args[2] = None
# split up to convolution + bias
convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta)

# create a copy of the meta without the qparams, to be used with the new nodes
new_meta = meta.copy()
new_meta.data.pop("output_qparams", None)
new_meta.data.pop("input_qparams", None)

# reshape the tensor to the same rank as the convolution output to add the bias to the channels
channel_bias = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(bias, [1, len(bias.data), 1, 1]),
{},
new_meta,
)

output_dtype = meta.data["output_qparams"][0].dtype

if output_dtype == torch.int16:
# The conv will get the output int48 scaled to int32 in serialization step.
# To be able to add the bias we need to first scale (cast?) the output to int32.
# The resulting i32 sum will then need to be scaled back to the output dtype.

# calculate common rescale factor from convolution output and bias quantization
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
conv_output_scale = output_qparams.scale
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
bias_scale = bias_qparams.scale

common_scale = max(bias_scale, conv_output_scale)

# calculate how we can rescale bias and conv to a common scale and maximize the output range
bias_rescale_factor = bias_scale / common_scale
conv_rescale_factor = conv_output_scale / common_scale

# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
# and then one for the sign bit.
bits_left_to_shift = 14

# update rescale factors
bias_rescale_factor *= 1 << bits_left_to_shift
conv_rescale_factor *= 1 << bits_left_to_shift

conv_output = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(convolution, torch.int32, conv_rescale_factor, 0, 0),
{},
new_meta,
)

bias_rescaled = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
{},
new_meta,
)

add = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(conv_output, bias_rescaled),
{},
new_meta,
)

res_rescale = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(
add,
output_dtype,
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
0,
0,
),
{},
new_meta,
)

else:
raise NotImplementedError(
f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}"
)

return res_rescale
63 changes: 41 additions & 22 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,29 +89,48 @@ def _derive_qparams_fn(
torch.ops.aten.linear.default,
torch.ops.aten.conv2d.padding,
]:
input_act = node.args[0]
weight = node.args[1]
# If the weights are quantized per_tensor, do the same with bias
qscheme = (
torch.per_tensor_symmetric
if self.weight is None
else self.weight.qscheme
)
ch_axis = None
if self.weight is not None:
if qscheme == torch.per_channel_symmetric:
ch_axis = self.weight.ch_axis
if self.input_activation is None or self.weight is None:
raise ValueError(
"Input activation and weight QuantizationConfig must be specified."
)
if self.input_activation.dtype == self.weight.dtype == torch.int8:
# This is the default int8 quantization which uses the derived quantization
# calculated from the activation and weight scale
input_act = node.args[0]
weight = node.args[1]

quantization_spec = DerivedQuantizationSpec(
derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item]
derive_qparams_fn=_derive_qparams_fn,
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max - 1,
qscheme=qscheme,
ch_axis=ch_axis,
)
return quantization_spec # type: ignore[return-value]
# If the weights are quantized per_tensor, do the same with bias
qscheme = (
torch.per_tensor_symmetric
if self.weight is None
else self.weight.qscheme
)
ch_axis = None
if self.weight is not None:
if qscheme == torch.per_channel_symmetric:
ch_axis = self.weight.ch_axis

quantization_spec = DerivedQuantizationSpec(
derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item]
derive_qparams_fn=_derive_qparams_fn,
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max - 1,
qscheme=qscheme,
ch_axis=ch_axis,
)
return quantization_spec # type: ignore[return-value]
elif (
self.input_activation.dtype == torch.int16
and self.weight.dtype == torch.int8
):
# In case the activation is quantized to int16, the bias needs to be
# added after the convolution, so use the output quantization for this case.
return self.output_activation
else:
raise NotImplementedError(
f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented"
)

if self.bias is None:
return None
Expand Down