Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@
TOSA_PRO_FP_SupportList,
TOSA_PRO_INT_SupportList,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.specification import Tosa_1_00
from executorch.backends.arm.tosa.specification import (
Tosa_1_00,
TosaSpecification,
TosaSpecMapping,
)
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -116,10 +119,9 @@ def is_node_tosa_supported(


# container for all SupportedTosaOperatorCheck classes
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
}
_tosa_spec_support: TosaSpecMapping[Type[SupportedTOSAOperatorCheck]] = (
TosaSpecMapping()
)


def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
Expand All @@ -134,7 +136,7 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):

"""
for tosa_spec in checker.tosa_specs:
_tosa_spec_support[tosa_spec].append(checker)
_tosa_spec_support.add(tosa_spec, checker)
return checker


Expand All @@ -150,12 +152,12 @@ def get_registered_tosa_support_checks(
list[Type[SupportedTOSAOperatorCheck]]: Registered checker classes.

"""
if tosa_spec not in _tosa_spec_support:
checks = _tosa_spec_support.get(tosa_spec)
if not checks:
raise RuntimeError(
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support._mapping.keys())}"
)

return _tosa_spec_support[tosa_spec]
return checks


def tosa_support_factory(
Expand Down
39 changes: 30 additions & 9 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
"""

import json

import logging
from typing import Any, Dict, List, Optional

import torch
Expand All @@ -20,9 +22,14 @@
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.debug.schema import DebugHook
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.backends.arm.tosa.specification import (
TosaSpecification,
TosaSpecMapping,
)
from torch.export import ExportedProgram

logger = logging.getLogger(__name__)


class NodeVisitor:
"""Provide a visitor pattern to lower edge IR to TOSA.
Expand Down Expand Up @@ -125,23 +132,31 @@ def define_node(


# container for all node visitors
_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
TosaSpecification.create_from_string("TOSA-1.0+INT"): {},
TosaSpecification.create_from_string("TOSA-1.0+FP"): {},
}
_node_visitor_tuples: TosaSpecMapping[tuple] = TosaSpecMapping()


def register_node_visitor(visitor):
"""Register a concrete ``NodeVisitor`` class for its TOSA specs."""
for tosa_spec in visitor.tosa_specs:
_node_visitor_dicts[tosa_spec][visitor.target] = visitor
# Try to get the tuple to make sure it doesn't exist
visitor_tuple = (visitor.target, visitor)
try:
tuples = _node_visitor_tuples.get(tosa_spec)
except KeyError:
tuples = []

if visitor_tuple in tuples:
raise RuntimeError(
f"Visitor for target {visitor.target} already registered for TOSA spec {tosa_spec}"
)
_node_visitor_tuples.add(tosa_spec, visitor_tuple)
return visitor


def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
"""Return a mapping from target names to visitor instances for a spec."""
node_visitors = {}
tosa_spec = None
node_visitors: Dict[str, NodeVisitor] = {}
tosa_spec: TosaSpecification | None = None
for arg in args:
if isinstance(arg, TosaSpecification):
tosa_spec = arg
Expand All @@ -150,7 +165,13 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
if tosa_spec is None:
raise RuntimeError("No TOSA specification supplied.")

for target, visitor in _node_visitor_dicts[tosa_spec].items():
# Use the mapping to get the dict for this spec (handles combined specs)
for node_visitor_tuple in _node_visitor_tuples.get(tosa_spec):
target, visitor = node_visitor_tuple
if target in node_visitors and node_visitors[target].__class__ != visitor:
logger.warning(
f"Target {target} already has visitor class {node_visitors[target].__class__.__name__} registered, overwriting with class: {visitor.__name__}"
)
node_visitors[target] = visitor(*args)

return node_visitors
1 change: 0 additions & 1 deletion backends/arm/operators/op_index_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from torch.fx import Node


@register_node_visitor
class CommonIndexTensorVisitor(NodeVisitor):
target = "aten.index.Tensor"

Expand Down
103 changes: 102 additions & 1 deletion backends/arm/test/misc/test_tosa_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import unittest

from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
from executorch.backends.arm.tosa.specification import (
Tosa_1_00,
TosaSpecification,
TosaSpecMapping,
)

from parameterized import parameterized # type: ignore[import-untyped]

Expand Down Expand Up @@ -66,3 +70,100 @@ def test_correct_string_representation(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_1_00)
assert f"{tosa_spec}" == version_string


class TestTosaSpecMapping(unittest.TestCase):
"""Tests the TosaSpecMapping class"""

def test_mapping(self):
mapping = TosaSpecMapping()
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
# check that the mapping is correct
vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))

assert vals == ["A"]
assert len(vals) == 1

def test_mapping_multiple(self):
mapping = TosaSpecMapping()
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B")
# check that the mapping is correct
vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))

assert vals == ["A", "B"]
assert len(vals) == 2

def test_mapping_different_profiles(self):
mapping = TosaSpecMapping()
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B")
# check that the mapping is correct
vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP"))

assert vals_int == ["A"]
assert vals_fp == ["B"]
assert len(vals_int) == 1
assert len(vals_fp) == 1

