From 79cd4d54975f2fd178b9e9e32dc0df7a5e36a375 Mon Sep 17 00:00:00 2001 From: Agrima Khare Date: Wed, 19 Nov 2025 13:50:51 +0000 Subject: [PATCH] Arm Backend: Add support for select_scatter.default Signed-off-by: Agrima Khare Change-Id: Ib2ee27d940bef495d88b32b084f8bcaece78a09b --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + .../_passes/decompose_select_scatter_pass.py | 143 +++++++++++++++ backends/arm/test/ops/test_select_scatter.py | 173 ++++++++++++++++++ 4 files changed, 319 insertions(+) create mode 100644 backends/arm/_passes/decompose_select_scatter_pass.py create mode 100644 backends/arm/test/ops/test_select_scatter.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 08b840de045..5de0d07344d 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -67,6 +67,7 @@ from .decompose_round_pass import DecomposeRoundPass # noqa from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa from .decompose_select import DecomposeSelectPass # noqa +from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa from .decompose_sign_pass import DecomposeSignPass # noqa from .decompose_silu_pass import DecomposeSiluPass # noqa from .decompose_sinh_pass import DecomposeSinhPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 2da33bfd41d..f13f88f538b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -70,6 +70,7 @@ DecomposeRoundPass, DecomposeScaledDotProductAttentionPass, DecomposeSelectPass, + DecomposeSelectScatterPass, DecomposeSignPass, DecomposeSiluPass, DecomposeSinhPass, @@ -325,6 +326,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): # Transformation passes (pre scalar -> tensor) self.add_passes( [ + DecomposeSelectScatterPass(), ConvertInt64ConstOpsToInt32Pass(), ConvertInt64OutputOpsToInt32Pass(), InsertInt32CastsAfterInt64PlaceholdersPass(), diff --git a/backends/arm/_passes/decompose_select_scatter_pass.py b/backends/arm/_passes/decompose_select_scatter_pass.py new file mode 100644 index 00000000000..f3c7ae5955b --- /dev/null +++ b/backends/arm/_passes/decompose_select_scatter_pass.py @@ -0,0 +1,143 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( + ConvertInt64ConstOpsToInt32Pass, +) +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_scatter_ops = (exir_ops.edge.aten.select_scatter.default,) +aten_scatter_ops = (torch.ops.aten.select_scatter.default,) + + +def get_select_scatter_decomposition(op) -> tuple: + if op in edge_scatter_ops: + return ( + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.eq.Scalar, + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.view_copy.default, + ) + if op in aten_scatter_ops: + return ( + torch.ops.aten.arange.start_step, + torch.ops.aten.eq.Scalar, + torch.ops.aten.where.self, + torch.ops.aten.expand_copy.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.view_copy.default, + ) + + raise RuntimeError(f"Can't get select_scatter decomposition for op {op}") + + +class DecomposeSelectScatterPass(ArmPass): + """select_scatter is decomposed into other ops during export, however this is only + suppported for the fp profile and for the int profile we need to decompose it here. + + The decomposition is as follows: + - Build a boolean mask the size of x + eq(view(arange(0, dim_size), mask_shape), index) + - Broadcast source to x + expand(unsqueeze(source, dim), shape) + - Route the updated slice while keeping the untouched lanes + where(mask, expanded_source, x) + + This reflects the decomposition for the fp profile implemented in torch._refs + """ + + _passes_required_after: Set[Type[ExportPass]] = { + ReplaceScalarWithTensorByProfilePass, + ConvertInt64ConstOpsToInt32Pass, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_scatter_ops + aten_scatter_ops): + return super().call_operator(op, args, kwargs, meta, updated=False) + + ( + arange_op, + eq_op, + where_op, + expand_op, + unsqueeze_op, + view_op, + ) = get_select_scatter_decomposition(op) + + input_tensor = args[0] + src_tensor = args[1] + dim = int(args[2]) + index = int(args[3]) + + shape = input_tensor.data.size() + rank = len(shape) + dim = dim % rank if dim < 0 else dim + dim_size = shape[dim] + if index < 0: + index = index + dim_size + + mask_shape = [1] * rank + mask_shape[dim] = -1 + + arange_node = super().call_operator( + arange_op, + (0, dim_size, 1), + {}, + meta, + updated=False, + ) + + view_node = super().call_operator( + view_op, + (arange_node, mask_shape), + {}, + meta, + updated=False, + ) + + mask_node = super().call_operator( + eq_op, + (view_node, index), + {}, + meta, + updated=False, + ) + + unsqueeze_node = super().call_operator( + unsqueeze_op, + (src_tensor, dim), + {}, + meta, + updated=False, + ) + + expand_node = super().call_operator( + expand_op, + (unsqueeze_node, shape), + {}, + meta, + updated=False, + ) + + where_node = super().call_operator( + where_op, + (mask_node, expand_node, input_tensor), + {}, + meta, + updated=True, + ) + + return where_node diff --git a/backends/arm/test/ops/test_select_scatter.py b/backends/arm/test/ops/test_select_scatter.py new file mode 100644 index 00000000000..94bfc518b22 --- /dev/null +++ b/backends/arm/test/ops/test_select_scatter.py @@ -0,0 +1,173 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineINT, + OpNotSupportedPipeline, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + "rank2_rand": lambda: ( + torch.randint(-30, 30, (5, 9), dtype=torch.float32), + torch.randint(0, 9, (9,), dtype=torch.float32), + 0, + 2, + ), + "rank2_zeros": lambda: ( + torch.rand((3, 2), dtype=torch.float32), + torch.randint(0, 4, (2,), dtype=torch.float32), + 0, + 0, + ), + "rank3_rand": lambda: ( + torch.rand((2, 4, 5), dtype=torch.float32), + torch.randint(-5, 5, (2, 5), dtype=torch.float32), + 1, + 0, + ), + "rank3_ones": lambda: ( + torch.ones((2, 3, 3), dtype=torch.float32), + torch.rand((2, 3), dtype=torch.float32), + 2, + 2, + ), + "rank4_rand": lambda: ( + torch.rand((1, 2, 4, 5), dtype=torch.float32), + torch.rand((2, 4, 5), dtype=torch.float32), + 0, + 0, + ), + "rank4_ones": lambda: ( + torch.ones((2, 3, 3, 2), dtype=torch.float32), + torch.randint(-5, 5, (2, 3, 2), dtype=torch.float32), + 2, + -1, + ), + "rank5_ones": lambda: ( + torch.ones((3, 4, 20, 9, 5), dtype=torch.float32), + torch.randn((3, 4, 20, 9), dtype=torch.float32), + 4, + 1, + ), + "rank6_rand": lambda: ( + torch.rand((1, 2, 3, 4, 2, 1), dtype=torch.float32), + torch.randn((2, 3, 4, 2, 1), dtype=torch.float32), + 0, + 0, + ), +} + + +class SelectScatter(torch.nn.Module): + fp_aten_op = "torch.ops.aten.select_scatter.default" + int_aten_ops = [ + "torch.ops.aten.arange.start_step", + "torch.ops.aten.view_copy.default", + "torch.ops.aten.unsqueeze_copy.default", + "torch.ops.aten.expand_copy.default", + "torch.ops.aten.where.self", + "torch.ops.aten.eq.Tensor", + ] + fp_exir_op = ["executorch_exir_dialects_edge__ops_aten_select_scatter_default"] + int_exir_ops = [ + "executorch_exir_dialects_edge__ops_aten_eq_Tensor", + "executorch_exir_dialects_edge__ops_aten_where_self", + "executorch_exir_dialects_edge__ops_aten_arange_start_step", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default", + "executorch_exir_dialects_edge__ops_aten_expand_copy_default", + ] + u55_not_supported = { + "executorch_exir_dialects_edge__ops_aten_eq_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_where_self": 1, + } + + def forward(self, x: torch.Tensor, y: torch.Tensor, dim: int, index: int): + return x.select_scatter(y, dim, index) + + +input_t = Tuple[torch.Tensor, torch.Tensor, int, int] + + +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_tosa_FP(test_module: input_t): + pipeline = TosaPipelineFP[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.fp_aten_op, + exir_op=SelectScatter.fp_exir_op, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_tosa_INT(test_module: input_t): + pipeline = TosaPipelineINT[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.int_aten_ops, + exir_op=SelectScatter.int_exir_ops, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_u55_INT(test_module: input_t): + # select_scatter is not supported on U55 + pipeline = OpNotSupportedPipeline[input_t]( + SelectScatter(), + test_module(), + SelectScatter.u55_not_supported, + quantize=True, + u55_subset=True, + n_expected_delegates=1, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_u85_INT(test_module: input_t): + pipeline = EthosU85PipelineINT[input_t]( + SelectScatter(), + test_module(), + aten_ops=SelectScatter.int_aten_ops, + exir_ops=SelectScatter.int_exir_ops, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_vgf_FP(test_module: input_t): + pipeline = VgfPipeline[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.fp_aten_op, + exir_op=SelectScatter.fp_exir_op, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_vgf_INT(test_module: input_t): + pipeline = VgfPipeline[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.int_aten_ops, + exir_op=SelectScatter.int_exir_ops, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()