1# mypy: allow-untyped-defs 2import dataclasses 3import functools 4import os 5import platform 6import re 7import subprocess 8import sys 9from typing import Any, Callable, Dict, List 10 11import torch 12from torch._inductor import config 13 14 15_IS_WINDOWS = sys.platform == "win32" 16 17 18def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: 19 # ISA dry compile will cost about 1 sec time each startup time. 20 # Please check the issue: https://github.com/pytorch/pytorch/issues/100378 21 # Actually, dry compile is checking compile capability for ISA. 22 # We just record the compiler version, isa options and pytorch version info, 23 # and generated them to output binary hash path. 24 # It would optimize and skip compile existing binary. 25 from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler 26 27 compiler_info = get_compiler_version_info(get_cpp_compiler()) 28 torch_version = torch.__version__ 29 fingerprint = f"{compiler_info}={isa_flags}={torch_version}" 30 return fingerprint 31 32 33class VecISA: 34 _bit_width: int 35 _macro: List[str] 36 _arch_flags: str 37 _dtype_nelements: Dict[torch.dtype, int] 38 39 # Note [Checking for Vectorized Support in Inductor] 40 # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions 41 # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions 42 # like exp, pow, sin, cos and etc. 43 # But PyTorch and TorchInductor might use different compilers to build code. If 44 # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so 45 # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass 46 # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest 47 # gcc/g++ compiler by default while it could support the AVX512 compilation. 48 # Therefore, there would be a conflict sleef version between PyTorch and 49 # TorchInductor. Hence, we dry-compile the following code to check whether current 50 # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM 51 # also needs the logic 52 # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, 53 # making the runtime check unnecessary. 54 _avx_code = """ 55#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) 56#include <ATen/cpu/vec/functional.h> 57#include <ATen/cpu/vec/vec.h> 58#endif 59 60alignas(64) float in_out_ptr0[16] = {0.0}; 61 62extern "C" void __avx_chk_kernel() { 63 auto tmp0 = at::vec::Vectorized<float>(1); 64 auto tmp1 = tmp0.exp(); 65 tmp1.store(in_out_ptr0); 66} 67""" # noqa: B950 68 69 _avx_py_load = """ 70import torch 71from ctypes import cdll 72cdll.LoadLibrary("__lib_path__") 73""" 74 75 def bit_width(self) -> int: 76 return self._bit_width 77 78 def nelements(self, dtype: torch.dtype = torch.float) -> int: 79 return self._dtype_nelements[dtype] 80 81 def build_macro(self) -> List[str]: 82 return self._macro 83 84 def build_arch_flags(self) -> str: 85 return self._arch_flags 86 87 def __hash__(self) -> int: 88 return hash(str(self)) 89 90 def check_build(self, code: str) -> bool: 91 from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write 92 from torch._inductor.cpp_builder import ( 93 CppBuilder, 94 CppTorchOptions, 95 normalize_path_separator, 96 ) 97 98 key, input_path = write( 99 code, 100 "cpp", 101 extra=_get_isa_dry_compile_fingerprint(self._arch_flags), 102 ) 103 from filelock import FileLock 104 105 lock_dir = get_lock_dir() 106 lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) 107 with lock: 108 output_dir = os.path.dirname(input_path) 109 buid_options = CppTorchOptions(vec_isa=self, warning_all=False) 110 x86_isa_help_builder = CppBuilder( 111 key, 112 [input_path], 113 buid_options, 114 output_dir, 115 ) 116 try: 117 # Check if the output file exist, and compile when not. 118 output_path = normalize_path_separator( 119 x86_isa_help_builder.get_target_file_path() 120 ) 121 if not os.path.isfile(output_path): 122 status, target_file = x86_isa_help_builder.build() 123 124 # Check build result 125 subprocess.check_call( 126 [ 127 sys.executable, 128 "-c", 129 VecISA._avx_py_load.replace("__lib_path__", output_path), 130 ], 131 cwd=output_dir, 132 stderr=subprocess.DEVNULL, 133 env={**os.environ, "PYTHONPATH": ":".join(sys.path)}, 134 ) 135 except Exception as e: 136 return False 137 138 return True 139 140 @functools.lru_cache(None) # noqa: B019 141 def __bool__(self) -> bool: 142 if config.cpp.vec_isa_ok is not None: 143 return config.cpp.vec_isa_ok 144 145 if config.is_fbcode(): 146 return True 147 148 return self.check_build(VecISA._avx_code) 149 150 151@dataclasses.dataclass 152class VecNEON(VecISA): 153 _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h 154 _macro = ["CPU_CAPABILITY_NEON"] 155 if sys.platform == "darwin" and platform.processor() == "arm": 156 _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF") 157 _arch_flags = "" # Unused 158 _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} 159 160 def __str__(self) -> str: 161 return "asimd" # detects the presence of advanced SIMD on armv8-a kernels 162 163 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 164 165 166@dataclasses.dataclass 167class VecAVX512(VecISA): 168 _bit_width = 512 169 _macro = ["CPU_CAPABILITY_AVX512"] 170 _arch_flags = ( 171 "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" 172 if not _IS_WINDOWS 173 else "/arch:AVX512" 174 ) # TODO: use cflags 175 _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} 176 177 def __str__(self) -> str: 178 return "avx512" 179 180 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 181 182 183@dataclasses.dataclass 184class VecAMX(VecAVX512): 185 _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" 186 187 def __str__(self) -> str: 188 return super().__str__() + " amx_tile" 189 190 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 191 192 _amx_code = """ 193#include <cstdint> 194#include <immintrin.h> 195 196struct amx_tilecfg { 197 uint8_t palette_id; 198 uint8_t start_row; 199 uint8_t reserved_0[14]; 200 uint16_t colsb[16]; 201 uint8_t rows[16]; 202}; 203 204extern "C" void __amx_chk_kernel() { 205 amx_tilecfg cfg = {0}; 206 _tile_loadconfig(&cfg); 207 _tile_zero(0); 208 _tile_dpbf16ps(0, 1, 2); 209 _tile_dpbusd(0, 1, 2); 210} 211""" 212 213 @functools.lru_cache(None) # noqa: B019 214 def __bool__(self) -> bool: 215 if super().__bool__(): 216 if config.is_fbcode(): 217 return False 218 if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx(): 219 return True 220 return False 221 222 223@dataclasses.dataclass 224class VecAVX2(VecISA): 225 _bit_width = 256 226 _macro = ["CPU_CAPABILITY_AVX2"] 227 _arch_flags = ( 228 "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2" 229 ) # TODO: use cflags 230 _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} 231 232 def __str__(self) -> str: 233 return "avx2" 234 235 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 236 237 238@dataclasses.dataclass 239class VecZVECTOR(VecISA): 240 _bit_width = 256 241 _macro = [ 242 "CPU_CAPABILITY_ZVECTOR", 243 "CPU_CAPABILITY=ZVECTOR", 244 "HAVE_ZVECTOR_CPU_DEFINITION", 245 ] 246 _arch_flags = "-mvx -mzvector" 247 _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} 248 249 def __str__(self) -> str: 250 return "zvector" 251 252 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 253 254 255@dataclasses.dataclass 256class VecVSX(VecISA): 257 _bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256 258 _macro = ["CPU_CAPABILITY_VSX"] 259 _arch_flags = "-mvsx" 260 _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} 261 262 def __str__(self) -> str: 263 return "vsx" 264 265 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 266 267 268class InvalidVecISA(VecISA): 269 _bit_width = 0 270 _macro = [""] 271 _arch_flags = "" 272 _dtype_nelements = {} 273 274 def __str__(self) -> str: 275 return "INVALID_VEC_ISA" 276 277 def __bool__(self) -> bool: # type: ignore[override] 278 return False 279 280 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 281 282 283def x86_isa_checker() -> List[str]: 284 supported_isa: List[str] = [] 285 286 def _check_and_append_supported_isa( 287 dest: List[str], isa_supported: bool, isa_name: str 288 ) -> None: 289 if isa_supported: 290 dest.append(isa_name) 291 292 Arch = platform.machine() 293 """ 294 Arch value is x86_64 on Linux, and the value is AMD64 on Windows. 295 """ 296 if Arch != "x86_64" and Arch != "AMD64": 297 return supported_isa 298 299 avx2 = torch.cpu._is_avx2_supported() 300 avx512 = torch.cpu._is_avx512_supported() 301 amx_tile = torch.cpu._is_amx_tile_supported() 302 303 _check_and_append_supported_isa(supported_isa, avx2, "avx2") 304 _check_and_append_supported_isa(supported_isa, avx512, "avx512") 305 _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile") 306 307 return supported_isa 308 309 310invalid_vec_isa = InvalidVecISA() 311supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()] 312 313 314# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content 315# might have too much redundant content that is useless for ISA check. Hence, 316# we only cache some key isa information. 317@functools.lru_cache(None) 318def valid_vec_isa_list() -> List[VecISA]: 319 isa_list: List[VecISA] = [] 320 if sys.platform == "darwin" and platform.processor() == "arm": 321 isa_list.append(VecNEON()) 322 323 if sys.platform not in ["linux", "win32"]: 324 return isa_list 325 326 arch = platform.machine() 327 if arch == "s390x": 328 with open("/proc/cpuinfo") as _cpu_info: 329 while True: 330 line = _cpu_info.readline() 331 if not line: 332 break 333 # process line 334 featuresmatch = re.match(r"^features\s*:\s*(.*)$", line) 335 if featuresmatch: 336 for group in featuresmatch.groups(): 337 if re.search(r"[\^ ]+vxe[\$ ]+", group): 338 isa_list.append(VecZVECTOR()) 339 break 340 elif arch == "ppc64le": 341 isa_list.append(VecVSX()) 342 elif arch == "aarch64": 343 isa_list.append(VecNEON()) 344 elif arch in ["x86_64", "AMD64"]: 345 """ 346 arch value is x86_64 on Linux, and the value is AMD64 on Windows. 347 """ 348 _cpu_supported_x86_isa = x86_isa_checker() 349 for isa in supported_vec_isa_list: 350 if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa: 351 isa_list.append(isa) 352 353 return isa_list 354 355 356def pick_vec_isa() -> VecISA: 357 if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]): 358 return VecAVX2() 359 360 _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() 361 if not _valid_vec_isa_list: 362 return invalid_vec_isa 363 364 # If the simdlen is None, it indicates determine the vectorization length automatically 365 if config.cpp.simdlen is None: 366 assert _valid_vec_isa_list 367 return _valid_vec_isa_list[0] 368 369 for isa in _valid_vec_isa_list: 370 if config.cpp.simdlen == isa.bit_width(): 371 return isa 372 373 return invalid_vec_isa 374