1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8# 9# Main implementation of AoT flow to partition and preprocess for Arm target 10# backends. Converts via TOSA as an intermediate form supported by AoT and 11# JIT compiler flows. 12# 13 14import re 15from typing import List 16 17from executorch.exir.backend.compile_spec_schema import CompileSpec 18from packaging.version import Version 19 20 21class TosaSpecification: 22 """ 23 This class implements a representation of TOSA specification 24 (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile 25 (with extension) and a level (8k). 26 For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension 27 For 1.00 releases the profile is INT or FP, and the extensions are for 28 INT: int16, int4, var, cf 29 FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf 30 31 The TOSA specification is encoded in the string represenatation 32 TOSA-major.minor.patch+profile[+level][+extensions] 33 34 For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified. 35 36 Profiles are uppercase letters and extensions and level is lowercase. 37 """ 38 39 version: Version 40 41 def support_integer(self) -> bool: 42 """ 43 Returns true if any integer operations are supported for the specification. 44 """ 45 raise NotImplementedError 46 47 def support_float(self) -> bool: 48 """ 49 Returns true if any float operations are supported for the specification. 50 """ 51 raise NotImplementedError 52 53 def __init__(self, version: Version): 54 self.version = version 55 56 @staticmethod 57 def create_from_compilespecs( 58 compile_specs: List[CompileSpec], 59 ) -> "TosaSpecification": 60 """ 61 Search the CompileSpec list for 'tosa_version' and instantiate a 62 class from the found value or return None on failure. 63 """ 64 for spec in compile_specs: 65 if spec.key == "tosa_version": 66 return TosaSpecification.create_from_string(spec.value.decode()) 67 raise ValueError( 68 "No TOSA version key found in any of the supplied CompileSpecs" 69 ) 70 71 @staticmethod 72 def create_from_string(repr: str) -> "TosaSpecification": 73 """ 74 Creates a TOSA specification class from a string representation: 75 TOSA-0.80.0+MI 76 TOSA-0.80.0+BI+8k 77 TOSA-0.80.0+BI+u55 # Ethos-U55 extension to handle TOSA subset 78 TOSA-0.90.0+MI 79 TOSA-1.00.0+INT+FP+int4+cf 80 """ 81 82 pattern = r"^(TOSA)-([\d.]+)\+(.+)$" 83 match = re.match(pattern, repr) 84 if match: 85 name = match.group(1) 86 version = Version(match.group(2)) 87 extras = match.group(3).split("+") 88 if name != "TOSA": 89 raise ValueError(f"Malformed TOSA specification representation: {repr}") 90 match version: 91 case _ if version.major == 0 and version.minor == 80: 92 return Tosa_0_80(version, extras) 93 case _ if version.major == 1 and version.minor == 0: 94 return Tosa_1_00(version, extras) 95 case _: 96 raise ValueError(f"Wrong TOSA version: {version} from {repr}") 97 98 raise ValueError(f"Failed to parse TOSA specification representation: {repr}") 99 100 101class Tosa_0_80(TosaSpecification): 102 profile: str 103 level_8k: bool 104 is_U55_subset: bool 105 available_profiles = ["BI", "MI"] # MT is not defined 106 107 def __init__(self, version: Version, extras: List[str]): 108 super().__init__(version) 109 assert version >= Version("0.80") and version < Version("0.90") 110 111 # Check that we only have one profile in the extensions list 112 if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1: 113 raise ValueError( 114 f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found." 115 ) 116 117 # The list contains one profile at most, so pick it 118 self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0] 119 extras.remove(self.profile) 120 121 self.level_8k = "8k" in extras 122 if self.level_8k: 123 extras.remove("8k") 124 self.is_U55_subset = "u55" in extras 125 if self.is_U55_subset: 126 extras.remove("u55") 127 128 if len(extras) > 0: 129 raise ValueError(f"Unhandled extras found: {extras}") 130 131 def __repr__(self): 132 extensions = "" 133 if self.level_8k: 134 extensions += "+8K" 135 if self.is_U55_subset: 136 extensions += "+u55" 137 return f"TOSA-{str(self.version)}+{self.profile}{extensions}" 138 139 def __hash__(self) -> int: 140 return hash(str(self.version) + self.profile) 141 142 def __eq__(self, other: object) -> bool: 143 if isinstance(other, Tosa_0_80): 144 return (self.version == other.version) and (self.profile == other.profile) 145 return False 146 147 def support_integer(self): 148 return True 149 150 def support_float(self): 151 return self.profile == "MI" 152 153 154class Tosa_1_00(TosaSpecification): 155 profiles: List[str] 156 level_8k: bool 157 extensions: List[str] 158 159 available_profiles = ["INT", "FP"] 160 valid_extensions = { 161 "INT": ["int16", "int4", "var", "cf"], 162 "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], 163 } 164 165 def __init__(self, version: Version, extras: List[str]): 166 super().__init__(version) 167 168 # Check that we have at least one profile in the extensions list 169 if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0: 170 raise ValueError( 171 f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}." 172 ) 173 174 # and not more than number of available profiles 175 if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len( 176 Tosa_1_00.available_profiles 177 ): 178 raise ValueError( 179 f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}." 180 ) 181 182 # The list contains one profile at least, so pick them 183 self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles] 184 for p in self.profiles: 185 extras.remove(p) 186 187 self.level_8k = "8k" in extras 188 if self.level_8k: 189 extras.remove("8k") 190 191 combined_extensions = [] 192 for p in self.profiles: 193 combined_extensions += Tosa_1_00.valid_extensions[p] 194 195 if not all(e in combined_extensions for e in extras): 196 raise ValueError( 197 f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}" 198 ) 199 200 # all the rest of the extras are handled extensions 201 self.extensions = extras 202 203 def _get_profiles_string(self) -> str: 204 return "".join(["+" + p for p in self.profiles]) 205 206 def _get_extensions_string(self) -> str: 207 return "".join(["+" + e for e in self.extensions]) 208 209 def __repr__(self): 210 return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}" 211 212 def __hash__(self) -> int: 213 return hash(str(self.version) + self._get_profiles_string()) 214 215 def __eq__(self, other: object) -> bool: 216 if isinstance(other, Tosa_1_00): 217 return (self.version == other.version) and ( 218 self._get_profiles_string() == other._get_profiles_string() 219 ) 220 return False 221 222 def support_integer(self): 223 return "INT" in self.profiles 224 225 def support_float(self): 226 return "FP" in self.profiles 227