xref: /aosp_15_r20/external/executorch/backends/arm/tosa_specification.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8#
9# Main implementation of AoT flow to partition and preprocess for Arm target
10# backends. Converts via TOSA as an intermediate form supported by AoT and
11# JIT compiler flows.
12#
13
14import re
15from typing import List
16
17from executorch.exir.backend.compile_spec_schema import CompileSpec
18from packaging.version import Version
19
20
21class TosaSpecification:
22    """
23    This class implements a representation of TOSA specification
24    (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
25    (with extension) and a level (8k).
26    For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension
27    For 1.00 releases the profile is INT or FP, and the extensions are for
28        INT: int16, int4, var, cf
29        FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
30
31    The TOSA specification is encoded in the string represenatation
32        TOSA-major.minor.patch+profile[+level][+extensions]
33
34    For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified.
35
36    Profiles are uppercase letters and extensions and level is lowercase.
37    """
38
39    version: Version
40
41    def support_integer(self) -> bool:
42        """
43        Returns true if any integer operations are supported for the specification.
44        """
45        raise NotImplementedError
46
47    def support_float(self) -> bool:
48        """
49        Returns true if any float operations are supported for the specification.
50        """
51        raise NotImplementedError
52
53    def __init__(self, version: Version):
54        self.version = version
55
56    @staticmethod
57    def create_from_compilespecs(
58        compile_specs: List[CompileSpec],
59    ) -> "TosaSpecification":
60        """
61        Search the CompileSpec list for 'tosa_version' and instantiate a
62        class from the found value or return None on failure.
63        """
64        for spec in compile_specs:
65            if spec.key == "tosa_version":
66                return TosaSpecification.create_from_string(spec.value.decode())
67        raise ValueError(
68            "No TOSA version key found in any of the supplied CompileSpecs"
69        )
70
71    @staticmethod
72    def create_from_string(repr: str) -> "TosaSpecification":
73        """
74        Creates a TOSA specification class from a string representation:
75        TOSA-0.80.0+MI
76        TOSA-0.80.0+BI+8k
77        TOSA-0.80.0+BI+u55   # Ethos-U55 extension to handle TOSA subset
78        TOSA-0.90.0+MI
79        TOSA-1.00.0+INT+FP+int4+cf
80        """
81
82        pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
83        match = re.match(pattern, repr)
84        if match:
85            name = match.group(1)
86            version = Version(match.group(2))
87            extras = match.group(3).split("+")
88            if name != "TOSA":
89                raise ValueError(f"Malformed TOSA specification representation: {repr}")
90            match version:
91                case _ if version.major == 0 and version.minor == 80:
92                    return Tosa_0_80(version, extras)
93                case _ if version.major == 1 and version.minor == 0:
94                    return Tosa_1_00(version, extras)
95                case _:
96                    raise ValueError(f"Wrong TOSA version: {version} from {repr}")
97
98        raise ValueError(f"Failed to parse TOSA specification representation: {repr}")
99
100
101class Tosa_0_80(TosaSpecification):
102    profile: str
103    level_8k: bool
104    is_U55_subset: bool
105    available_profiles = ["BI", "MI"]  # MT is not defined
106
107    def __init__(self, version: Version, extras: List[str]):
108        super().__init__(version)
109        assert version >= Version("0.80") and version < Version("0.90")
110
111        # Check that we only have one profile in the extensions list
112        if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1:
113            raise ValueError(
114                f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found."
115            )
116
117        # The list contains one profile at most, so pick it
118        self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0]
119        extras.remove(self.profile)
120
121        self.level_8k = "8k" in extras
122        if self.level_8k:
123            extras.remove("8k")
124        self.is_U55_subset = "u55" in extras
125        if self.is_U55_subset:
126            extras.remove("u55")
127
128        if len(extras) > 0:
129            raise ValueError(f"Unhandled extras found: {extras}")
130
131    def __repr__(self):
132        extensions = ""
133        if self.level_8k:
134            extensions += "+8K"
135        if self.is_U55_subset:
136            extensions += "+u55"
137        return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
138
139    def __hash__(self) -> int:
140        return hash(str(self.version) + self.profile)
141
142    def __eq__(self, other: object) -> bool:
143        if isinstance(other, Tosa_0_80):
144            return (self.version == other.version) and (self.profile == other.profile)
145        return False
146
147    def support_integer(self):
148        return True
149
150    def support_float(self):
151        return self.profile == "MI"
152
153
154class Tosa_1_00(TosaSpecification):
155    profiles: List[str]
156    level_8k: bool
157    extensions: List[str]
158
159    available_profiles = ["INT", "FP"]
160    valid_extensions = {
161        "INT": ["int16", "int4", "var", "cf"],
162        "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
163    }
164
165    def __init__(self, version: Version, extras: List[str]):
166        super().__init__(version)
167
168        # Check that we have at least one profile in the extensions list
169        if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0:
170            raise ValueError(
171                f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}."
172            )
173
174        # and not more than number of available profiles
175        if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len(
176            Tosa_1_00.available_profiles
177        ):
178            raise ValueError(
179                f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}."
180            )
181
182        # The list contains one profile at least, so pick them
183        self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles]
184        for p in self.profiles:
185            extras.remove(p)
186
187        self.level_8k = "8k" in extras
188        if self.level_8k:
189            extras.remove("8k")
190
191        combined_extensions = []
192        for p in self.profiles:
193            combined_extensions += Tosa_1_00.valid_extensions[p]
194
195        if not all(e in combined_extensions for e in extras):
196            raise ValueError(
197                f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}"
198            )
199
200        # all the rest of the extras are handled extensions
201        self.extensions = extras
202
203    def _get_profiles_string(self) -> str:
204        return "".join(["+" + p for p in self.profiles])
205
206    def _get_extensions_string(self) -> str:
207        return "".join(["+" + e for e in self.extensions])
208
209    def __repr__(self):
210        return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
211
212    def __hash__(self) -> int:
213        return hash(str(self.version) + self._get_profiles_string())
214
215    def __eq__(self, other: object) -> bool:
216        if isinstance(other, Tosa_1_00):
217            return (self.version == other.version) and (
218                self._get_profiles_string() == other._get_profiles_string()
219            )
220        return False
221
222    def support_integer(self):
223        return "INT" in self.profiles
224
225    def support_float(self):
226        return "FP" in self.profiles
227