From 65126f34639b38f745a6a3b4eb327cc3141908fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 22 May 2025 13:30:08 +0200 Subject: [PATCH 1/2] Arm backend: Add a support class for handling maping of TOSA spec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a support class for mapping TOSA specs to values in order to handle combinations of +INT and +FP profiles. As an intemittent step allow for registration of same class for same target, but for different TOSA specification and emits a warning. Signed-off-by: Per Åstrand Change-Id: Ic379725af3b8f32dc4849e4771017a5e05a6fe6f --- .../tosa_supported_operators.py | 24 +++--- backends/arm/operators/node_visitor.py | 39 ++++++--- backends/arm/operators/op_index_tensor.py | 1 - backends/arm/tosa/specification.py | 81 ++++++++++++++++++- 4 files changed, 123 insertions(+), 22 deletions(-) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 737c71a7039..e4050f1dc49 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -39,8 +39,11 @@ TOSA_PRO_FP_SupportList, TOSA_PRO_INT_SupportList, ) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.specification import Tosa_1_00 +from executorch.backends.arm.tosa.specification import ( + Tosa_1_00, + TosaSpecification, + TosaSpecMapping, +) from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -116,10 +119,9 @@ def is_node_tosa_supported( # container for all SupportedTosaOperatorCheck classes -_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { - TosaSpecification.create_from_string("TOSA-1.0+INT"): [], - TosaSpecification.create_from_string("TOSA-1.0+FP"): [], -} +_tosa_spec_support: TosaSpecMapping[Type[SupportedTOSAOperatorCheck]] = ( + TosaSpecMapping() +) def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): @@ -134,7 +136,7 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): """ for tosa_spec in checker.tosa_specs: - _tosa_spec_support[tosa_spec].append(checker) + _tosa_spec_support.add(tosa_spec, checker) return checker @@ -150,12 +152,12 @@ def get_registered_tosa_support_checks( list[Type[SupportedTOSAOperatorCheck]]: Registered checker classes. """ - if tosa_spec not in _tosa_spec_support: + checks = _tosa_spec_support.get(tosa_spec) + if not checks: raise RuntimeError( - f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}" + f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support._mapping.keys())}" ) - - return _tosa_spec_support[tosa_spec] + return checks def tosa_support_factory( diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index c03c27574b8..c54ae67e541 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -12,6 +12,8 @@ """ import json + +import logging from typing import Any, Dict, List, Optional import torch @@ -20,9 +22,14 @@ from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.arm.tosa.specification import ( + TosaSpecification, + TosaSpecMapping, +) from torch.export import ExportedProgram +logger = logging.getLogger(__name__) + class NodeVisitor: """Provide a visitor pattern to lower edge IR to TOSA. @@ -125,23 +132,31 @@ def define_node( # container for all node visitors -_node_visitor_dicts: Dict[TosaSpecification, Dict] = { - TosaSpecification.create_from_string("TOSA-1.0+INT"): {}, - TosaSpecification.create_from_string("TOSA-1.0+FP"): {}, -} +_node_visitor_tuples: TosaSpecMapping[tuple] = TosaSpecMapping() def register_node_visitor(visitor): """Register a concrete ``NodeVisitor`` class for its TOSA specs.""" for tosa_spec in visitor.tosa_specs: - _node_visitor_dicts[tosa_spec][visitor.target] = visitor + # Try to get the tuple to make sure it doesn't exist + visitor_tuple = (visitor.target, visitor) + try: + tuples = _node_visitor_tuples.get(tosa_spec) + except KeyError: + tuples = [] + + if visitor_tuple in tuples: + raise RuntimeError( + f"Visitor for target {visitor.target} already registered for TOSA spec {tosa_spec}" + ) + _node_visitor_tuples.add(tosa_spec, visitor_tuple) return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: """Return a mapping from target names to visitor instances for a spec.""" - node_visitors = {} - tosa_spec = None + node_visitors: Dict[str, NodeVisitor] = {} + tosa_spec: TosaSpecification | None = None for arg in args: if isinstance(arg, TosaSpecification): tosa_spec = arg @@ -150,7 +165,13 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]: if tosa_spec is None: raise RuntimeError("No TOSA specification supplied.") - for target, visitor in _node_visitor_dicts[tosa_spec].items(): + # Use the mapping to get the dict for this spec (handles combined specs) + for node_visitor_tuple in _node_visitor_tuples.get(tosa_spec): + target, visitor = node_visitor_tuple + if target in node_visitors and node_visitors[target].__class__ != visitor: + logger.warning( + f"Target {target} already has visitor class {node_visitors[target].__class__.__name__} registered, overwriting with class: {visitor.__name__}" + ) node_visitors[target] = visitor(*args) return node_visitors diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index b2adb785ef6..710b5f8e1d8 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -24,7 +24,6 @@ from torch.fx import Node -@register_node_visitor class CommonIndexTensorVisitor(NodeVisitor): target = "aten.index.Tensor" diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index 7afa7d9f0de..6fca2163d41 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -12,10 +12,71 @@ import contextvars import re -from typing import List +from typing import Dict, Generic, List, Set, TypeVar from packaging.version import Version +T = TypeVar("T") + + +class TosaSpecMapping(Generic[T]): + def __init__(self): + self._mapping: Dict[TosaSpecification, List[T]] = {} + + def add(self, spec: "TosaSpecification", value: T) -> None: + """ + Adds a value to the mapping for the given TOSA specification. + The specification is normalized to its canonical form, which means that + only the version and profiles are considered, without extensions. + This allows for grouping of values under the same TOSA specification + regardless of the extensions they may have. + """ + + if spec.is_U55_subset or spec.extensions: + raise ValueError( + f"TosaSpecMapping does not support extensions, got: {spec}" + ) + + if isinstance(spec, Tosa_1_00) and len(spec.profiles) > 1: + raise ValueError( + f"TosaSpecMapping does not support multiple profiles, got: {spec}" + ) + + norm_spec = spec._canonical_key() + if norm_spec not in self._mapping: + self._mapping[norm_spec] = [] + self._mapping[norm_spec].append(value) + + @staticmethod + def _get_base_specs(spec: "TosaSpecification") -> List["TosaSpecification"]: + # Handles combined TOSA-1.0+FP+INT, etc. + if isinstance(spec, Tosa_1_00): + profiles: Set[str] = set(spec.profiles) + if profiles == {"FP", "INT"}: + version = spec.version + return [ + TosaSpecification.create_from_string(f"TOSA-{version}+FP"), + TosaSpecification.create_from_string(f"TOSA-{version}+INT"), + ] + return [spec] + + def get(self, spec: "TosaSpecification") -> List[T]: + """ + Returns a list of values associated with the given TOSA specification. + The specification is normalized to its canonical form, which means that + only the version and profiles are considered, without extensions. + """ + + base_specs = self._get_base_specs(spec) + result: List[T] = [] + for base in base_specs: + norm_base = base._canonical_key() + result.extend(self._mapping.get(norm_base, [])) + if len(result) == 0: + raise KeyError(f"No values found for TOSA specification: {spec}") + + return result # Do not deduplicate with set(), as values may be unhashable + class TosaSpecification: """Represent a TOSA specification. @@ -34,6 +95,7 @@ class TosaSpecification: version: Version is_U55_subset: bool + extensions: List[str] def support_integer(self) -> bool: """Return True if integer operations are supported.""" @@ -52,6 +114,7 @@ def __init__(self, version: Version, extras: List[str]): """ self.version = version + self.extensions = [] self.is_U55_subset = "u55" in extras if self.is_U55_subset: @@ -89,6 +152,12 @@ def create_from_string(repr: str) -> "TosaSpecification": raise ValueError(f"Failed to parse TOSA specification representation: {repr}") + def _canonical_key(self) -> "TosaSpecification": + """ + Returns a new TosaSpecification instance with only version and profiles (no extensions). + """ + raise NotImplementedError + class Tosa_1_00(TosaSpecification): """Provide TOSA 1.00 profile and extension semantics. @@ -232,6 +301,16 @@ def support_extension(self, extension: str) -> bool: return False + def _canonical_key(self) -> "Tosa_1_00": + """ + Returns a new Tosa_1_00 instance with only major.minor version and profiles (no extensions). + Patch version is set to zero for normalization. + """ + from packaging.version import Version + + norm_version = Version(f"{self.version.major}.{self.version.minor}.0") + return Tosa_1_00(norm_version, self.profiles.copy()) + class TosaLoweringContext: """Manage the TOSA specification context for lowering. From 8131ea884a2d2488e575bc4e9a6cf744e8fa59a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 9 Oct 2025 12:02:01 +0200 Subject: [PATCH 2/2] Arm backend: Add testcases for TosaSpecMapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: If7a868ec4a1d10c6d7c9307eccc25b56e2d84c5c --- backends/arm/test/misc/test_tosa_spec.py | 103 ++++++++++++++++++++++- 1 file changed, 102 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py index 190c50f4aa1..91a5bc19728 100644 --- a/backends/arm/test/misc/test_tosa_spec.py +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -5,7 +5,11 @@ import unittest -from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification +from executorch.backends.arm.tosa.specification import ( + Tosa_1_00, + TosaSpecification, + TosaSpecMapping, +) from parameterized import parameterized # type: ignore[import-untyped] @@ -66,3 +70,100 @@ def test_correct_string_representation(self, version_string: str): tosa_spec = TosaSpecification.create_from_string(version_string) assert isinstance(tosa_spec, Tosa_1_00) assert f"{tosa_spec}" == version_string + + +class TestTosaSpecMapping(unittest.TestCase): + """Tests the TosaSpecMapping class""" + + def test_mapping(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + # check that the mapping is correct + vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + assert vals == ["A"] + assert len(vals) == 1 + + def test_mapping_multiple(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B") + # check that the mapping is correct + vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + assert vals == ["A", "B"] + assert len(vals) == 2 + + def test_mapping_different_profiles(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B") + # check that the mapping is correct + vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP")) + + assert vals_int == ["A"] + assert vals_fp == ["B"] + assert len(vals_int) == 1 + assert len(vals_fp) == 1 + + def test_mapping_different_profiles_combined_consumer(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B") + # check that the mapping is correct + combined_vals = mapping.get( + TosaSpecification.create_from_string("TOSA-1.0+INT+FP") + ) + + assert "A" in combined_vals + assert "B" in combined_vals + assert len(combined_vals) == 2 + + def test_mapping_no_spec(self): + mapping = TosaSpecMapping() + with self.assertRaises(KeyError): + mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + def test_mapping_no_values_for_spec(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A") + with self.assertRaises(KeyError): + mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + def test_spec_with_different_profiles(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B") + # check that the mapping is correct + vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP")) + vals_int_fp = mapping.get( + TosaSpecification.create_from_string("TOSA-1.0+INT+FP") + ) + + assert vals_fp == ["A"] + assert vals_int == ["B"] + assert len(vals_int) == 1 + assert len(vals_fp) == 1 + assert len(vals_int_fp) == 2 + + def test_combined_profiles(self): + mapping = TosaSpecMapping() + with self.assertRaises(ValueError): + # Don't allow multiple profiles in a single spec + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT+FP"), "A") + + def test_spec_add_with_extension(self): + mapping = TosaSpecMapping() + with self.assertRaises(ValueError): + mapping.add( + TosaSpecification.create_from_string("TOSA-1.0.0+INT+int16"), "A" + ) + + def test_spec_non_canonical_key(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + + val = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT+u55")) + assert val == ["A"]