xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cuda/cutlass_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import logging
4import os
5import sys
6from dataclasses import dataclass
7from pathlib import Path
8from typing import Any, List, Optional
9
10import sympy
11
12import torch
13
14from ... import config
15from ...ir import Layout
16from ...runtime.runtime_utils import cache_dir
17from .cuda_env import get_cuda_arch, get_cuda_version
18
19
20log = logging.getLogger(__name__)
21
22
23def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
24    for cutlass_module in cutlass_modules:
25        content = content.replace(
26            f"from {cutlass_module} import ",
27            f"from cutlass_library.{cutlass_module} import ",
28        )
29    return content
30
31
32def _gen_cutlass_file(
33    file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
34) -> None:
35    orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
36    text = ""
37    with open(orig_full_path) as f:
38        text = f.read()
39    text = _rename_cutlass_import(text, cutlass_modules)
40    dst_full_path = os.path.abspath(
41        os.path.join(
42            dst_dir,
43            file_name,
44        )
45    )
46    with open(dst_full_path, "w") as f:
47        f.write(text)
48
49
50@functools.lru_cache(None)
51def try_import_cutlass() -> bool:
52    if config.is_fbcode():
53        return True
54
55    # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
56    # This is a temporary hack to avoid CUTLASS module naming conflicts.
57    # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
58
59    cutlass_py_full_path = os.path.abspath(
60        os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
61    )
62    tmp_cutlass_py_full_path = os.path.abspath(
63        os.path.join(cache_dir(), "torch_cutlass_library")
64    )
65    dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
66
67    if os.path.isdir(cutlass_py_full_path):
68        if tmp_cutlass_py_full_path not in sys.path:
69            if os.path.exists(dst_link):
70                assert os.path.islink(
71                    dst_link
72                ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
73                assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
74                    cutlass_py_full_path
75                ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
76            else:
77                os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
78                os.symlink(cutlass_py_full_path, dst_link)
79            sys.path.append(tmp_cutlass_py_full_path)
80        try:
81            import cutlass_library.generator  # noqa: F401
82            import cutlass_library.library  # noqa: F401
83            import cutlass_library.manifest  # noqa: F401
84
85            return True
86
87        except ImportError as e:
88            log.debug(
89                "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.",
90                str(e),
91            )
92    else:
93        log.debug(
94            "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
95            cutlass_py_full_path,
96        )
97    return False
98
99
100def _normalize_cuda_arch(arch: str) -> str:
101    if int(arch) >= 90:
102        return "90"
103    elif int(arch) >= 80:
104        return "80"
105    elif int(arch) >= 75:
106        return "75"
107    elif int(arch) >= 70:
108        return "70"
109    else:
110        raise NotImplementedError(f"Unsupported cuda arch: {arch}")
111
112
113@dataclass
114class CUTLASSArgs:
115    """
116    CUTLASS args used to initialize a CUTLASS Manifest.
117    """
118
119    architectures: Optional[str] = None
120    cuda_version: Optional[str] = None
121
122    operations = "all"
123    build_dir = ""
124    curr_build_dir = ""
125    generator_target = ""
126    kernels = "all"
127    ignore_kernels = ""
128    # TODO: these three look dead?
129    kernel_filter_file: None = None
130    selected_kernel_list: None = None
131    interface_dir: None = None
132    filter_by_cc = True
133    disable_full_archs_compilation = False
134
135    def __post_init__(self):
136        if self.architectures is None or self.cuda_version is None:
137            raise RuntimeError(
138                f"{self.architectures=} or {self.cuda_version=} is None!"
139            )
140        self.architectures = _normalize_cuda_arch(self.architectures)
141
142
143@functools.lru_cache(None)
144def _gen_ops_cached(arch, version) -> List[Any]:
145    # Note: Cache needs to be specific for cuda architecture and version
146
147    # Import cutlass python scripts.
148    assert try_import_cutlass()
149    import cutlass_library.generator as cutlass_generator
150    import cutlass_library.manifest as cutlass_manifest
151
152    if arch is None or version is None:
153        log.error(
154            "Cannot detect cuda arch %s or cuda version %s. "
155            "Will discard all cutlass ops. "
156            "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.",
157            arch,
158            version,
159        )
160        return []
161    arch = _normalize_cuda_arch(arch)
162    args = CUTLASSArgs(architectures=arch, cuda_version=version)
163    manifest = cutlass_manifest.Manifest(args)
164
165    if arch == "90":
166        cutlass_generator.GenerateSM90(manifest, args.cuda_version)
167        cutlass_generator.GenerateSM80(manifest, args.cuda_version)
168    else:
169        try:
170            func = getattr(cutlass_generator, "GenerateSM" + arch)
171            func(manifest, args.cuda_version)
172        except AttributeError as e:
173            raise NotImplementedError(
174                "Arch " + arch + " is not supported by current cutlass lib."
175            ) from e
176    return manifest.operations
177
178
179def gen_ops() -> List[Any]:
180    """
181    Generates all supported CUTLASS operations.
182    """
183    arch = get_cuda_arch()
184    version = get_cuda_version()
185    return _gen_ops_cached(arch, version)
186
187
188def torch_dtype_to_cutlass_type(
189    torch_dtype: torch.dtype,
190) -> "cutlass_library.library.DataType":  # type: ignore[name-defined] # noqa: F821
191    # Import cutlass python scripts.
192    assert try_import_cutlass()
193    import cutlass_library  # type: ignore[import]
194
195    if torch_dtype == torch.float:
196        return cutlass_library.library.DataType.f32
197    elif torch_dtype == torch.half:
198        return cutlass_library.library.DataType.f16
199    elif torch_dtype == torch.bfloat16:
200        return cutlass_library.library.DataType.bf16
201    else:
202        raise NotImplementedError(f"Unsupported data type: {torch_dtype=}")
203
204
205def dtype_match(
206    torch_dtype: Optional[torch.dtype],
207    cutlass_dtype: "cutlass_library.library.DataType",  # type: ignore[name-defined]  # noqa: F821
208) -> bool:
209    # Import cutlass python scripts.
210    assert try_import_cutlass()
211    import cutlass_library
212
213    if torch_dtype == torch.float:
214        return (
215            cutlass_dtype == cutlass_library.library.DataType.f32
216            or cutlass_dtype == cutlass_library.library.DataType.tf32
217        )
218    elif torch_dtype == torch.half:
219        return cutlass_dtype == cutlass_library.library.DataType.f16
220    elif torch_dtype == torch.bfloat16:
221        return cutlass_dtype == cutlass_library.library.DataType.bf16
222    elif torch_dtype == torch.int8:
223        return cutlass_dtype == cutlass_library.library.DataType.s8
224    elif torch_dtype == torch.uint8:
225        return cutlass_dtype == cutlass_library.library.DataType.u8
226    elif torch_dtype == torch.int32:
227        return cutlass_dtype == cutlass_library.library.DataType.s32
228    else:
229        return False
230
231
232def get_accumulator_dtype(
233    input_torch_dtypes: List[torch.dtype],
234) -> Optional[torch.dtype]:
235    """
236    Given a pair of input torch dtypes, returns the inferred accumulator torch dtype.
237    """
238
239    if len(input_torch_dtypes) != 2:
240        return None
241
242    torch_dtype = None
243    if input_torch_dtypes[0] == input_torch_dtypes[1]:
244        torch_dtype = input_torch_dtypes[0]
245    else:
246        size0 = torch.tensor([], dtype=input_torch_dtypes[0]).element_size()
247        size1 = torch.tensor([], dtype=input_torch_dtypes[1]).element_size()
248        if size0 > size1:
249            dtype0, dtype1 = input_torch_dtypes
250        else:
251            dtype1, dtype0 = input_torch_dtypes
252        if dtype0 in [torch.half, torch.bfloat16] and dtype1 in [
253            torch.int8,
254            torch.uint8,
255        ]:
256            torch_dtype = dtype0
257
258    if torch_dtype == torch.half:
259        if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction:
260            return torch_dtype
261        else:
262            return torch.float
263    if torch_dtype in {torch.bfloat16, torch.float}:
264        return torch.float
265    if torch_dtype == torch.int8:
266        return torch.int32
267    raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}")
268
269
270def get_alignments(torch_dtype: torch.dtype) -> List[int]:
271    """
272    Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype.
273    CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment.
274    """
275
276    if torch_dtype in (torch.half, torch.bfloat16):
277        return [8, 4, 2, 1]
278    elif torch_dtype == torch.float:
279        return [4, 2, 1]
280    elif torch_dtype in (torch.uint8, torch.int8):
281        return [16, 8, 4, 2]
282    elif torch_dtype == torch.int32:
283        return [4, 2, 1]
284    else:
285        raise NotImplementedError(f"unsupported {torch_dtype=} for alignments")
286
287
288def get_max_alignment(inductor_layout: Layout) -> int:
289    """
290    Returns the max alignment (in terms of number of elements) for a given Inductor Layout.
291    """
292
293    dtype = inductor_layout.dtype
294    size = inductor_layout.size
295    offset = inductor_layout.offset
296
297    def is_static_int(number):
298        return isinstance(number, (int, sympy.Integer))
299
300    try:
301        contiguous_dim = inductor_layout.stride.index(1)
302    except ValueError:
303        # No dim with stride 1 found, return 1
304        return 1
305    if (
306        is_static_int(size[contiguous_dim])
307        and is_static_int(offset)
308        and all(is_static_int(s) for s in inductor_layout.stride)
309    ):
310        alignments = get_alignments(dtype)
311        for alignment in alignments:
312            if (
313                int(size[contiguous_dim]) % alignment != 0
314                or int(offset) % alignment != 0
315            ):
316                continue
317            if all(
318                (dim == contiguous_dim)
319                or (inductor_layout.stride[dim] % alignment == 0)
320                for dim in range(len(size))
321            ):
322                return alignment
323    return 1
324
325
326class CUDACompileSourceCapturingContext:
327    # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation.
328    # Can be used to capture the sourcecode passed to CUDACodeCache.compile
329
330    def __init__(self):
331        self.sources = []
332        self._compile_patch = None
333
334    def __enter__(self, *args, **kwargs):
335        import unittest.mock as mock
336
337        import torch._inductor.codecache
338
339        _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile
340
341        def my_compile(source_code, dst_file_ext):
342            self.sources.append(source_code)
343            return _compile_method_orig(source_code, dst_file_ext)
344
345        self._compile_patch = mock.patch(
346            "torch._inductor.codecache.CUDACodeCache.compile", my_compile
347        )
348        return self._compile_patch.__enter__(*args, **kwargs)  # type: ignore[union-attr]
349
350    def __exit__(self, *args, **kwargs):
351        return self._compile_patch.__exit__(*args, **kwargs)  # type: ignore[union-attr]
352
353
354def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):
355    # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run
356    # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled.
357    from torch._inductor.codecache import cuda_compile_command
358
359    extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]
360    compile_command = cuda_compile_command(
361        [str(srcpath)], str(exepath), "exe", extra_args=extra_args
362    )
363    return compile_command
364