xref: /aosp_15_r20/external/executorch/backends/arm/tosa_specification.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2024 Arm Limited and/or its affiliates.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
5*523fa7a6SAndroid Build Coastguard Worker
6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker#
9*523fa7a6SAndroid Build Coastguard Worker# Main implementation of AoT flow to partition and preprocess for Arm target
10*523fa7a6SAndroid Build Coastguard Worker# backends. Converts via TOSA as an intermediate form supported by AoT and
11*523fa7a6SAndroid Build Coastguard Worker# JIT compiler flows.
12*523fa7a6SAndroid Build Coastguard Worker#
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport re
15*523fa7a6SAndroid Build Coastguard Workerfrom typing import List
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
18*523fa7a6SAndroid Build Coastguard Workerfrom packaging.version import Version
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerclass TosaSpecification:
22*523fa7a6SAndroid Build Coastguard Worker    """
23*523fa7a6SAndroid Build Coastguard Worker    This class implements a representation of TOSA specification
24*523fa7a6SAndroid Build Coastguard Worker    (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
25*523fa7a6SAndroid Build Coastguard Worker    (with extension) and a level (8k).
26*523fa7a6SAndroid Build Coastguard Worker    For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension
27*523fa7a6SAndroid Build Coastguard Worker    For 1.00 releases the profile is INT or FP, and the extensions are for
28*523fa7a6SAndroid Build Coastguard Worker        INT: int16, int4, var, cf
29*523fa7a6SAndroid Build Coastguard Worker        FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker    The TOSA specification is encoded in the string represenatation
32*523fa7a6SAndroid Build Coastguard Worker        TOSA-major.minor.patch+profile[+level][+extensions]
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker    For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified.
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker    Profiles are uppercase letters and extensions and level is lowercase.
37*523fa7a6SAndroid Build Coastguard Worker    """
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    version: Version
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker    def support_integer(self) -> bool:
42*523fa7a6SAndroid Build Coastguard Worker        """
43*523fa7a6SAndroid Build Coastguard Worker        Returns true if any integer operations are supported for the specification.
44*523fa7a6SAndroid Build Coastguard Worker        """
45*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker    def support_float(self) -> bool:
48*523fa7a6SAndroid Build Coastguard Worker        """
49*523fa7a6SAndroid Build Coastguard Worker        Returns true if any float operations are supported for the specification.
50*523fa7a6SAndroid Build Coastguard Worker        """
51*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, version: Version):
54*523fa7a6SAndroid Build Coastguard Worker        self.version = version
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
57*523fa7a6SAndroid Build Coastguard Worker    def create_from_compilespecs(
58*523fa7a6SAndroid Build Coastguard Worker        compile_specs: List[CompileSpec],
59*523fa7a6SAndroid Build Coastguard Worker    ) -> "TosaSpecification":
60*523fa7a6SAndroid Build Coastguard Worker        """
61*523fa7a6SAndroid Build Coastguard Worker        Search the CompileSpec list for 'tosa_version' and instantiate a
62*523fa7a6SAndroid Build Coastguard Worker        class from the found value or return None on failure.
63*523fa7a6SAndroid Build Coastguard Worker        """
64*523fa7a6SAndroid Build Coastguard Worker        for spec in compile_specs:
65*523fa7a6SAndroid Build Coastguard Worker            if spec.key == "tosa_version":
66*523fa7a6SAndroid Build Coastguard Worker                return TosaSpecification.create_from_string(spec.value.decode())
67*523fa7a6SAndroid Build Coastguard Worker        raise ValueError(
68*523fa7a6SAndroid Build Coastguard Worker            "No TOSA version key found in any of the supplied CompileSpecs"
69*523fa7a6SAndroid Build Coastguard Worker        )
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
72*523fa7a6SAndroid Build Coastguard Worker    def create_from_string(repr: str) -> "TosaSpecification":
73*523fa7a6SAndroid Build Coastguard Worker        """
74*523fa7a6SAndroid Build Coastguard Worker        Creates a TOSA specification class from a string representation:
75*523fa7a6SAndroid Build Coastguard Worker        TOSA-0.80.0+MI
76*523fa7a6SAndroid Build Coastguard Worker        TOSA-0.80.0+BI+8k
77*523fa7a6SAndroid Build Coastguard Worker        TOSA-0.80.0+BI+u55   # Ethos-U55 extension to handle TOSA subset
78*523fa7a6SAndroid Build Coastguard Worker        TOSA-0.90.0+MI
79*523fa7a6SAndroid Build Coastguard Worker        TOSA-1.00.0+INT+FP+int4+cf
80*523fa7a6SAndroid Build Coastguard Worker        """
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker        pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
83*523fa7a6SAndroid Build Coastguard Worker        match = re.match(pattern, repr)
84*523fa7a6SAndroid Build Coastguard Worker        if match:
85*523fa7a6SAndroid Build Coastguard Worker            name = match.group(1)
86*523fa7a6SAndroid Build Coastguard Worker            version = Version(match.group(2))
87*523fa7a6SAndroid Build Coastguard Worker            extras = match.group(3).split("+")
88*523fa7a6SAndroid Build Coastguard Worker            if name != "TOSA":
89*523fa7a6SAndroid Build Coastguard Worker                raise ValueError(f"Malformed TOSA specification representation: {repr}")
90*523fa7a6SAndroid Build Coastguard Worker            match version:
91*523fa7a6SAndroid Build Coastguard Worker                case _ if version.major == 0 and version.minor == 80:
92*523fa7a6SAndroid Build Coastguard Worker                    return Tosa_0_80(version, extras)
93*523fa7a6SAndroid Build Coastguard Worker                case _ if version.major == 1 and version.minor == 0:
94*523fa7a6SAndroid Build Coastguard Worker                    return Tosa_1_00(version, extras)
95*523fa7a6SAndroid Build Coastguard Worker                case _:
96*523fa7a6SAndroid Build Coastguard Worker                    raise ValueError(f"Wrong TOSA version: {version} from {repr}")
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker        raise ValueError(f"Failed to parse TOSA specification representation: {repr}")
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Workerclass Tosa_0_80(TosaSpecification):
102*523fa7a6SAndroid Build Coastguard Worker    profile: str
103*523fa7a6SAndroid Build Coastguard Worker    level_8k: bool
104*523fa7a6SAndroid Build Coastguard Worker    is_U55_subset: bool
105*523fa7a6SAndroid Build Coastguard Worker    available_profiles = ["BI", "MI"]  # MT is not defined
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, version: Version, extras: List[str]):
108*523fa7a6SAndroid Build Coastguard Worker        super().__init__(version)
109*523fa7a6SAndroid Build Coastguard Worker        assert version >= Version("0.80") and version < Version("0.90")
110*523fa7a6SAndroid Build Coastguard Worker
111*523fa7a6SAndroid Build Coastguard Worker        # Check that we only have one profile in the extensions list
112*523fa7a6SAndroid Build Coastguard Worker        if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1:
113*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
114*523fa7a6SAndroid Build Coastguard Worker                f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found."
115*523fa7a6SAndroid Build Coastguard Worker            )
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker        # The list contains one profile at most, so pick it
118*523fa7a6SAndroid Build Coastguard Worker        self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0]
119*523fa7a6SAndroid Build Coastguard Worker        extras.remove(self.profile)
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker        self.level_8k = "8k" in extras
122*523fa7a6SAndroid Build Coastguard Worker        if self.level_8k:
123*523fa7a6SAndroid Build Coastguard Worker            extras.remove("8k")
124*523fa7a6SAndroid Build Coastguard Worker        self.is_U55_subset = "u55" in extras
125*523fa7a6SAndroid Build Coastguard Worker        if self.is_U55_subset:
126*523fa7a6SAndroid Build Coastguard Worker            extras.remove("u55")
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker        if len(extras) > 0:
129*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(f"Unhandled extras found: {extras}")
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Worker    def __repr__(self):
132*523fa7a6SAndroid Build Coastguard Worker        extensions = ""
133*523fa7a6SAndroid Build Coastguard Worker        if self.level_8k:
134*523fa7a6SAndroid Build Coastguard Worker            extensions += "+8K"
135*523fa7a6SAndroid Build Coastguard Worker        if self.is_U55_subset:
136*523fa7a6SAndroid Build Coastguard Worker            extensions += "+u55"
137*523fa7a6SAndroid Build Coastguard Worker        return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
138*523fa7a6SAndroid Build Coastguard Worker
139*523fa7a6SAndroid Build Coastguard Worker    def __hash__(self) -> int:
140*523fa7a6SAndroid Build Coastguard Worker        return hash(str(self.version) + self.profile)
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker    def __eq__(self, other: object) -> bool:
143*523fa7a6SAndroid Build Coastguard Worker        if isinstance(other, Tosa_0_80):
144*523fa7a6SAndroid Build Coastguard Worker            return (self.version == other.version) and (self.profile == other.profile)
145*523fa7a6SAndroid Build Coastguard Worker        return False
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker    def support_integer(self):
148*523fa7a6SAndroid Build Coastguard Worker        return True
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker    def support_float(self):
151*523fa7a6SAndroid Build Coastguard Worker        return self.profile == "MI"
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Workerclass Tosa_1_00(TosaSpecification):
155*523fa7a6SAndroid Build Coastguard Worker    profiles: List[str]
156*523fa7a6SAndroid Build Coastguard Worker    level_8k: bool
157*523fa7a6SAndroid Build Coastguard Worker    extensions: List[str]
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker    available_profiles = ["INT", "FP"]
160*523fa7a6SAndroid Build Coastguard Worker    valid_extensions = {
161*523fa7a6SAndroid Build Coastguard Worker        "INT": ["int16", "int4", "var", "cf"],
162*523fa7a6SAndroid Build Coastguard Worker        "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
163*523fa7a6SAndroid Build Coastguard Worker    }
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, version: Version, extras: List[str]):
166*523fa7a6SAndroid Build Coastguard Worker        super().__init__(version)
167*523fa7a6SAndroid Build Coastguard Worker
168*523fa7a6SAndroid Build Coastguard Worker        # Check that we have at least one profile in the extensions list
169*523fa7a6SAndroid Build Coastguard Worker        if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0:
170*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
171*523fa7a6SAndroid Build Coastguard Worker                f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}."
172*523fa7a6SAndroid Build Coastguard Worker            )
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Worker        # and not more than number of available profiles
175*523fa7a6SAndroid Build Coastguard Worker        if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len(
176*523fa7a6SAndroid Build Coastguard Worker            Tosa_1_00.available_profiles
177*523fa7a6SAndroid Build Coastguard Worker        ):
178*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
179*523fa7a6SAndroid Build Coastguard Worker                f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}."
180*523fa7a6SAndroid Build Coastguard Worker            )
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker        # The list contains one profile at least, so pick them
183*523fa7a6SAndroid Build Coastguard Worker        self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles]
184*523fa7a6SAndroid Build Coastguard Worker        for p in self.profiles:
185*523fa7a6SAndroid Build Coastguard Worker            extras.remove(p)
186*523fa7a6SAndroid Build Coastguard Worker
187*523fa7a6SAndroid Build Coastguard Worker        self.level_8k = "8k" in extras
188*523fa7a6SAndroid Build Coastguard Worker        if self.level_8k:
189*523fa7a6SAndroid Build Coastguard Worker            extras.remove("8k")
190*523fa7a6SAndroid Build Coastguard Worker
191*523fa7a6SAndroid Build Coastguard Worker        combined_extensions = []
192*523fa7a6SAndroid Build Coastguard Worker        for p in self.profiles:
193*523fa7a6SAndroid Build Coastguard Worker            combined_extensions += Tosa_1_00.valid_extensions[p]
194*523fa7a6SAndroid Build Coastguard Worker
195*523fa7a6SAndroid Build Coastguard Worker        if not all(e in combined_extensions for e in extras):
196*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
197*523fa7a6SAndroid Build Coastguard Worker                f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}"
198*523fa7a6SAndroid Build Coastguard Worker            )
199*523fa7a6SAndroid Build Coastguard Worker
200*523fa7a6SAndroid Build Coastguard Worker        # all the rest of the extras are handled extensions
201*523fa7a6SAndroid Build Coastguard Worker        self.extensions = extras
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker    def _get_profiles_string(self) -> str:
204*523fa7a6SAndroid Build Coastguard Worker        return "".join(["+" + p for p in self.profiles])
205*523fa7a6SAndroid Build Coastguard Worker
206*523fa7a6SAndroid Build Coastguard Worker    def _get_extensions_string(self) -> str:
207*523fa7a6SAndroid Build Coastguard Worker        return "".join(["+" + e for e in self.extensions])
208*523fa7a6SAndroid Build Coastguard Worker
209*523fa7a6SAndroid Build Coastguard Worker    def __repr__(self):
210*523fa7a6SAndroid Build Coastguard Worker        return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker    def __hash__(self) -> int:
213*523fa7a6SAndroid Build Coastguard Worker        return hash(str(self.version) + self._get_profiles_string())
214*523fa7a6SAndroid Build Coastguard Worker
215*523fa7a6SAndroid Build Coastguard Worker    def __eq__(self, other: object) -> bool:
216*523fa7a6SAndroid Build Coastguard Worker        if isinstance(other, Tosa_1_00):
217*523fa7a6SAndroid Build Coastguard Worker            return (self.version == other.version) and (
218*523fa7a6SAndroid Build Coastguard Worker                self._get_profiles_string() == other._get_profiles_string()
219*523fa7a6SAndroid Build Coastguard Worker            )
220*523fa7a6SAndroid Build Coastguard Worker        return False
221*523fa7a6SAndroid Build Coastguard Worker
222*523fa7a6SAndroid Build Coastguard Worker    def support_integer(self):
223*523fa7a6SAndroid Build Coastguard Worker        return "INT" in self.profiles
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Worker    def support_float(self):
226*523fa7a6SAndroid Build Coastguard Worker        return "FP" in self.profiles
227