def test_mapping_different_profiles_combined_consumer(self):
mapping = TosaSpecMapping()
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B")
# check that the mapping is correct
combined_vals = mapping.get(
TosaSpecification.create_from_string("TOSA-1.0+INT+FP")
)

assert "A" in combined_vals
assert "B" in combined_vals
assert len(combined_vals) == 2

def test_mapping_no_spec(self):
mapping = TosaSpecMapping()
with self.assertRaises(KeyError):
mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))

def test_mapping_no_values_for_spec(self):
mapping = TosaSpecMapping()
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A")
with self.assertRaises(KeyError):
mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))

def test_spec_with_different_profiles(self):
mapping = TosaSpecMapping()
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A")
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B")
# check that the mapping is correct
vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT"))
vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP"))
vals_int_fp = mapping.get(
TosaSpecification.create_from_string("TOSA-1.0+INT+FP")
)

assert vals_fp == ["A"]
assert vals_int == ["B"]
assert len(vals_int) == 1
assert len(vals_fp) == 1
assert len(vals_int_fp) == 2

def test_combined_profiles(self):
mapping = TosaSpecMapping()
with self.assertRaises(ValueError):
# Don't allow multiple profiles in a single spec
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT+FP"), "A")

def test_spec_add_with_extension(self):
mapping = TosaSpecMapping()
with self.assertRaises(ValueError):
mapping.add(
TosaSpecification.create_from_string("TOSA-1.0.0+INT+int16"), "A"
)

def test_spec_non_canonical_key(self):
mapping = TosaSpecMapping()
mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A")

val = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT+u55"))
assert val == ["A"]
81 changes: 80 additions & 1 deletion backends/arm/tosa/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,71 @@

import contextvars
import re
from typing import List
from typing import Dict, Generic, List, Set, TypeVar

from packaging.version import Version

T = TypeVar("T")


class TosaSpecMapping(Generic[T]):
def __init__(self):
self._mapping: Dict[TosaSpecification, List[T]] = {}

def add(self, spec: "TosaSpecification", value: T) -> None:
"""
Adds a value to the mapping for the given TOSA specification.
The specification is normalized to its canonical form, which means that
only the version and profiles are considered, without extensions.
This allows for grouping of values under the same TOSA specification
regardless of the extensions they may have.
"""

if spec.is_U55_subset or spec.extensions:
raise ValueError(
f"TosaSpecMapping does not support extensions, got: {spec}"
)

if isinstance(spec, Tosa_1_00) and len(spec.profiles) > 1:
raise ValueError(
f"TosaSpecMapping does not support multiple profiles, got: {spec}"
)

norm_spec = spec._canonical_key()
if norm_spec not in self._mapping:
self._mapping[norm_spec] = []
self._mapping[norm_spec].append(value)

@staticmethod
def _get_base_specs(spec: "TosaSpecification") -> List["TosaSpecification"]:
# Handles combined TOSA-1.0+FP+INT, etc.
if isinstance(spec, Tosa_1_00):
profiles: Set[str] = set(spec.profiles)
if profiles == {"FP", "INT"}:
version = spec.version
return [
TosaSpecification.create_from_string(f"TOSA-{version}+FP"),
TosaSpecification.create_from_string(f"TOSA-{version}+INT"),
]
return [spec]

def get(self, spec: "TosaSpecification") -> List[T]:
"""
Returns a list of values associated with the given TOSA specification.
The specification is normalized to its canonical form, which means that
only the version and profiles are considered, without extensions.
"""

base_specs = self._get_base_specs(spec)
result: List[T] = []
for base in base_specs:
norm_base = base._canonical_key()
result.extend(self._mapping.get(norm_base, []))
if len(result) == 0:
raise KeyError(f"No values found for TOSA specification: {spec}")

return result # Do not deduplicate with set(), as values may be unhashable


class TosaSpecification:
"""Represent a TOSA specification.
Expand All @@ -34,6 +95,7 @@ class TosaSpecification:

version: Version
is_U55_subset: bool
extensions: List[str]

def support_integer(self) -> bool:
"""Return True if integer operations are supported."""
Expand All @@ -52,6 +114,7 @@ def __init__(self, version: Version, extras: List[str]):

"""
self.version = version
self.extensions = []

self.is_U55_subset = "u55" in extras
if self.is_U55_subset:
Expand Down Expand Up @@ -89,6 +152,12 @@ def create_from_string(repr: str) -> "TosaSpecification":

raise ValueError(f"Failed to parse TOSA specification representation: {repr}")

def _canonical_key(self) -> "TosaSpecification":
"""
Returns a new TosaSpecification instance with only version and profiles (no extensions).
"""
raise NotImplementedError


class Tosa_1_00(TosaSpecification):
"""Provide TOSA 1.00 profile and extension semantics.
Expand Down Expand Up @@ -232,6 +301,16 @@ def support_extension(self, extension: str) -> bool:

return False

def _canonical_key(self) -> "Tosa_1_00":
"""
Returns a new Tosa_1_00 instance with only major.minor version and profiles (no extensions).
Patch version is set to zero for normalization.
"""
from packaging.version import Version

norm_version = Version(f"{self.version.major}.{self.version.minor}.0")
return Tosa_1_00(norm_version, self.profiles.copy())


class TosaLoweringContext:
"""Manage the TOSA specification context for lowering.
Expand Down
Loading