diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index f8f9cc36271..7ce411061e3 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -51,7 +51,6 @@ from .decompose_int16_activation_conv_pass import ( # noqa DecomposeConvWithInt16ActivationPass, ) -from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa from .decompose_int_pow_pass import DecomposeIntPowPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa @@ -75,6 +74,9 @@ from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa from .decompose_sum_pass import DecomposeSumPass # noqa +from .decompose_tosa_unsupported_clamp_pass import ( # noqa + DecomposeTOSAUnsupportedClampPass, +) from .decompose_var_pass import DecomposeVarPass # noqa from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index be894b27787..47b1e4e5558 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -54,7 +54,6 @@ DecomposeGluPass, DecomposeGroupedConvPass, DecomposeGroupNormPass, - DecomposeInt32ClampPass, DecomposeIntPowPass, DecomposeLayerNormPass, DecomposeLeakyReLUPass, @@ -78,6 +77,7 @@ DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, DecomposeSumPass, + DecomposeTOSAUnsupportedClampPass, DecomposeVarPass, DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, @@ -220,7 +220,7 @@ def _tosa_pipeline( [ FuseQuantizedActivationPass(), ConvertToClampPass(), - DecomposeInt32ClampPass(), + DecomposeTOSAUnsupportedClampPass(), DecomposeGroupNormPass(), DecomposeLayerNormPass(), DecomposeVarPass(), diff --git a/backends/arm/_passes/decompose_int32_clamp_pass.py b/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py similarity index 58% rename from backends/arm/_passes/decompose_int32_clamp_pass.py rename to backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py index 5574875b6b8..b467f6795b3 100644 --- a/backends/arm/_passes/decompose_int32_clamp_pass.py +++ b/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py @@ -11,13 +11,16 @@ from executorch.exir.pass_base import ExportPass -class DecomposeInt32ClampPass(ArmPass): - """Rewrite int32 clamp into min/max chain since TOSA lacks int32 clamp support.""" +class DecomposeTOSAUnsupportedClampPass(ArmPass): + """Rewrite TOSA unsupported clamp into min/max chain since TOSA lacks int32 clamp support + and only supports scalar min/max values.""" _passes_required_after: Set[Type[ExportPass]] = set() _supported_ops = { exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, torch.ops.aten.clamp.default, + torch.ops.aten.clamp.Tensor, } def _ensure_tensor( @@ -40,31 +43,53 @@ def _ensure_tensor( def call_operator(self, op, args, kwargs, meta): val = meta["val"] - if op not in self._supported_ops or val.dtype != torch.int32: + + is_scalar_clamp = op in { + exir_ops.edge.aten.clamp.default, + torch.ops.aten.clamp.default, + } + is_tensor_clamp = op in { + exir_ops.edge.aten.clamp.Tensor, + torch.ops.aten.clamp.Tensor, + } + + if op not in self._supported_ops: + return super().call_operator(op, args, kwargs, meta) + + # Only rewrite scalar clamp for int32 + if is_scalar_clamp and val.dtype != torch.int32: return super().call_operator(op, args, kwargs, meta) input_tensor = args[0] - min_arg = args[1] if len(args) > 1 else None - max_arg = args[2] if len(args) > 2 else None dtype = val.dtype rank = len(val.shape) + min_arg = args[1] if len(args) > 1 else None + max_arg = args[2] if len(args) > 2 else None - min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta) - max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta) + if is_scalar_clamp: + # Scalar min/max -> make them tensors for min/max ops + min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta) + max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta) + else: + # Tensor variant: arguments are already tensors; nothing extra to do + if not is_tensor_clamp: + raise RuntimeError( + f"DecomposeTOSAUnsupportedClampPass: unexpected op {op} in tensor clamp branch" + ) current = input_tensor - if max_arg is not None: + if min_arg is not None: current = super().call_operator( - exir_ops.edge.aten.minimum.default, - (current, max_arg), + exir_ops.edge.aten.maximum.default, + (current, min_arg), {}, meta, updated=True, ) - if min_arg is not None: + if max_arg is not None: current = super().call_operator( - exir_ops.edge.aten.maximum.default, - (current, min_arg), + exir_ops.edge.aten.minimum.default, + (current, max_arg), {}, meta, updated=True, diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index 351acab13d8..f4f72690345 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -41,6 +41,7 @@ exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, exir_ops.edge.aten.cumsum.default, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.permute_copy.default, @@ -138,6 +139,7 @@ exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, exir_ops.edge.aten.cos.default, exir_ops.edge.aten.cumsum.default, exir_ops.edge.aten.bmm.default, diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 9a7c6564e9c..f570fd222eb 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -34,6 +34,7 @@ "pixel_shuffle.default", "pixel_unshuffle.default", "while_loop.default", + "clamp.Tensor", ] ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py index 34513305563..60477c6cbe4 100644 --- a/backends/arm/test/ops/test_clamp.py +++ b/backends/arm/test/ops/test_clamp.py @@ -96,8 +96,6 @@ def test_clamp_tosa_INT(test_data): aten_op, exir_op, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) - pipeline.run() @@ -112,7 +110,6 @@ def test_clamp_tosa_INT_int32_inputs(test_data): aten_op, exir_op, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.pop_stage("quantize") pipeline.run() @@ -129,7 +126,6 @@ def test_clamp_tosa_INT_a16w8(test_data): exir_op, tosa_extensions=["int16"], ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -145,14 +141,12 @@ def test_clamp_u55_INT(test_data): aten_op, exir_op, ) - - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 -def test_clamp_16a8w_u55_INT16(test_data): +def test_clamp_16a8w_u55_INT(test_data): """Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -165,7 +159,6 @@ def test_clamp_16a8w_u55_INT16(test_data): a16w8_quantization=True, use_to_edge_transform_and_lower=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -181,14 +174,12 @@ def test_clamp_u85_INT(test_data): aten_op, exir_op, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) - pipeline.run() @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -def test_clamp_16a8w_u85_INT16(test_data): +def test_clamp_16a8w_u85_INT(test_data): """Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -201,7 +192,6 @@ def test_clamp_16a8w_u85_INT16(test_data): a16w8_quantization=True, use_to_edge_transform_and_lower=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -233,3 +223,252 @@ def test_clamp_vgf_quant(test_data): quantize=True, ) pipeline.run() + + +aten_op_tensor = "torch.ops.aten.clamp.Tensor" +exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_clamp_Tensor" + +test_data_suite_tensor_FP = { + # test_name: (test_data, min, max) + "rank_1": lambda: (torch.rand(10) * 2, torch.tensor(-1.0), torch.tensor(1.0)), + "rank_2": lambda: (torch.rand(1, 35), torch.tensor(0.5), torch.tensor(0.8)), + "rank_3": lambda: ( + torch.ones(1, 10, 10), + torch.rand(1, 10, 10) * 0.5, + torch.rand(1, 10, 10) * -0.5, + ), + "rank_4": lambda: ( + torch.rand(1, 10, 10, 1) * 2, + torch.tensor(-0.1), + torch.tensor(2.0), + ), + "rank_4_no_max": lambda: ( + torch.rand(10, 20, 30, 40) - 3, + torch.rand(30, 40) - 3.3, + None, + ), + "rank_4_no_min": lambda: ( + torch.rand(10, 20, 30, 40) * 10, + None, + torch.rand(10, 20, 30, 40) * 5.0, + ), +} + +test_data_suite_tensor_INT32 = { + "int32_rank2": lambda: ( + torch.randint(-50, 50, (2, 3), dtype=torch.int32), + torch.tensor(-10), + torch.tensor(10), + ), + "int32_rank3_no_min_broadcast_1_3": lambda: ( + torch.randint(0, 100, (1, 3, 3), dtype=torch.int32) + 10, + None, + torch.tensor([[3, 5, 7]], dtype=torch.int32), # torch.Size([1, 3]) + ), + "int32_rank3_no_max_broadcast_3_1": lambda: ( + torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32), + torch.tensor([[3], [5], [7]], dtype=torch.int32), # torch.Size([3, 1]) + None, + ), + "int32_rank4_large_range": lambda: ( + torch.randint(-200, 200, (1, 2, 4, 4), dtype=torch.int32), + torch.tensor((torch.iinfo(torch.int32).min)), + torch.tensor((torch.iinfo(torch.int32).max)), + ), + "int32_rank4_broadcast_1_2": lambda: ( + torch.ones(1, 2, 4, 4, dtype=torch.int32) * 100, + torch.randint(-10, 10, (4,), dtype=torch.int32), # torch.Size([4]) + torch.randint( + -10, + 10, + ( + 4, + 4, + ), + dtype=torch.int32, + ), # torch.Size([4, 4]) + ), + "int32_rank4_broadcast_3_4": lambda: ( + torch.ones(1, 2, 4, 4, dtype=torch.int32) * 100, + torch.randint( + -10, + 10, + ( + 1, + 4, + 4, + ), + dtype=torch.int32, + ), # torch.Size([1, 4, 4]) + torch.randint( + -10, + 10, + ( + 1, + 1, + 4, + 4, + ), + dtype=torch.int32, + ), # torch.Size([1, 1, 4, 4]) + ), +} + +test_data_suite_tensor_INT64 = { + "int64_rank_3": lambda: ( + torch.ones(1, 10, 10, dtype=torch.int64), + torch.tensor(-1), + torch.tensor(-1), + ), + "int64_rank_4": lambda: ( + torch.randint(-100, 100, (1, 3, 3)), + torch.tensor(-10), + torch.tensor(20), + ), +} + + +@common.parametrize("test_data", test_data_suite_tensor_FP) +def test_clamp_tensor_tosa_FP(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = TosaPipelineFP[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + + pipeline.run() + + +@common.parametrize( + "test_data", test_data_suite_tensor_INT32 | test_data_suite_tensor_INT64 +) +def test_clamp_tensor_tosa_INT(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + pipeline.run() + + +@common.parametrize( + "test_data", test_data_suite_tensor_INT32 | test_data_suite_tensor_INT64 +) +def test_clamp_tensor_tosa_INT_a16w8(test_data): + """Test clamp operation with int16 I/O quantization for TOSA INT.""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone300 +def test_clamp_tensor_u55_INT(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = EthosU55PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone300 +def test_clamp_tensor_16a8w_u55_INT(test_data): + """Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU55PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone320 +def test_clamp_tensor_u85_INT(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = EthosU85PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone320 +def test_clamp_tensor_16a8w_u85_INT(test_data): + """Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU85PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_FP) +@common.SkipIfNoModelConverter +def test_clamp_tensor_vgf_no_quant(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = VgfPipeline[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + quantize=False, + ) + pipeline.run() + + +@common.parametrize( + "test_data", test_data_suite_tensor_INT32 | test_data_suite_tensor_INT64 +) +@common.SkipIfNoModelConverter +def test_clamp_tensor_vgf_quant(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = VgfPipeline[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_convert_int32_clamp_to_minmax_pass.py b/backends/arm/test/passes/test_decompose_tosa_unsupported_clamp_pass.py similarity index 52% rename from backends/arm/test/passes/test_convert_int32_clamp_to_minmax_pass.py rename to backends/arm/test/passes/test_decompose_tosa_unsupported_clamp_pass.py index f05d46315c7..9ceeb1b93be 100644 --- a/backends/arm/test/passes/test_convert_int32_clamp_to_minmax_pass.py +++ b/backends/arm/test/passes/test_decompose_tosa_unsupported_clamp_pass.py @@ -6,8 +6,8 @@ from typing import Tuple import torch -from executorch.backends.arm._passes.decompose_int32_clamp_pass import ( - DecomposeInt32ClampPass, +from executorch.backends.arm._passes.decompose_tosa_unsupported_clamp_pass import ( + DecomposeTOSAUnsupportedClampPass, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -39,6 +39,35 @@ def test_decompose_int32_clamp_pass(test_data: input_t): ops_not_after_pass=[ "executorch_exir_dialects_edge__ops_aten_clamp_default", ], - pass_list=[DecomposeInt32ClampPass], + pass_list=[DecomposeTOSAUnsupportedClampPass], + ) + pipeline.run() + + +class ClampTensorInt32(torch.nn.Module): + test_data = {"rand": (torch.randint(-50, 50, (2, 3), dtype=torch.int32),)} + + def forward(self, x: torch.Tensor): + return torch.clamp(x, torch.tensor(-10), torch.tensor(5)) + + +@common.parametrize("test_data", ClampTensorInt32.test_data) +def test_decompose_int32_clamp_tensor_pass(test_data: input_t): + module = ClampTensorInt32() + pipeline = PassPipeline[input_t]( + module, + test_data, + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_clamp_Tensor": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_clamp_Tensor", + ], + pass_list=[DecomposeTOSAUnsupportedClampPass], ) pipeline.run()