xref: /aosp_15_r20/external/pytorch/torch/_inductor/cpu_vec_isa.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import functools
4import os
5import platform
6import re
7import subprocess
8import sys
9from typing import Any, Callable, Dict, List
10
11import torch
12from torch._inductor import config
13
14
15_IS_WINDOWS = sys.platform == "win32"
16
17
18def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
19    # ISA dry compile will cost about 1 sec time each startup time.
20    # Please check the issue: https://github.com/pytorch/pytorch/issues/100378
21    # Actually, dry compile is checking compile capability for ISA.
22    # We just record the compiler version, isa options and pytorch version info,
23    # and generated them to output binary hash path.
24    # It would optimize and skip compile existing binary.
25    from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler
26
27    compiler_info = get_compiler_version_info(get_cpp_compiler())
28    torch_version = torch.__version__
29    fingerprint = f"{compiler_info}={isa_flags}={torch_version}"
30    return fingerprint
31
32
33class VecISA:
34    _bit_width: int
35    _macro: List[str]
36    _arch_flags: str
37    _dtype_nelements: Dict[torch.dtype, int]
38
39    # Note [Checking for Vectorized Support in Inductor]
40    # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
41    # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
42    # like exp, pow, sin, cos and etc.
43    # But PyTorch and TorchInductor might use different compilers to build code. If
44    # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
45    # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
46    # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
47    # gcc/g++ compiler by default while it could support the AVX512 compilation.
48    # Therefore, there would be a conflict sleef version between PyTorch and
49    # TorchInductor. Hence, we dry-compile the following code to check whether current
50    # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
51    # also needs the logic
52    # In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
53    # making the runtime check unnecessary.
54    _avx_code = """
55#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
56#include <ATen/cpu/vec/functional.h>
57#include <ATen/cpu/vec/vec.h>
58#endif
59
60alignas(64) float in_out_ptr0[16] = {0.0};
61
62extern "C" void __avx_chk_kernel() {
63    auto tmp0 = at::vec::Vectorized<float>(1);
64    auto tmp1 = tmp0.exp();
65    tmp1.store(in_out_ptr0);
66}
67"""  # noqa: B950
68
69    _avx_py_load = """
70import torch
71from ctypes import cdll
72cdll.LoadLibrary("__lib_path__")
73"""
74
75    def bit_width(self) -> int:
76        return self._bit_width
77
78    def nelements(self, dtype: torch.dtype = torch.float) -> int:
79        return self._dtype_nelements[dtype]
80
81    def build_macro(self) -> List[str]:
82        return self._macro
83
84    def build_arch_flags(self) -> str:
85        return self._arch_flags
86
87    def __hash__(self) -> int:
88        return hash(str(self))
89
90    def check_build(self, code: str) -> bool:
91        from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write
92        from torch._inductor.cpp_builder import (
93            CppBuilder,
94            CppTorchOptions,
95            normalize_path_separator,
96        )
97
98        key, input_path = write(
99            code,
100            "cpp",
101            extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
102        )
103        from filelock import FileLock
104
105        lock_dir = get_lock_dir()
106        lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
107        with lock:
108            output_dir = os.path.dirname(input_path)
109            buid_options = CppTorchOptions(vec_isa=self, warning_all=False)
110            x86_isa_help_builder = CppBuilder(
111                key,
112                [input_path],
113                buid_options,
114                output_dir,
115            )
116            try:
117                # Check if the output file exist, and compile when not.
118                output_path = normalize_path_separator(
119                    x86_isa_help_builder.get_target_file_path()
120                )
121                if not os.path.isfile(output_path):
122                    status, target_file = x86_isa_help_builder.build()
123
124                # Check build result
125                subprocess.check_call(
126                    [
127                        sys.executable,
128                        "-c",
129                        VecISA._avx_py_load.replace("__lib_path__", output_path),
130                    ],
131                    cwd=output_dir,
132                    stderr=subprocess.DEVNULL,
133                    env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
134                )
135            except Exception as e:
136                return False
137
138            return True
139
140    @functools.lru_cache(None)  # noqa: B019
141    def __bool__(self) -> bool:
142        if config.cpp.vec_isa_ok is not None:
143            return config.cpp.vec_isa_ok
144
145        if config.is_fbcode():
146            return True
147
148        return self.check_build(VecISA._avx_code)
149
150
151@dataclasses.dataclass
152class VecNEON(VecISA):
153    _bit_width = 256  # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
154    _macro = ["CPU_CAPABILITY_NEON"]
155    if sys.platform == "darwin" and platform.processor() == "arm":
156        _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF")
157    _arch_flags = ""  # Unused
158    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
159
160    def __str__(self) -> str:
161        return "asimd"  # detects the presence of advanced SIMD on armv8-a kernels
162
163    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
164
165
166@dataclasses.dataclass
167class VecAVX512(VecISA):
168    _bit_width = 512
169    _macro = ["CPU_CAPABILITY_AVX512"]
170    _arch_flags = (
171        "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
172        if not _IS_WINDOWS
173        else "/arch:AVX512"
174    )  # TODO: use cflags
175    _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
176
177    def __str__(self) -> str:
178        return "avx512"
179
180    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
181
182
183@dataclasses.dataclass
184class VecAMX(VecAVX512):
185    _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
186
187    def __str__(self) -> str:
188        return super().__str__() + " amx_tile"
189
190    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
191
192    _amx_code = """
193#include <cstdint>
194#include <immintrin.h>
195
196struct amx_tilecfg {
197  uint8_t palette_id;
198  uint8_t start_row;
199  uint8_t reserved_0[14];
200  uint16_t colsb[16];
201  uint8_t rows[16];
202};
203
204extern "C" void __amx_chk_kernel() {
205  amx_tilecfg cfg = {0};
206  _tile_loadconfig(&cfg);
207  _tile_zero(0);
208  _tile_dpbf16ps(0, 1, 2);
209  _tile_dpbusd(0, 1, 2);
210}
211"""
212
213    @functools.lru_cache(None)  # noqa: B019
214    def __bool__(self) -> bool:
215        if super().__bool__():
216            if config.is_fbcode():
217                return False
218            if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
219                return True
220        return False
221
222
223@dataclasses.dataclass
224class VecAVX2(VecISA):
225    _bit_width = 256
226    _macro = ["CPU_CAPABILITY_AVX2"]
227    _arch_flags = (
228        "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2"
229    )  # TODO: use cflags
230    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
231
232    def __str__(self) -> str:
233        return "avx2"
234
235    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
236
237
238@dataclasses.dataclass
239class VecZVECTOR(VecISA):
240    _bit_width = 256
241    _macro = [
242        "CPU_CAPABILITY_ZVECTOR",
243        "CPU_CAPABILITY=ZVECTOR",
244        "HAVE_ZVECTOR_CPU_DEFINITION",
245    ]
246    _arch_flags = "-mvx -mzvector"
247    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
248
249    def __str__(self) -> str:
250        return "zvector"
251
252    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
253
254
255@dataclasses.dataclass
256class VecVSX(VecISA):
257    _bit_width = 256  # VSX simd supports 128 bit_width, but aten is emulating it as 256
258    _macro = ["CPU_CAPABILITY_VSX"]
259    _arch_flags = "-mvsx"
260    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
261
262    def __str__(self) -> str:
263        return "vsx"
264
265    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
266
267
268class InvalidVecISA(VecISA):
269    _bit_width = 0
270    _macro = [""]
271    _arch_flags = ""
272    _dtype_nelements = {}
273
274    def __str__(self) -> str:
275        return "INVALID_VEC_ISA"
276
277    def __bool__(self) -> bool:  # type: ignore[override]
278        return False
279
280    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
281
282
283def x86_isa_checker() -> List[str]:
284    supported_isa: List[str] = []
285
286    def _check_and_append_supported_isa(
287        dest: List[str], isa_supported: bool, isa_name: str
288    ) -> None:
289        if isa_supported:
290            dest.append(isa_name)
291
292    Arch = platform.machine()
293    """
294    Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
295    """
296    if Arch != "x86_64" and Arch != "AMD64":
297        return supported_isa
298
299    avx2 = torch.cpu._is_avx2_supported()
300    avx512 = torch.cpu._is_avx512_supported()
301    amx_tile = torch.cpu._is_amx_tile_supported()
302
303    _check_and_append_supported_isa(supported_isa, avx2, "avx2")
304    _check_and_append_supported_isa(supported_isa, avx512, "avx512")
305    _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")
306
307    return supported_isa
308
309
310invalid_vec_isa = InvalidVecISA()
311supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
312
313
314# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
315# might have too much redundant content that is useless for ISA check. Hence,
316# we only cache some key isa information.
317@functools.lru_cache(None)
318def valid_vec_isa_list() -> List[VecISA]:
319    isa_list: List[VecISA] = []
320    if sys.platform == "darwin" and platform.processor() == "arm":
321        isa_list.append(VecNEON())
322
323    if sys.platform not in ["linux", "win32"]:
324        return isa_list
325
326    arch = platform.machine()
327    if arch == "s390x":
328        with open("/proc/cpuinfo") as _cpu_info:
329            while True:
330                line = _cpu_info.readline()
331                if not line:
332                    break
333                # process line
334                featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
335                if featuresmatch:
336                    for group in featuresmatch.groups():
337                        if re.search(r"[\^ ]+vxe[\$ ]+", group):
338                            isa_list.append(VecZVECTOR())
339                            break
340    elif arch == "ppc64le":
341        isa_list.append(VecVSX())
342    elif arch == "aarch64":
343        isa_list.append(VecNEON())
344    elif arch in ["x86_64", "AMD64"]:
345        """
346        arch value is x86_64 on Linux, and the value is AMD64 on Windows.
347        """
348        _cpu_supported_x86_isa = x86_isa_checker()
349        for isa in supported_vec_isa_list:
350            if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
351                isa_list.append(isa)
352
353    return isa_list
354
355
356def pick_vec_isa() -> VecISA:
357    if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
358        return VecAVX2()
359
360    _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
361    if not _valid_vec_isa_list:
362        return invalid_vec_isa
363
364    # If the simdlen is None, it indicates determine the vectorization length automatically
365    if config.cpp.simdlen is None:
366        assert _valid_vec_isa_list
367        return _valid_vec_isa_list[0]
368
369    for isa in _valid_vec_isa_list:
370        if config.cpp.simdlen == isa.bit_width():
371            return isa
372
373    return invalid_vec_isa
374