1# mypy: allow-untyped-defs 2import functools 3import logging 4import os 5import sys 6from dataclasses import dataclass 7from pathlib import Path 8from typing import Any, List, Optional 9 10import sympy 11 12import torch 13 14from ... import config 15from ...ir import Layout 16from ...runtime.runtime_utils import cache_dir 17from .cuda_env import get_cuda_arch, get_cuda_version 18 19 20log = logging.getLogger(__name__) 21 22 23def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str: 24 for cutlass_module in cutlass_modules: 25 content = content.replace( 26 f"from {cutlass_module} import ", 27 f"from cutlass_library.{cutlass_module} import ", 28 ) 29 return content 30 31 32def _gen_cutlass_file( 33 file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str 34) -> None: 35 orig_full_path = os.path.abspath(os.path.join(src_dir, file_name)) 36 text = "" 37 with open(orig_full_path) as f: 38 text = f.read() 39 text = _rename_cutlass_import(text, cutlass_modules) 40 dst_full_path = os.path.abspath( 41 os.path.join( 42 dst_dir, 43 file_name, 44 ) 45 ) 46 with open(dst_full_path, "w") as f: 47 f.write(text) 48 49 50@functools.lru_cache(None) 51def try_import_cutlass() -> bool: 52 if config.is_fbcode(): 53 return True 54 55 # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. 56 # This is a temporary hack to avoid CUTLASS module naming conflicts. 57 # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. 58 59 cutlass_py_full_path = os.path.abspath( 60 os.path.join(config.cuda.cutlass_dir, "python/cutlass_library") 61 ) 62 tmp_cutlass_py_full_path = os.path.abspath( 63 os.path.join(cache_dir(), "torch_cutlass_library") 64 ) 65 dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library") 66 67 if os.path.isdir(cutlass_py_full_path): 68 if tmp_cutlass_py_full_path not in sys.path: 69 if os.path.exists(dst_link): 70 assert os.path.islink( 71 dst_link 72 ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." 73 assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( 74 cutlass_py_full_path 75 ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" 76 else: 77 os.makedirs(tmp_cutlass_py_full_path, exist_ok=True) 78 os.symlink(cutlass_py_full_path, dst_link) 79 sys.path.append(tmp_cutlass_py_full_path) 80 try: 81 import cutlass_library.generator # noqa: F401 82 import cutlass_library.library # noqa: F401 83 import cutlass_library.manifest # noqa: F401 84 85 return True 86 87 except ImportError as e: 88 log.debug( 89 "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", 90 str(e), 91 ) 92 else: 93 log.debug( 94 "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", 95 cutlass_py_full_path, 96 ) 97 return False 98 99 100def _normalize_cuda_arch(arch: str) -> str: 101 if int(arch) >= 90: 102 return "90" 103 elif int(arch) >= 80: 104 return "80" 105 elif int(arch) >= 75: 106 return "75" 107 elif int(arch) >= 70: 108 return "70" 109 else: 110 raise NotImplementedError(f"Unsupported cuda arch: {arch}") 111 112 113@dataclass 114class CUTLASSArgs: 115 """ 116 CUTLASS args used to initialize a CUTLASS Manifest. 117 """ 118 119 architectures: Optional[str] = None 120 cuda_version: Optional[str] = None 121 122 operations = "all" 123 build_dir = "" 124 curr_build_dir = "" 125 generator_target = "" 126 kernels = "all" 127 ignore_kernels = "" 128 # TODO: these three look dead? 129 kernel_filter_file: None = None 130 selected_kernel_list: None = None 131 interface_dir: None = None 132 filter_by_cc = True 133 disable_full_archs_compilation = False 134 135 def __post_init__(self): 136 if self.architectures is None or self.cuda_version is None: 137 raise RuntimeError( 138 f"{self.architectures=} or {self.cuda_version=} is None!" 139 ) 140 self.architectures = _normalize_cuda_arch(self.architectures) 141 142 143@functools.lru_cache(None) 144def _gen_ops_cached(arch, version) -> List[Any]: 145 # Note: Cache needs to be specific for cuda architecture and version 146 147 # Import cutlass python scripts. 148 assert try_import_cutlass() 149 import cutlass_library.generator as cutlass_generator 150 import cutlass_library.manifest as cutlass_manifest 151 152 if arch is None or version is None: 153 log.error( 154 "Cannot detect cuda arch %s or cuda version %s. " 155 "Will discard all cutlass ops. " 156 "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", 157 arch, 158 version, 159 ) 160 return [] 161 arch = _normalize_cuda_arch(arch) 162 args = CUTLASSArgs(architectures=arch, cuda_version=version) 163 manifest = cutlass_manifest.Manifest(args) 164 165 if arch == "90": 166 cutlass_generator.GenerateSM90(manifest, args.cuda_version) 167 cutlass_generator.GenerateSM80(manifest, args.cuda_version) 168 else: 169 try: 170 func = getattr(cutlass_generator, "GenerateSM" + arch) 171 func(manifest, args.cuda_version) 172 except AttributeError as e: 173 raise NotImplementedError( 174 "Arch " + arch + " is not supported by current cutlass lib." 175 ) from e 176 return manifest.operations 177 178 179def gen_ops() -> List[Any]: 180 """ 181 Generates all supported CUTLASS operations. 182 """ 183 arch = get_cuda_arch() 184 version = get_cuda_version() 185 return _gen_ops_cached(arch, version) 186 187 188def torch_dtype_to_cutlass_type( 189 torch_dtype: torch.dtype, 190) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 191 # Import cutlass python scripts. 192 assert try_import_cutlass() 193 import cutlass_library # type: ignore[import] 194 195 if torch_dtype == torch.float: 196 return cutlass_library.library.DataType.f32 197 elif torch_dtype == torch.half: 198 return cutlass_library.library.DataType.f16 199 elif torch_dtype == torch.bfloat16: 200 return cutlass_library.library.DataType.bf16 201 else: 202 raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") 203 204 205def dtype_match( 206 torch_dtype: Optional[torch.dtype], 207 cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 208) -> bool: 209 # Import cutlass python scripts. 210 assert try_import_cutlass() 211 import cutlass_library 212 213 if torch_dtype == torch.float: 214 return ( 215 cutlass_dtype == cutlass_library.library.DataType.f32 216 or cutlass_dtype == cutlass_library.library.DataType.tf32 217 ) 218 elif torch_dtype == torch.half: 219 return cutlass_dtype == cutlass_library.library.DataType.f16 220 elif torch_dtype == torch.bfloat16: 221 return cutlass_dtype == cutlass_library.library.DataType.bf16 222 elif torch_dtype == torch.int8: 223 return cutlass_dtype == cutlass_library.library.DataType.s8 224 elif torch_dtype == torch.uint8: 225 return cutlass_dtype == cutlass_library.library.DataType.u8 226 elif torch_dtype == torch.int32: 227 return cutlass_dtype == cutlass_library.library.DataType.s32 228 else: 229 return False 230 231 232def get_accumulator_dtype( 233 input_torch_dtypes: List[torch.dtype], 234) -> Optional[torch.dtype]: 235 """ 236 Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. 237 """ 238 239 if len(input_torch_dtypes) != 2: 240 return None 241 242 torch_dtype = None 243 if input_torch_dtypes[0] == input_torch_dtypes[1]: 244 torch_dtype = input_torch_dtypes[0] 245 else: 246 size0 = torch.tensor([], dtype=input_torch_dtypes[0]).element_size() 247 size1 = torch.tensor([], dtype=input_torch_dtypes[1]).element_size() 248 if size0 > size1: 249 dtype0, dtype1 = input_torch_dtypes 250 else: 251 dtype1, dtype0 = input_torch_dtypes 252 if dtype0 in [torch.half, torch.bfloat16] and dtype1 in [ 253 torch.int8, 254 torch.uint8, 255 ]: 256 torch_dtype = dtype0 257 258 if torch_dtype == torch.half: 259 if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction: 260 return torch_dtype 261 else: 262 return torch.float 263 if torch_dtype in {torch.bfloat16, torch.float}: 264 return torch.float 265 if torch_dtype == torch.int8: 266 return torch.int32 267 raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") 268 269 270def get_alignments(torch_dtype: torch.dtype) -> List[int]: 271 """ 272 Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. 273 CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment. 274 """ 275 276 if torch_dtype in (torch.half, torch.bfloat16): 277 return [8, 4, 2, 1] 278 elif torch_dtype == torch.float: 279 return [4, 2, 1] 280 elif torch_dtype in (torch.uint8, torch.int8): 281 return [16, 8, 4, 2] 282 elif torch_dtype == torch.int32: 283 return [4, 2, 1] 284 else: 285 raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") 286 287 288def get_max_alignment(inductor_layout: Layout) -> int: 289 """ 290 Returns the max alignment (in terms of number of elements) for a given Inductor Layout. 291 """ 292 293 dtype = inductor_layout.dtype 294 size = inductor_layout.size 295 offset = inductor_layout.offset 296 297 def is_static_int(number): 298 return isinstance(number, (int, sympy.Integer)) 299 300 try: 301 contiguous_dim = inductor_layout.stride.index(1) 302 except ValueError: 303 # No dim with stride 1 found, return 1 304 return 1 305 if ( 306 is_static_int(size[contiguous_dim]) 307 and is_static_int(offset) 308 and all(is_static_int(s) for s in inductor_layout.stride) 309 ): 310 alignments = get_alignments(dtype) 311 for alignment in alignments: 312 if ( 313 int(size[contiguous_dim]) % alignment != 0 314 or int(offset) % alignment != 0 315 ): 316 continue 317 if all( 318 (dim == contiguous_dim) 319 or (inductor_layout.stride[dim] % alignment == 0) 320 for dim in range(len(size)) 321 ): 322 return alignment 323 return 1 324 325 326class CUDACompileSourceCapturingContext: 327 # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. 328 # Can be used to capture the sourcecode passed to CUDACodeCache.compile 329 330 def __init__(self): 331 self.sources = [] 332 self._compile_patch = None 333 334 def __enter__(self, *args, **kwargs): 335 import unittest.mock as mock 336 337 import torch._inductor.codecache 338 339 _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile 340 341 def my_compile(source_code, dst_file_ext): 342 self.sources.append(source_code) 343 return _compile_method_orig(source_code, dst_file_ext) 344 345 self._compile_patch = mock.patch( 346 "torch._inductor.codecache.CUDACodeCache.compile", my_compile 347 ) 348 return self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] 349 350 def __exit__(self, *args, **kwargs): 351 return self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] 352 353 354def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path): 355 # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run 356 # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. 357 from torch._inductor.codecache import cuda_compile_command 358 359 extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"] 360 compile_command = cuda_compile_command( 361 [str(srcpath)], str(exepath), "exe", extra_args=extra_args 362 ) 363 return compile_command 364