From 58ce7d86cdecd835901411e5ffbee3352056a972 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 24 Mar 2025 16:08:39 +0100 Subject: [PATCH] Arm backend: Extend convolution support check to 3d Add conv3d tests, though most are skipped since conv3d support is not yet implemented. Signed-off-by: Erik Lundell Change-Id: I13b9756eef4188a32bb1784836d2fe96a22f1ade --- .../operator_support/convolution_support.py | 16 +- backends/arm/test/ops/test_conv2d.py | 28 +- backends/arm/test/ops/test_conv3d.py | 399 ++++++++++++++++++ 3 files changed, 424 insertions(+), 19 deletions(-) create mode 100644 backends/arm/test/ops/test_conv3d.py diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index b07ae82f98f..9e13babe23a 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -55,7 +55,7 @@ def _is_node_supported_u55(self, node: fx.Node): C_in = shape_in[1] C_out = shape_out[1] - if (C_in == group) and (C_out % C_in) == 0: + if (C_in == group) and (C_out % C_in) == 0 and len(shape_in) <= 4: # Depthwise convolution for dim in shape_in[1:]: if not 1 <= dim <= 65536: @@ -74,6 +74,7 @@ def _is_node_supported_u55(self, node: fx.Node): kernel_w = kernel[2] kernel_h = kernel[3] if len(kernel) > 3 else 1 + kernel_z = kernel[4] if len(kernel) > 4 else 1 # Kernel condition misses constraint on sum of absolute weights if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096: self.reporter.report_reject( @@ -81,6 +82,11 @@ def _is_node_supported_u55(self, node: fx.Node): f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({kernel_w}, {kernel_h})", ) return False + if kernel_z != 1: + self.reporter.report_reject( + node, f"Convolution3d needs to have kernel_z==1, got {kernel_z}." + ) + return False if not self._stride_condition(node): self.reporter.report_reject( @@ -107,6 +113,14 @@ def _stride_condition(self, node: fx.Node) -> bool: if len(strides) == 1: strides = [strides[0]] * 2 + if len(strides) > 2: + stride_z = strides[2] + if stride_z > 1: + self.reporter.report_reject( + node, f"Convolution3d only supports stride_z<=1, got {stride_z}." + ) + return False + for stride, dilation in zip(strides, dilations): stride_condition = 1 <= stride <= 3 dilation_condition = (not has_padding) and (dilation == 1) diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 8083b2ecf71..844eed97638 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -8,10 +8,10 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineBI, EthosU85PipelineBI, + OpNotSupportedPipeline, TosaPipelineBI, TosaPipelineMI, ) @@ -34,9 +34,9 @@ def __init__( in_channels: Union[List, int, None] = None, out_channels: Union[List, int, None] = None, kernel_size: Union[List, Tuple, None] = None, - stride: Union[List, Tuple, None] = None, - padding: Union[List, Tuple, None] = None, - dilation: Union[List, Tuple, None] = None, + stride: Union[List, Tuple, int, None] = None, + padding: Union[List, Tuple, int, None] = None, + dilation: Union[List, Tuple, int, None] = None, groups: Union[List, int, None] = None, bias: Union[List, bool, None] = None, padding_mode: Union[List, str, None] = None, @@ -446,17 +446,9 @@ def test_convolution_2d_u85_BI_on_fvp(test_module): def test_reject_convolution_2d_u55_BI( module: Conv2d, ): - ( - ArmTester( - module, - example_inputs=module.get_inputs(), - compile_spec=common.get_u55_compile_spec(), - ) - .quantize() - .export() - .check_count({"torch.ops.aten.conv2d.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge_transform_and_lower() - .check(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) - ) + OpNotSupportedPipeline( + module, + module.get_inputs(), + "TOSA-0.80+BI+u55", + {"executorch_exir_dialects_edge__ops_aten_convolution_default": 1}, + ).run() diff --git a/backends/arm/test/ops/test_conv3d.py b/backends/arm/test/ops/test_conv3d.py new file mode 100644 index 00000000000..22f7e9e7f54 --- /dev/null +++ b/backends/arm/test/ops/test_conv3d.py @@ -0,0 +1,399 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Tuple, Union + +import pytest +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + OpNotSupportedPipeline, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.conv3d.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" + + +class Conv3d(torch.nn.Module): + """ + Creates one or many chained 3D-convolutions. For multiple convolutions, the + respective parameteres are provided as lists. + """ + + def __init__( + self, + height=8, + width=8, + depth=8, + nbr_conv=1, # Number of chained convs + in_channels: Union[List, int, None] = None, + out_channels: Union[List, int, None] = None, + kernel_size: Union[List, Tuple, None] = None, + stride: Union[List, Tuple, int, None] = None, + padding: Union[List, Tuple, int, None] = None, + dilation: Union[List, Tuple, int, None] = None, + groups: Union[List, int, None] = None, + bias: Union[List, bool, None] = None, + padding_mode: Union[List, str, None] = None, + batches=1, + dtype=torch.float, + ): + super().__init__() + self.nbr_convs = nbr_conv + + # Handle default values + in_channels = [2] * nbr_conv if in_channels is None else in_channels + out_channels = [1 * nbr_conv] if out_channels is None else out_channels + kernel_size = [(3, 3, 1)] * nbr_conv if kernel_size is None else kernel_size + stride = [(2, 2, 1)] * nbr_conv if stride is None else stride + padding = [(1, 1, 1)] * nbr_conv if padding is None else padding + dilation = [(1, 1, 1)] * nbr_conv if dilation is None else dilation + groups = [1] * nbr_conv if groups is None else groups + bias = [True] * nbr_conv if bias is None else bias + padding_mode = ["zeros"] * nbr_conv if padding_mode is None else padding_mode + + # This allows the input parameters to be either a single value or a list + # as type hint implies + if not isinstance(in_channels, List): + in_channels = [in_channels] + if not isinstance(out_channels, List): + out_channels = [out_channels] + if not isinstance(kernel_size, List): + kernel_size = [kernel_size] + if not isinstance(stride, List): + stride = [stride] + if not isinstance(padding, List): + padding = [padding] + if not isinstance(dilation, List): + dilation = [dilation] + if not isinstance(groups, List): + groups = [groups] + if not isinstance(bias, List): + bias = [bias] + if not isinstance(padding_mode, List): + padding_mode = [padding_mode] + + self.batches = batches + self.in_channels = in_channels + self.height = height + self.width = width + self.depth = depth + self.dtype = dtype + + # Build chain of convs + for i in range(self.nbr_convs): + setattr( + self, + f"conv_{i}", + torch.nn.Conv3d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=kernel_size[i], + stride=stride[i], + padding=padding[i], + dilation=dilation[i], + groups=groups[i], + bias=bias[i], + padding_mode=padding_mode[i], + ).to(dtype), + ) + + def get_inputs(self): + return ( + torch.randn( + self.batches, self.in_channels[0], self.height, self.width, self.depth + ).to(self.dtype), + ) + + def forward(self, x): + for i in range(self.nbr_convs): + conv = getattr(self, f"conv_{i}") + x = conv(x) + return x + + +conv3d_2x2_3x2x40x40_nobias = Conv3d( + in_channels=2, + out_channels=3, + kernel_size=(2, 2, 2), + stride=1, + bias=False, + padding=0, + width=40, + height=40, + batches=3, +) + +conv3d_3x3_1x3x256x256_st1 = Conv3d( + in_channels=3, + out_channels=10, + kernel_size=(3, 3, 3), + stride=1, + padding=0, + width=256, + height=256, + batches=1, +) + +conv3d_3x3_1x3x12x12_st2_pd1 = Conv3d( + in_channels=3, + out_channels=4, + kernel_size=(3, 3, 3), + stride=2, + padding=1, + width=12, + height=12, + batches=1, +) + +conv3d_1x1_1x2x128x128_st1 = Conv3d( + in_channels=2, + out_channels=1, + kernel_size=(1, 1, 1), + stride=1, + padding=0, + width=128, + height=128, + batches=1, +) + +conv3d_2x2_1x1x14x13_st2 = Conv3d( + in_channels=1, + out_channels=1, + kernel_size=(2, 2, 2), + stride=2, + padding=0, + width=14, + height=13, + batches=1, +) + +conv3d_5x5_3x2x128x128_st1 = Conv3d( + in_channels=2, + out_channels=3, + kernel_size=(5, 5, 5), + stride=1, + padding=0, + width=128, + height=128, + batches=3, +) + +conv3d_3x3_1x3x224x224_st2_pd1 = Conv3d( + in_channels=3, + out_channels=16, + kernel_size=(3, 3, 3), + stride=2, + padding=1, + width=224, + height=224, + batches=1, +) + +conv3d_5x5_1x3x14x15_st3_pd1 = Conv3d( + in_channels=3, + out_channels=16, + kernel_size=(5, 5, 5), + stride=3, + padding=1, + width=14, + height=15, + batches=1, +) + +conv3d_7x7_1x3x16x16_st2_pd1_dl2 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(7, 7, 7), + stride=2, + padding=1, + dilation=2, + width=16, + height=16, + batches=1, +) + +conv3d_7x7_1x3x15x15_st1_pd0_dl1 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(7, 7, 7), + stride=1, + padding=0, + dilation=1, + width=15, + height=15, + batches=1, +) + +conv3d_5x5_1x3x14x14_st5_pd0_dl1 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(5, 5, 5), + stride=5, + padding=0, + dilation=1, + width=14, + height=14, + batches=1, +) + +conv3d_5x5_1x3x9x9_st5_pd0_dl1 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(5, 5, 5), + stride=5, + padding=0, + dilation=1, + width=9, + height=9, + batches=1, +) + +conv3d_3x3_1x3x8x9_st3_pd0_dl1 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(3, 3, 3), + stride=3, + padding=0, + dilation=1, + width=8, + height=9, + batches=1, +) + +conv3d_3x3_1x3x9x8_st3_pd0_dl1 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(3, 3, 3), + stride=3, + padding=0, + dilation=1, + width=8, + height=9, + batches=1, +) + +conv3d_3x4_1x3x7x7_st3_pd0_dl1 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(3, 4, 3), + stride=3, + padding=0, + dilation=1, + width=7, + height=7, + batches=1, +) + +conv3d_4x3_1x3x7x7_st3_pd0_dl1 = Conv3d( + in_channels=3, + out_channels=3, + kernel_size=(4, 3, 3), + stride=3, + padding=0, + dilation=1, + width=7, + height=7, + batches=1, +) + +test_modules = { + "2x2_3x2x40x40_nobias": conv3d_2x2_3x2x40x40_nobias, + "3x3_1x3x256x256_st1": conv3d_3x3_1x3x256x256_st1, + "3x3_1x3x12x12_st2_pd1": conv3d_3x3_1x3x12x12_st2_pd1, + "1x1_1x2x128x128_st1": conv3d_1x1_1x2x128x128_st1, + "2x2_1x1x14x13_st2_needs_adjust_pass": conv3d_2x2_1x1x14x13_st2, + "5x5_1x3x14x15_st3_pd1_needs_adjust_pass": conv3d_5x5_1x3x14x15_st3_pd1, + "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": conv3d_7x7_1x3x16x16_st2_pd1_dl2, + "7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass": conv3d_7x7_1x3x15x15_st1_pd0_dl1, + "5x5_1x3x14x14_st5_pd0_dl1_needs_adjust_pass": conv3d_5x5_1x3x14x14_st5_pd0_dl1, + "5x5_1x3x9x9_st5_pd0_dl1_needs_adjust_pass": conv3d_5x5_1x3x9x9_st5_pd0_dl1, + "3x3_1x3x9x8_st3_pd0_dl1_needs_adjust_pass": conv3d_3x3_1x3x9x8_st3_pd0_dl1, + "3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass": conv3d_3x3_1x3x8x9_st3_pd0_dl1, + "3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": conv3d_3x4_1x3x7x7_st3_pd0_dl1, + "4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": conv3d_4x3_1x3x7x7_st3_pd0_dl1, + "5x5_3x2x128x128_st1": conv3d_5x5_3x2x128x128_st1, + "3x3_1x3x224x224_st2_pd1": conv3d_3x3_1x3x224x224_st2_pd1, +} + +input_t = Tuple[torch.Tensor] + + +@common.parametrize("test_module", test_modules) +@pytest.mark.skip # Not implemented, skip until it is. +def test_convolution_3d_tosa_MI(test_module): + pipeline = TosaPipelineMI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@pytest.mark.skip # Not implemented, skip until it is. +def test_convolution_3d_tosa_BI(test_module): + pipeline = TosaPipelineBI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@pytest.mark.skip # Not implemented, skip until it is. +def test_convolution_3d_u55_BI(test_module): + pipeline = EthosU55PipelineBI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@pytest.mark.skip # Not implemented, skip until it is. +def test_convolution_3d_u85_BI(test_module): + pipeline = EthosU85PipelineBI[input_t]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() + + +reject_suite = { + "large_stride": Conv3d( + in_channels=1, + out_channels=1, + kernel_size=(2, 2, 1), + stride=(2, 4, 2), + padding=1, + width=10, + height=14, + batches=1, + ), + "large_kernel_z": Conv3d( + in_channels=1, + out_channels=1, + kernel_size=(2, 2, 2), + stride=1, + padding=0, + width=80, + height=80, + batches=1, + ), +} + + +@common.parametrize("module", reject_suite) +def test_reject_convolution_3d_u55_BI( + module: Conv3d, +): + OpNotSupportedPipeline( + module, + module.get_inputs(), + "TOSA-0.80+BI+u55", + {"executorch_exir_dialects_edge__ops_aten_convolution_default": 1}, + ).run()