diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 9304e65dc38..ecd98536f4b 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -38,6 +38,118 @@ logger.setLevel(logging.INFO) +class ArmCompileSpecBuilder: + def __init__(self): + self.compile_spec: List[CompileSpec] = [] + self.compiler_flags = [] + self.output_format = None + self.path_for_intermediates = None + self.permute_nhwc = False + + def ethosu_compile_spec( + self, + config: str, + system_config: Optional[str] = None, + memory_mode: Optional[str] = None, + extra_flags: Optional[str] = None, + config_ini: Optional[str] = "Arm/vela.ini", + ): + """ + Generate compile spec for Ethos-U NPU + + Args: + config: Ethos-U accelerator configuration, e.g. ethos-u55-128 + system_config: System configuration to select from the Vel + configuration file + memory_mode: Memory mode to select from the Vela configuration file + extra_flags: Extra flags for the Vela compiler + config_ini: Vela configuration file(s) in Python ConfigParser .ini + file format + """ + assert ( + self.output_format is None + ), f"Output format already set to f{self.output_format}" + self.output_format = "vela" + self.compiler_flags = [ + f"--accelerator-config={config}", + f"--config={config_ini}", + ] + if system_config is not None: + self.compiler_flags.append(f"--system-config={system_config}") + if memory_mode is not None: + self.compiler_flags.append(f"--memory-mode={memory_mode}") + if extra_flags is not None: + self.compiler_flags.append(extra_flags) + + return self + + def tosa_compile_spec(self): + """ + Generate compile spec for TOSA flatbuffer output + """ + assert ( + self.output_format is None + ), f"Output format already set: {self.output_format}" + self.output_format = "tosa" + return self + + def dump_intermediate_tosa(self, output_path: str): + """ + Output intermediate .tosa file + """ + self.path_for_intermediates = output_path + return self + + def set_permute_memory_format(self, set_nhwc_permutation: bool = True): + self.permute_nhwc = set_nhwc_permutation + return self + + def build(self): + """ + Generate a list of compile spec objects from the builder + """ + if self.output_format == "vela": + self.compile_spec += [ + CompileSpec("output_format", "vela".encode()), + CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()), + ] + elif self.output_format == "tosa": + self.compile_spec.append(CompileSpec("output_format", "tosa".encode())) + + if self.path_for_intermediates is not None: + self.compile_spec.append( + CompileSpec("debug_tosa_path", self.path_for_intermediates.encode()) + ) + + if self.permute_nhwc: + self.compile_spec.append( + CompileSpec("permute_memory_format", "nhwc".encode()) + ) + + return self.compile_spec + + +def is_permute_memory(compile_spec: List[CompileSpec]) -> bool: + for spec in compile_spec: + if spec.key == "permute_memory_format": + return spec.value.decode() == "nhwc" + return False + + +def is_tosa(compile_spec: List[CompileSpec]) -> bool: + for spec in compile_spec: + if spec.key == "output_format": + return spec.value.decode() == "tosa" + return False + + +def get_intermediate_path(compile_spec: List[CompileSpec]) -> str: + for spec in compile_spec: + if spec.key == "debug_tosa_path": + return spec.value.decode() + return None + + def generate_ethosu_compile_spec( config: str, permute_memory_to_nhwc: Optional[bool] = None, @@ -46,45 +158,31 @@ def generate_ethosu_compile_spec( extra_flags: Optional[str] = None, config_ini: Optional[str] = "Arm/vela.ini", ) -> List[CompileSpec]: - """ - Generate compile spec for Ethos-U NPU - """ - compiler_flags = [f"--accelerator-config={config}", f"--config={config_ini}"] - if system_config is not None: - compiler_flags.append(f"--system-config={system_config}") - if memory_mode is not None: - compiler_flags.append(f"--memory-mode={memory_mode}") - if extra_flags is not None: - compiler_flags.append(extra_flags) - - compile_spec = [ - CompileSpec("output_format", "vela".encode()), - CompileSpec("compile_flags", " ".join(compiler_flags).encode()), - ] - - if permute_memory_to_nhwc: - compile_spec.append(CompileSpec("permute_memory_format", "nhwc".encode())) - - return compile_spec + return ( + ArmCompileSpecBuilder() + .ethosu_compile_spec( + config, + system_config=system_config, + memory_mode=memory_mode, + extra_flags=extra_flags, + config_ini=config_ini, + ) + .set_permute_memory_format(permute_memory_to_nhwc) + .build() + ) def generate_tosa_compile_spec( permute_memory_to_nhwc: Optional[bool] = None, output_path: Optional[str] = None, ) -> List[CompileSpec]: - """ - Generate compile spec for TOSA flatbuffer output - """ - - compile_spec = [CompileSpec("output_format", "tosa".encode())] - - if permute_memory_to_nhwc: - compile_spec.append(CompileSpec("permute_memory_format", "nhwc".encode())) - - if output_path is not None: - compile_spec.append(CompileSpec("debug_tosa_path", output_path.encode())) - - return compile_spec + return ( + ArmCompileSpecBuilder() + .tosa_compile_spec() + .set_permute_memory_format(permute_memory_to_nhwc) + .dump_intermediate_tosa(output_path) + .build() + ) @final diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index eeccfadf5b9..6a8f4b526d6 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. import shutil +import tempfile + +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder # TODO: fixme! These globs are a temporary workaround. Reasoning: # Running the jobs in _unittest.yml will not work since that environment doesn't @@ -13,3 +16,36 @@ # should be installed in the CI env. TOSA_REF_MODEL_INSTALLED = shutil.which("tosa_reference_model") VELA_INSTALLED = shutil.which("vela") + + +def get_tosa_compile_spec(permute_memory_to_nhwc=False, custom_path=None): + """ + Default compile spec for TOSA tests. + """ + intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_") + compile_spec = ( + ArmCompileSpecBuilder() + .tosa_compile_spec() + .set_permute_memory_format(permute_memory_to_nhwc) + .dump_intermediate_tosa(intermediate_path) + .build() + ) + return compile_spec + + +def get_u55_compile_spec(permute_memory_to_nhwc=False): + """ + Default compile spec for Ethos-U55 tests. + """ + compile_spec = ( + ArmCompileSpecBuilder() + .ethosu_compile_spec( + "ethos-u55-128", + system_config="Ethos_U55_High_End_Embedded", + memory_mode="Shared_Sram", + extra_flags=None, + ) + .set_permute_memory_format(permute_memory_to_nhwc) + .build() + ) + return compile_spec diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 6bb70d98c59..df55e5253e2 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -10,8 +10,9 @@ import unittest import torch -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.arm_tester import ArmTester logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -45,8 +46,7 @@ def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, inputs=module.get_inputs(), - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, + compile_spec=common.get_tosa_compile_spec(), ) .export() .to_edge() @@ -60,8 +60,7 @@ def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, inputs=module.get_inputs(), - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, + compile_spec=common.get_tosa_compile_spec(), ) .quantize() .export() diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index ded3b6f6919..ee136d1b480 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -11,8 +11,8 @@ import torch import torchvision.models as models from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester + +from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize from torchvision.models.mobilenetv2 import MobileNet_V2_Weights @@ -46,9 +46,7 @@ def test_mv2_tosa_MI(self): ArmTester( self.mv2, inputs=self.model_inputs, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .export() .to_edge() @@ -62,9 +60,7 @@ def test_mv2_tosa_BI(self): ArmTester( self.mv2, inputs=self.model_inputs, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .quantize(Quantize(calibrate=False)) .export() @@ -74,7 +70,7 @@ def test_mv2_tosa_BI(self): .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -89,9 +85,7 @@ def test_mv2_u55_BI(self): ArmTester( self.mv2, inputs=self.model_inputs, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, - permute_memory_to_nhwc=True, + compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), ) .quantize(Quantize(calibrate=False)) .export() diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 091e4b27ff5..f77feec586f 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -12,8 +12,8 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester + +from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized logger = logging.getLogger(__name__) @@ -57,9 +57,7 @@ def _test_add_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=False, + compile_spec=common.get_tosa_compile_spec(), ) .export() .check_count({"torch.ops.aten.add.Tensor": 1}) @@ -70,7 +68,7 @@ def _test_add_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -83,9 +81,7 @@ def _test_add_tosa_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=False, + compile_spec=common.get_tosa_compile_spec(), ) .quantize() .export() @@ -98,7 +94,7 @@ def _test_add_tosa_BI_pipeline( ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(qtol=1) + tester.run_method_and_compare_outputs(qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -111,8 +107,7 @@ def _test_add_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, + compile_spec=common.get_u55_compile_spec(), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index 6e0aa8cbad8..259e8e11809 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -12,8 +12,7 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester +from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized logger = logging.getLogger(__name__) @@ -51,9 +50,7 @@ def _test_avgpool2d_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .export() .check(["torch.ops.aten.avg_pool2d.default"]) @@ -65,7 +62,7 @@ def _test_avgpool2d_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -78,9 +75,7 @@ def _test_avgpool2d_tosa_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() @@ -93,7 +88,7 @@ def _test_avgpool2d_tosa_BI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(qtol=1) + tester.run_method_and_compare_outputs(qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -106,9 +101,7 @@ def _test_avgpool2d_tosa_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, - permute_memory_to_nhwc=True, + compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv.py b/backends/arm/test/ops/test_conv.py index 4f94412bf25..1f0ef7bc293 100644 --- a/backends/arm/test/ops/test_conv.py +++ b/backends/arm/test/ops/test_conv.py @@ -11,8 +11,8 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester + +from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized logger = logging.getLogger(__name__) @@ -248,9 +248,7 @@ def _test_conv2d_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .export() .to_edge() @@ -260,7 +258,7 @@ def _test_conv2d_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -275,9 +273,7 @@ def _test_conv2d_tosa_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() @@ -288,7 +284,7 @@ def _test_conv2d_tosa_BI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(qtol=1) + tester.run_method_and_compare_outputs(qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -301,9 +297,7 @@ def _test_conv2d_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, - permute_memory_to_nhwc=True, + compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index c96a2920e2e..1fd68493790 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -11,8 +11,7 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester +from executorch.backends.arm.test.tester.arm_tester import ArmTester logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -135,9 +134,7 @@ def _test_conv_combo_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .export() .to_edge() @@ -147,7 +144,7 @@ def _test_conv_combo_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -164,9 +161,7 @@ def _test_conv_combo_tosa_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() @@ -177,7 +172,7 @@ def _test_conv_combo_tosa_BI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(atol=atol, rtol=rtol, qtol=1) + tester.run_method_and_compare_outputs(atol=atol, rtol=rtol, qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -190,9 +185,7 @@ def _test_conv_combo_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, - permute_memory_to_nhwc=True, + compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 006a8567387..0901a49293b 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -12,8 +12,8 @@ import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.ops.test_conv import Conv2d -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester + +from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized logger = logging.getLogger(__name__) @@ -134,9 +134,7 @@ def _test_dw_conv2d_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .export() .to_edge() @@ -146,7 +144,7 @@ def _test_dw_conv2d_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -159,9 +157,7 @@ def _test_dw_conv2d_tosa_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() @@ -172,7 +168,7 @@ def _test_dw_conv2d_tosa_BI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(qtol=1) + tester.run_method_and_compare_outputs(qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -185,9 +181,7 @@ def _test_dw_conv2d_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, - permute_memory_to_nhwc=True, + compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index afb5a57b2d1..6313db3a167 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -12,8 +12,8 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester + +from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized logger = logging.getLogger(__name__) @@ -114,8 +114,7 @@ def _test_linear_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, + compile_spec=common.get_tosa_compile_spec(), ) .export() .check_count({"torch.ops.aten.addmm.default": 1}) @@ -126,7 +125,7 @@ def _test_linear_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -139,8 +138,7 @@ def _test_linear_tosa_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, + compile_spec=common.get_tosa_compile_spec(), ) .quantize() .export() @@ -152,7 +150,7 @@ def _test_linear_tosa_BI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(qtol=True) + tester.run_method_and_compare_outputs(qtol=True) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -165,8 +163,7 @@ def _test_linear_tosa_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, + compile_spec=common.get_u55_compile_spec(), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 95bc8c4babc..79ba3de7dad 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -12,8 +12,7 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester +from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized logger = logging.getLogger(__name__) @@ -56,9 +55,7 @@ def _test_meandim_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .export() .check(["torch.ops.aten.mean.dim"]) @@ -70,7 +67,7 @@ def _test_meandim_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -83,9 +80,7 @@ def _test_meandim_tosa_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, - permute_memory_to_nhwc=True, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() @@ -98,7 +93,7 @@ def _test_meandim_tosa_BI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(qtol=1) + tester.run_method_and_compare_outputs(qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -111,9 +106,7 @@ def _test_meandim_tosa_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, - permute_memory_to_nhwc=True, + compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index 5d5766f019b..32bd2253464 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -12,8 +12,7 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester +from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized logger = logging.getLogger(__name__) @@ -44,8 +43,7 @@ def _test_softmax_tosa_MI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.MI, - backend=ArmBackendSelector.TOSA, + compile_spec=common.get_tosa_compile_spec(), ) .export() .check(["torch.ops.aten._softmax.default"]) @@ -57,7 +55,7 @@ def _test_softmax_tosa_MI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method_and_compare_outputs() else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -68,10 +66,7 @@ def _test_softmax_tosa_BI_pipeline( ): tester = ( ArmTester( - module, - inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.TOSA, + module, inputs=test_data, compile_spec=common.get_tosa_compile_spec() ) .quantize() .export() @@ -84,7 +79,7 @@ def _test_softmax_tosa_BI_pipeline( .to_executorch() ) if common.TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs(qtol=1) + tester.run_method_and_compare_outputs(qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -97,8 +92,7 @@ def _test_softmax_tosa_u55_BI_pipeline( ArmTester( module, inputs=test_data, - profile=TosaProfile.BI, - backend=ArmBackendSelector.ETHOS_U55, + compile_spec=common.get_u55_compile_spec(), ) .quantize() .export() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 2d0816a2943..dceae602716 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -3,17 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum from typing import List, Optional, Tuple, Union import numpy as np import torch + from executorch.backends.arm.arm_backend import ( - generate_ethosu_compile_spec, - generate_tosa_compile_spec, + get_intermediate_path, + is_permute_memory, + is_tosa, ) - from executorch.backends.arm.arm_partitioner import ArmPartitioner from executorch.backends.arm.arm_quantizer import ( ArmQuantizer, @@ -22,7 +22,6 @@ from executorch.backends.arm.test.tosautil.tosa_test_utils import ( QuantizationParams, - TosaProfile, TosaTestUtils, ) @@ -35,12 +34,89 @@ ) from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export import ExportedProgram -class ArmBackendSelector(Enum): - TOSA = "tosa" - ETHOS_U55 = "ethos-u55" +def _get_input_params( + program: ExportedProgram, is_quantized: bool +) -> Tuple[str, Union[List[QuantizationParams], List[None]]]: + """ + Get name and optionally quantization parameters for the inputs to this + model. + + Args: + program (ExportedProgram): The program to get input parameters from + Returns: + Tuple[str, Optional[QuantizationParams]]: A tuple containing the + input node names and their quantization parameters. + """ + input_names = [] + # E.g. bias and weights are 'placeholders' as well. This is used to + # get only the use inputs. + usr_inputs = program.graph_signature.user_inputs + for node in program.graph.nodes: + if node.op == "placeholder" and node.name in usr_inputs: + input_names.append(node.name) + continue + + if is_quantized: + quant_params = [] + for node in program.graph.nodes: + if ( + node.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + and node.args[0].name in input_names + ): + qp = QuantizationParams( + node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] + ) + quant_params.append(qp) + if len(quant_params) == len( + input_names + ): # break early if we have all the inputs quantized parameters + break + assert len(quant_params) != 0, "Quantization paramerters not found" + return (input_names, quant_params) + else: + return (input_names, len(input_names) * [None]) # return a list of None's + + +def _get_output_param( + program: ExportedProgram, is_quantized: bool +) -> Tuple[str, Union[QuantizationParams, None]]: + """ + Get name and optionally quantization parameters for the inputs to this + model. + + Args: + program (ExportedProgram): The program to get output parameters from. + Returns: + Tuple[str, Optional[QuantizationParams]]: A tuple containing the + output node name and its quantization parameters. + """ + output_node = None + for node in program.graph.nodes: + if node.op == "output": + output_node = node + break + + if is_quantized: + quant_params = None + for node in program.graph.nodes: + if ( + node.target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + and node == output_node.args[0][0] + ): + quant_params = QuantizationParams( + node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] + ) + break # break early, there's only one output node + assert quant_params is not None, "Quantization paramerters not found" + return (output_node.name, quant_params) + else: + return (output_node.name, None) class Partition(Partition): @@ -75,44 +151,23 @@ def __init__( self, model: torch.nn.Module, inputs: Tuple[torch.Tensor], - backend: ArmBackendSelector = ArmBackendSelector.TOSA, - profile: TosaProfile = TosaProfile.BI, - permute_memory_to_nhwc: bool = False, + compile_spec: List[CompileSpec] = None, ): """ Args: model (torch.nn.Module): The model to test inputs (Tuple[torch.Tensor]): The inputs to the model - backend (ArmBackendSelector): The backend to use. E.g. TOSA or - ETHOS_U55. - TOSA: Lower to TOSA and test numerical correctness compared to - torch reference. - ETHOS_U55: Lower to TOSA, then let Vela compile. Only - functional test, no numerical checks. - profile (TosaProfile): The TOSA profile to use. Either - TosaProfile.BI or TosaProfile.MI - permute_memory_to_nhwc (bool) : flag for enabling the memory format - permutation to nhwc as required by TOSA + compile_spec (List[CompileSpec]): The compile spec to use """ + + # Use the TosaTestUtils if you are using a TOSA backend self.tosa_test_util = None - self.is_quantized = profile == TosaProfile.BI - self.permute_memory_to_nhwc = permute_memory_to_nhwc - - if backend == ArmBackendSelector.TOSA: - self.tosa_test_util = TosaTestUtils(profile=profile) - # The spec below tiggers arm_backend.py to output two files: - # 1) output.tosa - # 2) desc.json - # Saved on disk in self.tosa_test_util.intermediate_path - self.compile_spec = generate_tosa_compile_spec( - permute_memory_to_nhwc, self.tosa_test_util.intermediate_path - ) - elif backend == ArmBackendSelector.ETHOS_U55: - self.compile_spec = generate_ethosu_compile_spec( - config="ethos-u55-128", permute_memory_to_nhwc=permute_memory_to_nhwc - ) - else: - raise ValueError(f"Unknown backend: {backend}") + if is_tosa(compile_spec): + intermediate_path = get_intermediate_path(compile_spec) + self.tosa_test_util = TosaTestUtils(intermediate_path=intermediate_path) + + self.compile_spec = compile_spec + super().__init__(model, inputs) def quantize(self, quantize_stage: Optional[Quantize] = None): @@ -134,8 +189,14 @@ def partition(self, partition_stage: Optional[Partition] = None): partition_stage = Partition(arm_partitioner) return super().partition(partition_stage) - def run_method( - self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None + def run_method_and_compare_outputs( + self, + stage: Optional[str] = None, + inputs: Optional[Tuple[torch.Tensor]] = None, + num_runs=1, + atol=1e-03, + rtol=1e-03, + qtol=0, ): """ This function runs the tosa_reference_model tool to get output data @@ -150,130 +211,74 @@ def run_method( Todo: * A lot of the stuff in this method should be broken out into a run_artifact() method on a ToExecutorch stage class. - * See "TODO" inline below """ assert ( self.tosa_test_util is not None ), "self.tosa_test_util is not initialized, cannot use run_method()" - inputs_to_run = inputs or self.inputs + + number_of_runs = 1 if inputs is not None else num_runs + stage = stage or self.cur export_stage = self.stages[self.stage_name(Export)] - (input_names, qp_input) = self._get_input_params(export_stage.artifact) - (output_name, qp_output) = self._get_output_param(export_stage.artifact) + is_quantized = self.stages["Quantize"] is not None + (input_names, qp_input) = _get_input_params(export_stage.artifact, is_quantized) + (output_name, qp_output) = _get_output_param( + export_stage.artifact, is_quantized + ) # Calculate the reference output using the original module or the quant - # module. self.quantization_scale is used by compare_outputs() to - # calculate the tolerance - self.quantization_scale = None if qp_output is None else qp_output.scale - if self.is_quantized: - module_for_ref = self.stages[self.stage_name(Quantize)].artifact + # module. + quantization_scale = None + if is_quantized: + quantization_scale = qp_output.scale + quantize_stage = self.stages[self.stage_name(Quantize)] + module_for_ref = quantize_stage.artifact + print(f"Comparing Stage {stage} with Stage {quantize_stage}") else: module_for_ref = self.original_module - self.reference_output = self._calculate_reference_output( - module_for_ref, inputs_to_run - ) - - # Transpose input data which is on NCHW format to NHWC format, - if self.permute_memory_to_nhwc and len(inputs_to_run[0].shape) == 4: - NHWC_Order = (0, 2, 3, 1) - inputs_to_run = (np.transpose(inputs_to_run[0], NHWC_Order),) - - # Run the TOSA ref model to get the output tensor, which will be - # compared to the torch output in compare_outputs() - tosa_output = self.tosa_test_util.run_tosa_ref_model( - params_input=(input_names, qp_input), - param_output=(output_name, qp_output), - inputs=inputs_to_run, - ) - - # Transpose back to NCHW format for comparison to torch output - if self.permute_memory_to_nhwc and len(tosa_output.shape) == 4: - NCHW_Order = (0, 3, 1, 2) - tosa_output = (np.transpose(tosa_output, NCHW_Order),) - - self.stage_output = tosa_output + print(f"Comparing Stage {stage} with original module") + + # Loop inputs and compare TOSA ref model output with Torch reference + # for each loop iteration. + for run_iteration in range(number_of_runs): + inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) + input_shapes = [generated_input.shape for generated_input in inputs_to_run] + print(f"Run {run_iteration} with input shapes: {input_shapes}") + + # Get Torch reference data... + reference_output = self._calculate_reference_output( + module_for_ref, inputs_to_run + ) - return self + # ...now get TOSA ref model data + # Transpose input data which is on NCHW format to NHWC format, + is_nhwc = is_permute_memory(self.compile_spec) + if is_nhwc and len(inputs_to_run[0].shape) == 4: + NHWC_Order = (0, 2, 3, 1) + inputs_to_run = (np.transpose(inputs_to_run[0], NHWC_Order),) + + # Run the TOSA ref model to get the output tensor, which will be + # compared to the torch output in compare_outputs() + tosa_output = self.tosa_test_util.run_tosa_ref_model( + params_input=(input_names, qp_input), + param_output=(output_name, qp_output), + inputs=inputs_to_run, + ) - def _get_input_params( - self, program: ExportedProgram - ) -> Tuple[str, Union[List[QuantizationParams], List[None]]]: - """ - Get name and optionally quantization parameters for the inputs to this - model. + # Transpose back to NCHW format for comparison to torch output + if is_nhwc and len(tosa_output.shape) == 4: + NCHW_Order = (0, 3, 1, 2) + tosa_output = (np.transpose(tosa_output, NCHW_Order),) - Args: - program (ExportedProgram): The program to get input parameters from - Returns: - Tuple[str, Optional[QuantizationParams]]: A tuple containing the - input node names and their quantization parameters. - """ - input_names = [] - # E.g. bias and weights are 'placeholders' as well. This is used to - # get only the use inputs. - usr_inputs = program.graph_signature.user_inputs - for node in program.graph.nodes: - if node.op == "placeholder" and node.name in usr_inputs: - input_names.append(node.name) - continue - - if self.is_quantized: - quant_params = [] - for node in program.graph.nodes: - if ( - node.target - == torch.ops.quantized_decomposed.quantize_per_tensor.default - and node.args[0].name in input_names - ): - qp = QuantizationParams( - node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] - ) - quant_params.append(qp) - if len(quant_params) == len( - input_names - ): # break early if we have all the inputs quantized parameters - break - assert len(quant_params) != 0, "Quantization paramerters not found" - return (input_names, quant_params) - else: - return (input_names, len(input_names) * [None]) # return a list of None's + stage_output = tosa_output - def _get_output_param( - self, program: ExportedProgram - ) -> Tuple[str, Union[QuantizationParams, None]]: - """ - Get name and optionally quantization parameters for the inputs to this - model. - - Args: - program (ExportedProgram): The program to get output parameters from. - Returns: - Tuple[str, Optional[QuantizationParams]]: A tuple containing the - output node name and its quantization parameters. - """ - output_node = None - for node in program.graph.nodes: - if node.op == "output": - output_node = node - break + # Output from running artifact at stage + self._compare_outputs( + reference_output, stage_output, quantization_scale, atol, rtol, qtol + ) - if self.is_quantized: - quant_params = None - for node in program.graph.nodes: - if ( - node.target - == torch.ops.quantized_decomposed.dequantize_per_tensor.default - and node == output_node.args[0][0] - ): - quant_params = QuantizationParams( - node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] - ) - break # break early, there's only one output node - assert quant_params is not None, "Quantization paramerters not found" - return (output_node.name, quant_params) - else: - return (output_node.name, None) + return self @staticmethod def _calculate_reference_output( diff --git a/backends/arm/test/tosautil/tosa_test_utils.py b/backends/arm/test/tosautil/tosa_test_utils.py index 2ac9794fe89..df9b1455281 100644 --- a/backends/arm/test/tosautil/tosa_test_utils.py +++ b/backends/arm/test/tosautil/tosa_test_utils.py @@ -15,8 +15,6 @@ import numpy as np import torch -from executorch.backends.arm.test.test_models import TosaProfile - logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -39,15 +37,11 @@ def __init__(self, node_name: str, zp: int, scale: float): class TosaTestUtils: def __init__( self, - intermediate_path: Optional[str] = None, + intermediate_path: str, tosa_ref_model_path: Optional[str] = None, - profile: Optional[TosaProfile] = None, ): - self.intermediate_path = intermediate_path or tempfile.mkdtemp( - prefix="arm_tosa_" - ) + self.intermediate_path = intermediate_path self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" - self.profile = profile or TosaProfile.MI assert os.path.exists( self.intermediate_path ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" @@ -145,7 +139,7 @@ def run_tosa_ref_model( as a next step. See: https://review.mlplatform.org/plugins/gitiles/tosa/reference_model/#executable-usage """ - + is_quantized = param_output[1] is not None desc_file_path = os.path.join(self.intermediate_path, "desc.json") assert os.path.exists( desc_file_path @@ -159,7 +153,7 @@ def run_tosa_ref_model( ): data_np = data.detach().numpy() - if self.profile is TosaProfile.BI: + if is_quantized: assert ( quant_param.node_name == input_name ), "These quantization params do not match the input tensor name" @@ -190,7 +184,7 @@ def run_tosa_ref_model( # Load the output file (OFM) and return it as a numpy array tosa_ref_output = np.load(ofm_file_npy) - if self.profile is TosaProfile.BI: + if is_quantized: # Need to dequant back to FP32 for comparison with torch output quant_param = param_output[1] assert (