Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa
from .decompose_sign_pass import DecomposeSignPass # noqa
from .decompose_silu_pass import DecomposeSiluPass # noqa
from .decompose_sinh_pass import DecomposeSinhPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
DecomposeRoundPass,
DecomposeScaledDotProductAttentionPass,
DecomposeSelectPass,
DecomposeSelectScatterPass,
DecomposeSignPass,
DecomposeSiluPass,
DecomposeSinhPass,
Expand Down Expand Up @@ -330,6 +331,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
# Transformation passes (pre scalar -> tensor)
self.add_passes(
[
DecomposeSelectScatterPass(),
ConvertInt64ConstOpsToInt32Pass(),
ConvertInt64OutputOpsToInt32Pass(),
InsertInt32CastsAfterInt64PlaceholdersPass(),
Expand Down
143 changes: 143 additions & 0 deletions backends/arm/_passes/decompose_select_scatter_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import (
ConvertInt64ConstOpsToInt32Pass,
)
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_scatter_ops = (exir_ops.edge.aten.select_scatter.default,)
aten_scatter_ops = (torch.ops.aten.select_scatter.default,)


def get_select_scatter_decomposition(op) -> tuple:
if op in edge_scatter_ops:
return (
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.view_copy.default,
)
if op in aten_scatter_ops:
return (
torch.ops.aten.arange.start_step,
torch.ops.aten.eq.Scalar,
torch.ops.aten.where.self,
torch.ops.aten.expand_copy.default,
torch.ops.aten.unsqueeze_copy.default,
torch.ops.aten.view_copy.default,
)

raise RuntimeError(f"Can't get select_scatter decomposition for op {op}")


class DecomposeSelectScatterPass(ArmPass):
"""select_scatter is decomposed into other ops during export, however this is only
suppported for the fp profile and for the int profile we need to decompose it here.

The decomposition is as follows:
- Build a boolean mask the size of x
eq(view(arange(0, dim_size), mask_shape), index)
- Broadcast source to x
expand(unsqueeze(source, dim), shape)
- Route the updated slice while keeping the untouched lanes
where(mask, expanded_source, x)

This reflects the decomposition for the fp profile implemented in torch._refs
"""

_passes_required_after: Set[Type[ExportPass]] = {
ReplaceScalarWithTensorByProfilePass,
ConvertInt64ConstOpsToInt32Pass,
}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_scatter_ops + aten_scatter_ops):
return super().call_operator(op, args, kwargs, meta, updated=False)

(
arange_op,
eq_op,
where_op,
expand_op,
unsqueeze_op,
view_op,
) = get_select_scatter_decomposition(op)

input_tensor = args[0]
src_tensor = args[1]
dim = int(args[2])
index = int(args[3])

shape = input_tensor.data.size()
rank = len(shape)
dim = dim % rank if dim < 0 else dim
dim_size = shape[dim]
if index < 0:
index = index + dim_size

mask_shape = [1] * rank
mask_shape[dim] = -1

arange_node = super().call_operator(
arange_op,
(0, dim_size, 1),
{},
meta,
updated=False,
)

view_node = super().call_operator(
view_op,
(arange_node, mask_shape),
{},
meta,
updated=False,
)

mask_node = super().call_operator(
eq_op,
(view_node, index),
{},
meta,
updated=False,
)

unsqueeze_node = super().call_operator(
unsqueeze_op,
(src_tensor, dim),
{},
meta,
updated=False,
)

expand_node = super().call_operator(
expand_op,
(unsqueeze_node, shape),
{},
meta,
updated=False,
)

where_node = super().call_operator(
where_op,
(mask_node, expand_node, input_tensor),
{},
meta,
updated=True,
)

return where_node
173 changes: 173 additions & 0 deletions backends/arm/test/ops/test_select_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU85PipelineINT,
OpNotSupportedPipeline,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

test_data_suite = {
"rank2_rand": lambda: (
torch.randint(-30, 30, (5, 9), dtype=torch.float32),
torch.randint(0, 9, (9,), dtype=torch.float32),
0,
2,
),
"rank2_zeros": lambda: (
torch.rand((3, 2), dtype=torch.float32),
torch.randint(0, 4, (2,), dtype=torch.float32),
0,
0,
),
"rank3_rand": lambda: (
torch.rand((2, 4, 5), dtype=torch.float32),
torch.randint(-5, 5, (2, 5), dtype=torch.float32),
1,
0,
),
"rank3_ones": lambda: (
torch.ones((2, 3, 3), dtype=torch.float32),
torch.rand((2, 3), dtype=torch.float32),
2,
2,
),
"rank4_rand": lambda: (
torch.rand((1, 2, 4, 5), dtype=torch.float32),
torch.rand((2, 4, 5), dtype=torch.float32),
0,
0,
),
"rank4_ones": lambda: (
torch.ones((2, 3, 3, 2), dtype=torch.float32),
torch.randint(-5, 5, (2, 3, 2), dtype=torch.float32),
2,
-1,
),
"rank5_ones": lambda: (
torch.ones((3, 4, 20, 9, 5), dtype=torch.float32),
torch.randn((3, 4, 20, 9), dtype=torch.float32),
4,
1,
),
"rank6_rand": lambda: (
torch.rand((1, 2, 3, 4, 2, 1), dtype=torch.float32),
torch.randn((2, 3, 4, 2, 1), dtype=torch.float32),
0,
0,
),
}


class SelectScatter(torch.nn.Module):
fp_aten_op = "torch.ops.aten.select_scatter.default"
int_aten_ops = [
"torch.ops.aten.arange.start_step",
"torch.ops.aten.view_copy.default",
"torch.ops.aten.unsqueeze_copy.default",
"torch.ops.aten.expand_copy.default",
"torch.ops.aten.where.self",
"torch.ops.aten.eq.Tensor",
]
fp_exir_op = ["executorch_exir_dialects_edge__ops_aten_select_scatter_default"]
int_exir_ops = [
"executorch_exir_dialects_edge__ops_aten_eq_Tensor",
"executorch_exir_dialects_edge__ops_aten_where_self",
"executorch_exir_dialects_edge__ops_aten_arange_start_step",
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default",
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
]
u55_not_supported = {
"executorch_exir_dialects_edge__ops_aten_eq_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_where_self": 1,
}

def forward(self, x: torch.Tensor, y: torch.Tensor, dim: int, index: int):
return x.select_scatter(y, dim, index)


input_t = Tuple[torch.Tensor, torch.Tensor, int, int]


@common.parametrize("test_module", test_data_suite)
def test_select_scatter_tosa_FP(test_module: input_t):
pipeline = TosaPipelineFP[input_t](
SelectScatter(),
test_module(),
aten_op=SelectScatter.fp_aten_op,
exir_op=SelectScatter.fp_exir_op,
)
pipeline.run()


@common.parametrize("test_module", test_data_suite)
def test_select_scatter_tosa_INT(test_module: input_t):
pipeline = TosaPipelineINT[input_t](
SelectScatter(),
test_module(),
aten_op=SelectScatter.int_aten_ops,
exir_op=SelectScatter.int_exir_ops,
)
pipeline.run()


@common.parametrize("test_module", test_data_suite)
def test_select_scatter_u55_INT(test_module: input_t):
# select_scatter is not supported on U55
pipeline = OpNotSupportedPipeline[input_t](
SelectScatter(),
test_module(),
SelectScatter.u55_not_supported,
quantize=True,
u55_subset=True,
n_expected_delegates=1,
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_module", test_data_suite)
def test_select_scatter_u85_INT(test_module: input_t):
pipeline = EthosU85PipelineINT[input_t](
SelectScatter(),
test_module(),
aten_ops=SelectScatter.int_aten_ops,
exir_ops=SelectScatter.int_exir_ops,
)
pipeline.run()


@common.SkipIfNoModelConverter
@common.parametrize("test_module", test_data_suite)
def test_select_scatter_vgf_FP(test_module: input_t):
pipeline = VgfPipeline[input_t](
SelectScatter(),
test_module(),
aten_op=SelectScatter.fp_aten_op,
exir_op=SelectScatter.fp_exir_op,
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.SkipIfNoModelConverter
@common.parametrize("test_module", test_data_suite)
def test_select_scatter_vgf_INT(test_module: input_t):
pipeline = VgfPipeline[input_t](
SelectScatter(),
test_module(),
aten_op=SelectScatter.int_aten_ops,
exir_op=SelectScatter.int_exir_ops,
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading