1# mypy: allow-untyped-defs 2import copy 3import glob 4import importlib 5import importlib.abc 6import os 7import re 8import shlex 9import shutil 10import setuptools 11import subprocess 12import sys 13import sysconfig 14import warnings 15import collections 16from pathlib import Path 17import errno 18 19import torch 20import torch._appdirs 21from .file_baton import FileBaton 22from ._cpp_extension_versioner import ExtensionVersioner 23from .hipify import hipify_python 24from .hipify.hipify_python import GeneratedFileCleaner 25from typing import Dict, List, Optional, Union, Tuple 26from torch.torch_version import TorchVersion, Version 27 28from setuptools.command.build_ext import build_ext 29 30IS_WINDOWS = sys.platform == 'win32' 31IS_MACOS = sys.platform.startswith('darwin') 32IS_LINUX = sys.platform.startswith('linux') 33LIB_EXT = '.pyd' if IS_WINDOWS else '.so' 34EXEC_EXT = '.exe' if IS_WINDOWS else '' 35CLIB_PREFIX = '' if IS_WINDOWS else 'lib' 36CLIB_EXT = '.dll' if IS_WINDOWS else '.so' 37SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared' 38 39_HERE = os.path.abspath(__file__) 40_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) 41TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib') 42 43 44SUBPROCESS_DECODE_ARGS = ('oem',) if IS_WINDOWS else () 45MINIMUM_GCC_VERSION = (5, 0, 0) 46MINIMUM_MSVC_VERSION = (19, 0, 24215) 47 48VersionRange = Tuple[Tuple[int, ...], Tuple[int, ...]] 49VersionMap = Dict[str, VersionRange] 50# The following values were taken from the following GitHub gist that 51# summarizes the minimum valid major versions of g++/clang++ for each supported 52# CUDA version: https://gist.github.com/ax3l/9489132 53# Or from include/crt/host_config.h in the CUDA SDK 54# The second value is the exclusive(!) upper bound, i.e. min <= version < max 55CUDA_GCC_VERSIONS: VersionMap = { 56 '11.0': (MINIMUM_GCC_VERSION, (10, 0)), 57 '11.1': (MINIMUM_GCC_VERSION, (11, 0)), 58 '11.2': (MINIMUM_GCC_VERSION, (11, 0)), 59 '11.3': (MINIMUM_GCC_VERSION, (11, 0)), 60 '11.4': ((6, 0, 0), (12, 0)), 61 '11.5': ((6, 0, 0), (12, 0)), 62 '11.6': ((6, 0, 0), (12, 0)), 63 '11.7': ((6, 0, 0), (12, 0)), 64} 65 66MINIMUM_CLANG_VERSION = (3, 3, 0) 67CUDA_CLANG_VERSIONS: VersionMap = { 68 '11.1': (MINIMUM_CLANG_VERSION, (11, 0)), 69 '11.2': (MINIMUM_CLANG_VERSION, (12, 0)), 70 '11.3': (MINIMUM_CLANG_VERSION, (12, 0)), 71 '11.4': (MINIMUM_CLANG_VERSION, (13, 0)), 72 '11.5': (MINIMUM_CLANG_VERSION, (13, 0)), 73 '11.6': (MINIMUM_CLANG_VERSION, (14, 0)), 74 '11.7': (MINIMUM_CLANG_VERSION, (14, 0)), 75} 76 77__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension", 78 "CppExtension", "CUDAExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available", 79 "verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"] 80# Taken directly from python stdlib < 3.9 81# See https://github.com/pytorch/pytorch/issues/48617 82def _nt_quote_args(args: Optional[List[str]]) -> List[str]: 83 """Quote command-line arguments for DOS/Windows conventions. 84 85 Just wraps every argument which contains blanks in double quotes, and 86 returns a new argument list. 87 """ 88 # Cover None-type 89 if not args: 90 return [] 91 return [f'"{arg}"' if ' ' in arg else arg for arg in args] 92 93def _find_cuda_home() -> Optional[str]: 94 """Find the CUDA install path.""" 95 # Guess #1 96 cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') 97 if cuda_home is None: 98 # Guess #2 99 nvcc_path = shutil.which("nvcc") 100 if nvcc_path is not None: 101 cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) 102 else: 103 # Guess #3 104 if IS_WINDOWS: 105 cuda_homes = glob.glob( 106 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') 107 if len(cuda_homes) == 0: 108 cuda_home = '' 109 else: 110 cuda_home = cuda_homes[0] 111 else: 112 cuda_home = '/usr/local/cuda' 113 if not os.path.exists(cuda_home): 114 cuda_home = None 115 if cuda_home and not torch.cuda.is_available(): 116 print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'", 117 file=sys.stderr) 118 return cuda_home 119 120def _find_rocm_home() -> Optional[str]: 121 """Find the ROCm install path.""" 122 # Guess #1 123 rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') 124 if rocm_home is None: 125 # Guess #2 126 hipcc_path = shutil.which('hipcc') 127 if hipcc_path is not None: 128 rocm_home = os.path.dirname(os.path.dirname( 129 os.path.realpath(hipcc_path))) 130 # can be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc 131 if os.path.basename(rocm_home) == 'hip': 132 rocm_home = os.path.dirname(rocm_home) 133 else: 134 # Guess #3 135 fallback_path = '/opt/rocm' 136 if os.path.exists(fallback_path): 137 rocm_home = fallback_path 138 if rocm_home and torch.version.hip is None: 139 print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'", 140 file=sys.stderr) 141 return rocm_home 142 143 144def _join_rocm_home(*paths) -> str: 145 """ 146 Join paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set. 147 148 This is basically a lazy way of raising an error for missing $ROCM_HOME 149 only once we need to get any ROCm-specific path. 150 """ 151 if ROCM_HOME is None: 152 raise OSError('ROCM_HOME environment variable is not set. ' 153 'Please set it to your ROCm install root.') 154 elif IS_WINDOWS: 155 raise OSError('Building PyTorch extensions using ' 156 'ROCm and Windows is not supported.') 157 return os.path.join(ROCM_HOME, *paths) 158 159 160ABI_INCOMPATIBILITY_WARNING = ''' 161 162 !! WARNING !! 163 164!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 165Your compiler ({}) may be ABI-incompatible with PyTorch! 166Please use a compiler that is ABI-compatible with GCC 5.0 and above. 167See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html. 168 169See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6 170for instructions on how to install GCC 5 or higher. 171!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 172 173 !! WARNING !! 174''' 175WRONG_COMPILER_WARNING = ''' 176 177 !! WARNING !! 178 179!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 180Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was 181built with for this platform, which is {pytorch_compiler} on {platform}. Please 182use {pytorch_compiler} to to compile your extension. Alternatively, you may 183compile PyTorch from source using {user_compiler}, and then you can also use 184{user_compiler} to compile your extension. 185 186See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help 187with compiling PyTorch from source. 188!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 189 190 !! WARNING !! 191''' 192CUDA_MISMATCH_MESSAGE = ''' 193The detected CUDA version ({0}) mismatches the version that was used to compile 194PyTorch ({1}). Please make sure to use the same CUDA versions. 195''' 196CUDA_MISMATCH_WARN = "The detected CUDA version ({0}) has a minor version mismatch with the version that was used to compile PyTorch ({1}). Most likely this shouldn't be a problem." 197CUDA_NOT_FOUND_MESSAGE = ''' 198CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATH 199environment variable or add NVCC to your system PATH. The extension compilation will fail. 200''' 201ROCM_HOME = _find_rocm_home() 202HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None 203IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False 204ROCM_VERSION = None 205if torch.version.hip is not None: 206 ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) 207 208CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None 209CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH') 210# PyTorch releases have the version pattern major.minor.patch, whereas when 211# PyTorch is built from source, we append the git commit hash, which gives 212# it the below pattern. 213BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+') 214 215COMMON_MSVC_FLAGS = ['/MD', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', '/wd4190', '/wd4624', '/wd4067', '/wd4068', '/EHsc'] 216 217MSVC_IGNORE_CUDAFE_WARNINGS = [ 218 'base_class_has_different_dll_interface', 219 'field_without_dll_interface', 220 'dll_interface_conflict_none_assumed', 221 'dll_interface_conflict_dllexport_assumed' 222] 223 224COMMON_NVCC_FLAGS = [ 225 '-D__CUDA_NO_HALF_OPERATORS__', 226 '-D__CUDA_NO_HALF_CONVERSIONS__', 227 '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', 228 '-D__CUDA_NO_HALF2_OPERATORS__', 229 '--expt-relaxed-constexpr' 230] 231 232COMMON_HIP_FLAGS = [ 233 '-fPIC', 234 '-D__HIP_PLATFORM_AMD__=1', 235 '-DUSE_ROCM=1', 236 '-DHIPBLAS_V2', 237] 238 239COMMON_HIPCC_FLAGS = [ 240 '-DCUDA_HAS_FP16=1', 241 '-D__HIP_NO_HALF_OPERATORS__=1', 242 '-D__HIP_NO_HALF_CONVERSIONS__=1', 243] 244 245JIT_EXTENSION_VERSIONER = ExtensionVersioner() 246 247PLAT_TO_VCVARS = { 248 'win32' : 'x86', 249 'win-amd64' : 'x86_amd64', 250} 251 252def get_cxx_compiler(): 253 if IS_WINDOWS: 254 compiler = os.environ.get('CXX', 'cl') 255 else: 256 compiler = os.environ.get('CXX', 'c++') 257 return compiler 258 259def _is_binary_build() -> bool: 260 return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__) 261 262 263def _accepted_compilers_for_platform() -> List[str]: 264 # gnu-c++ and gnu-cc are the conda gcc compilers 265 return ['clang++', 'clang'] if IS_MACOS else ['g++', 'gcc', 'gnu-c++', 'gnu-cc', 'clang++', 'clang'] 266 267def _maybe_write(filename, new_content): 268 r''' 269 Equivalent to writing the content into the file but will not touch the file 270 if it already had the right content (to avoid triggering recompile). 271 ''' 272 if os.path.exists(filename): 273 with open(filename) as f: 274 content = f.read() 275 276 if content == new_content: 277 # The file already contains the right thing! 278 return 279 280 with open(filename, 'w') as source_file: 281 source_file.write(new_content) 282 283def get_default_build_root() -> str: 284 """ 285 Return the path to the root folder under which extensions will built. 286 287 For each extension module built, there will be one folder underneath the 288 folder returned by this function. For example, if ``p`` is the path 289 returned by this function and ``ext`` the name of an extension, the build 290 folder for the extension will be ``p/ext``. 291 292 This directory is **user-specific** so that multiple users on the same 293 machine won't meet permission issues. 294 """ 295 return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions')) 296 297 298def check_compiler_ok_for_platform(compiler: str) -> bool: 299 """ 300 Verify that the compiler is the expected one for the current platform. 301 302 Args: 303 compiler (str): The compiler executable to check. 304 305 Returns: 306 True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS, 307 and always True for Windows. 308 """ 309 if IS_WINDOWS: 310 return True 311 compiler_path = shutil.which(compiler) 312 if compiler_path is None: 313 return False 314 # Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'. 315 compiler_path = os.path.realpath(compiler_path) 316 # Check the compiler name 317 if any(name in compiler_path for name in _accepted_compilers_for_platform()): 318 return True 319 # If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag 320 env = os.environ.copy() 321 env['LC_ALL'] = 'C' # Don't localize output 322 version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) 323 if IS_LINUX: 324 # Check for 'gcc' or 'g++' for sccache wrapper 325 pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) 326 results = re.findall(pattern, version_string) 327 if len(results) != 1: 328 # Clang is also a supported compiler on Linux 329 # Though on Ubuntu it's sometimes called "Ubuntu clang version" 330 return 'clang version' in version_string 331 compiler_path = os.path.realpath(results[0].strip()) 332 # On RHEL/CentOS c++ is a gcc compiler wrapper 333 if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string: 334 return True 335 return any(name in compiler_path for name in _accepted_compilers_for_platform()) 336 if IS_MACOS: 337 # Check for 'clang' or 'clang++' 338 return version_string.startswith("Apple clang") 339 return False 340 341 342def get_compiler_abi_compatibility_and_version(compiler) -> Tuple[bool, TorchVersion]: 343 """ 344 Determine if the given compiler is ABI-compatible with PyTorch alongside its version. 345 346 Args: 347 compiler (str): The compiler executable name to check (e.g. ``g++``). 348 Must be executable in a shell process. 349 350 Returns: 351 A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch, 352 followed by a `TorchVersion` string that contains the compiler version separated by dots. 353 """ 354 if not _is_binary_build(): 355 return (True, TorchVersion('0.0.0')) 356 if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']: 357 return (True, TorchVersion('0.0.0')) 358 359 # First check if the compiler is one of the expected ones for the particular platform. 360 if not check_compiler_ok_for_platform(compiler): 361 warnings.warn(WRONG_COMPILER_WARNING.format( 362 user_compiler=compiler, 363 pytorch_compiler=_accepted_compilers_for_platform()[0], 364 platform=sys.platform)) 365 return (False, TorchVersion('0.0.0')) 366 367 if IS_MACOS: 368 # There is no particular minimum version we need for clang, so we're good here. 369 return (True, TorchVersion('0.0.0')) 370 try: 371 if IS_LINUX: 372 minimum_required_version = MINIMUM_GCC_VERSION 373 versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion']) 374 version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.') 375 else: 376 minimum_required_version = MINIMUM_MSVC_VERSION 377 compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT) 378 match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip()) 379 version = ['0', '0', '0'] if match is None else list(match.groups()) 380 except Exception: 381 _, error, _ = sys.exc_info() 382 warnings.warn(f'Error checking compiler version for {compiler}: {error}') 383 return (False, TorchVersion('0.0.0')) 384 385 if tuple(map(int, version)) >= minimum_required_version: 386 return (True, TorchVersion('.'.join(version))) 387 388 compiler = f'{compiler} {".".join(version)}' 389 warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler)) 390 391 return (False, TorchVersion('.'.join(version))) 392 393 394def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None: 395 if not CUDA_HOME: 396 raise RuntimeError(CUDA_NOT_FOUND_MESSAGE) 397 398 nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc') 399 cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS) 400 cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str) 401 if cuda_version is None: 402 return 403 404 cuda_str_version = cuda_version.group(1) 405 cuda_ver = Version(cuda_str_version) 406 if torch.version.cuda is None: 407 return 408 409 torch_cuda_version = Version(torch.version.cuda) 410 if cuda_ver != torch_cuda_version: 411 # major/minor attributes are only available in setuptools>=49.4.0 412 if getattr(cuda_ver, "major", None) is None: 413 raise ValueError("setuptools>=49.4.0 is required") 414 if cuda_ver.major != torch_cuda_version.major: 415 raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda)) 416 warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda)) 417 418 if not (sys.platform.startswith('linux') and 419 os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') not in ['ON', '1', 'YES', 'TRUE', 'Y'] and 420 _is_binary_build()): 421 return 422 423 cuda_compiler_bounds: VersionMap = CUDA_CLANG_VERSIONS if compiler_name.startswith('clang') else CUDA_GCC_VERSIONS 424 425 if cuda_str_version not in cuda_compiler_bounds: 426 warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}') 427 else: 428 min_compiler_version, max_excl_compiler_version = cuda_compiler_bounds[cuda_str_version] 429 # Special case for 11.4.0, which has lower compiler bounds than 11.4.1 430 if "V11.4.48" in cuda_version_str and cuda_compiler_bounds == CUDA_GCC_VERSIONS: 431 max_excl_compiler_version = (11, 0) 432 min_compiler_version_str = '.'.join(map(str, min_compiler_version)) 433 max_excl_compiler_version_str = '.'.join(map(str, max_excl_compiler_version)) 434 435 version_bound_str = f'>={min_compiler_version_str}, <{max_excl_compiler_version_str}' 436 437 if compiler_version < TorchVersion(min_compiler_version_str): 438 raise RuntimeError( 439 f'The current installed version of {compiler_name} ({compiler_version}) is less ' 440 f'than the minimum required version by CUDA {cuda_str_version} ({min_compiler_version_str}). ' 441 f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' 442 ) 443 if compiler_version >= TorchVersion(max_excl_compiler_version_str): 444 raise RuntimeError( 445 f'The current installed version of {compiler_name} ({compiler_version}) is greater ' 446 f'than the maximum required version by CUDA {cuda_str_version}. ' 447 f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' 448 ) 449 450 451class BuildExtension(build_ext): 452 """ 453 A custom :mod:`setuptools` build extension . 454 455 This :class:`setuptools.build_ext` subclass takes care of passing the 456 minimum required compiler flags (e.g. ``-std=c++17``) as well as mixed 457 C++/CUDA compilation (and support for CUDA files in general). 458 459 When using :class:`BuildExtension`, it is allowed to supply a dictionary 460 for ``extra_compile_args`` (rather than the usual list) that maps from 461 languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to 462 supply to the compiler. This makes it possible to supply different flags to 463 the C++ and CUDA compiler during mixed compilation. 464 465 ``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we 466 attempt to build using the Ninja backend. Ninja greatly speeds up 467 compilation compared to the standard ``setuptools.build_ext``. 468 Fallbacks to the standard distutils backend if Ninja is not available. 469 470 .. note:: 471 By default, the Ninja backend uses #CPUS + 2 workers to build the 472 extension. This may use up too many resources on some systems. One 473 can control the number of workers by setting the `MAX_JOBS` environment 474 variable to a non-negative number. 475 """ 476 477 @classmethod 478 def with_options(cls, **options): 479 """Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options.""" 480 class cls_with_options(cls): # type: ignore[misc, valid-type] 481 def __init__(self, *args, **kwargs): 482 kwargs.update(options) 483 super().__init__(*args, **kwargs) 484 485 return cls_with_options 486 487 def __init__(self, *args, **kwargs) -> None: 488 super().__init__(*args, **kwargs) 489 self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False) 490 491 self.use_ninja = kwargs.get('use_ninja', True) 492 if self.use_ninja: 493 # Test if we can use ninja. Fallback otherwise. 494 msg = ('Attempted to use ninja as the BuildExtension backend but ' 495 '{}. Falling back to using the slow distutils backend.') 496 if not is_ninja_available(): 497 warnings.warn(msg.format('we could not find ninja.')) 498 self.use_ninja = False 499 500 def finalize_options(self) -> None: 501 super().finalize_options() 502 if self.use_ninja: 503 self.force = True 504 505 def build_extensions(self) -> None: 506 compiler_name, compiler_version = self._check_abi() 507 508 cuda_ext = False 509 extension_iter = iter(self.extensions) 510 extension = next(extension_iter, None) 511 while not cuda_ext and extension: 512 for source in extension.sources: 513 _, ext = os.path.splitext(source) 514 if ext == '.cu': 515 cuda_ext = True 516 break 517 extension = next(extension_iter, None) 518 519 if cuda_ext and not IS_HIP_EXTENSION: 520 _check_cuda_version(compiler_name, compiler_version) 521 522 for extension in self.extensions: 523 # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when 524 # extra_compile_args is a dict. Otherwise, default torch flags do 525 # not get passed. Necessary when only one of 'cxx' and 'nvcc' is 526 # passed to extra_compile_args in CUDAExtension, i.e. 527 # CUDAExtension(..., extra_compile_args={'cxx': [...]}) 528 # or 529 # CUDAExtension(..., extra_compile_args={'nvcc': [...]}) 530 if isinstance(extension.extra_compile_args, dict): 531 for ext in ['cxx', 'nvcc']: 532 if ext not in extension.extra_compile_args: 533 extension.extra_compile_args[ext] = [] 534 535 self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H') 536 # See note [Pybind11 ABI constants] 537 for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: 538 val = getattr(torch._C, f"_PYBIND11_{name}") 539 if val is not None and not IS_WINDOWS: 540 self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"') 541 self._define_torch_extension_name(extension) 542 self._add_gnu_cpp_abi_flag(extension) 543 544 if 'nvcc_dlink' in extension.extra_compile_args: 545 assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}." 546 547 # Register .cu, .cuh, .hip, and .mm as valid source extensions. 548 self.compiler.src_extensions += ['.cu', '.cuh', '.hip'] 549 if torch.backends.mps.is_built(): 550 self.compiler.src_extensions += ['.mm'] 551 # Save the original _compile method for later. 552 if self.compiler.compiler_type == 'msvc': 553 self.compiler._cpp_extensions += ['.cu', '.cuh'] 554 original_compile = self.compiler.compile 555 original_spawn = self.compiler.spawn 556 else: 557 original_compile = self.compiler._compile 558 559 def append_std17_if_no_std_present(cflags) -> None: 560 # NVCC does not allow multiple -std to be passed, so we avoid 561 # overriding the option if the user explicitly passed it. 562 cpp_format_prefix = '/{}:' if self.compiler.compiler_type == 'msvc' else '-{}=' 563 cpp_flag_prefix = cpp_format_prefix.format('std') 564 cpp_flag = cpp_flag_prefix + 'c++17' 565 if not any(flag.startswith(cpp_flag_prefix) for flag in cflags): 566 cflags.append(cpp_flag) 567 568 def unix_cuda_flags(cflags): 569 cflags = (COMMON_NVCC_FLAGS + 570 ['--compiler-options', "'-fPIC'"] + 571 cflags + _get_cuda_arch_flags(cflags)) 572 573 # NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid 574 # overriding the option if the user explicitly passed it. 575 _ccbin = os.getenv("CC") 576 if ( 577 _ccbin is not None 578 and not any(flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags) 579 ): 580 cflags.extend(['-ccbin', _ccbin]) 581 582 return cflags 583 584 def convert_to_absolute_paths_inplace(paths): 585 # Helper function. See Note [Absolute include_dirs] 586 if paths is not None: 587 for i in range(len(paths)): 588 if not os.path.isabs(paths[i]): 589 paths[i] = os.path.abspath(paths[i]) 590 591 def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: 592 # Copy before we make any modifications. 593 cflags = copy.deepcopy(extra_postargs) 594 try: 595 original_compiler = self.compiler.compiler_so 596 if _is_cuda_file(src): 597 nvcc = [_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc')] 598 self.compiler.set_executable('compiler_so', nvcc) 599 if isinstance(cflags, dict): 600 cflags = cflags['nvcc'] 601 if IS_HIP_EXTENSION: 602 cflags = COMMON_HIPCC_FLAGS + cflags + _get_rocm_arch_flags(cflags) 603 else: 604 cflags = unix_cuda_flags(cflags) 605 elif isinstance(cflags, dict): 606 cflags = cflags['cxx'] 607 if IS_HIP_EXTENSION: 608 cflags = COMMON_HIP_FLAGS + cflags 609 append_std17_if_no_std_present(cflags) 610 611 original_compile(obj, src, ext, cc_args, cflags, pp_opts) 612 finally: 613 # Put the original compiler back in place. 614 self.compiler.set_executable('compiler_so', original_compiler) 615 616 def unix_wrap_ninja_compile(sources, 617 output_dir=None, 618 macros=None, 619 include_dirs=None, 620 debug=0, 621 extra_preargs=None, 622 extra_postargs=None, 623 depends=None): 624 r"""Compiles sources by outputting a ninja file and running it.""" 625 # NB: I copied some lines from self.compiler (which is an instance 626 # of distutils.UnixCCompiler). See the following link. 627 # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567 628 # This can be fragile, but a lot of other repos also do this 629 # (see https://github.com/search?q=_setup_compile&type=Code) 630 # so it is probably OK; we'll also get CI signal if/when 631 # we update our python version (which is when distutils can be 632 # upgraded) 633 634 # Use absolute path for output_dir so that the object file paths 635 # (`objects`) get generated with absolute paths. 636 output_dir = os.path.abspath(output_dir) 637 638 # See Note [Absolute include_dirs] 639 convert_to_absolute_paths_inplace(self.compiler.include_dirs) 640 641 _, objects, extra_postargs, pp_opts, _ = \ 642 self.compiler._setup_compile(output_dir, macros, 643 include_dirs, sources, 644 depends, extra_postargs) 645 common_cflags = self.compiler._get_cc_args(pp_opts, debug, extra_preargs) 646 extra_cc_cflags = self.compiler.compiler_so[1:] 647 with_cuda = any(map(_is_cuda_file, sources)) 648 649 # extra_postargs can be either: 650 # - a dict mapping cxx/nvcc to extra flags 651 # - a list of extra flags. 652 if isinstance(extra_postargs, dict): 653 post_cflags = extra_postargs['cxx'] 654 else: 655 post_cflags = list(extra_postargs) 656 if IS_HIP_EXTENSION: 657 post_cflags = COMMON_HIP_FLAGS + post_cflags 658 append_std17_if_no_std_present(post_cflags) 659 660 cuda_post_cflags = None 661 cuda_cflags = None 662 if with_cuda: 663 cuda_cflags = common_cflags 664 if isinstance(extra_postargs, dict): 665 cuda_post_cflags = extra_postargs['nvcc'] 666 else: 667 cuda_post_cflags = list(extra_postargs) 668 if IS_HIP_EXTENSION: 669 cuda_post_cflags = cuda_post_cflags + _get_rocm_arch_flags(cuda_post_cflags) 670 cuda_post_cflags = COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_post_cflags 671 else: 672 cuda_post_cflags = unix_cuda_flags(cuda_post_cflags) 673 append_std17_if_no_std_present(cuda_post_cflags) 674 cuda_cflags = [shlex.quote(f) for f in cuda_cflags] 675 cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags] 676 677 if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: 678 cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink']) 679 else: 680 cuda_dlink_post_cflags = None 681 _write_ninja_file_and_compile_objects( 682 sources=sources, 683 objects=objects, 684 cflags=[shlex.quote(f) for f in extra_cc_cflags + common_cflags], 685 post_cflags=[shlex.quote(f) for f in post_cflags], 686 cuda_cflags=cuda_cflags, 687 cuda_post_cflags=cuda_post_cflags, 688 cuda_dlink_post_cflags=cuda_dlink_post_cflags, 689 build_directory=output_dir, 690 verbose=True, 691 with_cuda=with_cuda) 692 693 # Return *all* object filenames, not just the ones we just built. 694 return objects 695 696 def win_cuda_flags(cflags): 697 return (COMMON_NVCC_FLAGS + 698 cflags + _get_cuda_arch_flags(cflags)) 699 700 def win_wrap_single_compile(sources, 701 output_dir=None, 702 macros=None, 703 include_dirs=None, 704 debug=0, 705 extra_preargs=None, 706 extra_postargs=None, 707 depends=None): 708 709 self.cflags = copy.deepcopy(extra_postargs) 710 extra_postargs = None 711 712 def spawn(cmd): 713 # Using regex to match src, obj and include files 714 src_regex = re.compile('/T(p|c)(.*)') 715 src_list = [ 716 m.group(2) for m in (src_regex.match(elem) for elem in cmd) 717 if m 718 ] 719 720 obj_regex = re.compile('/Fo(.*)') 721 obj_list = [ 722 m.group(1) for m in (obj_regex.match(elem) for elem in cmd) 723 if m 724 ] 725 726 include_regex = re.compile(r'((\-|\/)I.*)') 727 include_list = [ 728 m.group(1) 729 for m in (include_regex.match(elem) for elem in cmd) if m 730 ] 731 732 if len(src_list) >= 1 and len(obj_list) >= 1: 733 src = src_list[0] 734 obj = obj_list[0] 735 if _is_cuda_file(src): 736 nvcc = _join_cuda_home('bin', 'nvcc') 737 if isinstance(self.cflags, dict): 738 cflags = self.cflags['nvcc'] 739 elif isinstance(self.cflags, list): 740 cflags = self.cflags 741 else: 742 cflags = [] 743 744 cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env'] 745 for flag in COMMON_MSVC_FLAGS: 746 cflags = ['-Xcompiler', flag] + cflags 747 for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: 748 cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags 749 cmd = [nvcc, '-c', src, '-o', obj] + include_list + cflags 750 elif isinstance(self.cflags, dict): 751 cflags = COMMON_MSVC_FLAGS + self.cflags['cxx'] 752 append_std17_if_no_std_present(cflags) 753 cmd += cflags 754 elif isinstance(self.cflags, list): 755 cflags = COMMON_MSVC_FLAGS + self.cflags 756 append_std17_if_no_std_present(cflags) 757 cmd += cflags 758 759 return original_spawn(cmd) 760 761 try: 762 self.compiler.spawn = spawn 763 return original_compile(sources, output_dir, macros, 764 include_dirs, debug, extra_preargs, 765 extra_postargs, depends) 766 finally: 767 self.compiler.spawn = original_spawn 768 769 def win_wrap_ninja_compile(sources, 770 output_dir=None, 771 macros=None, 772 include_dirs=None, 773 debug=0, 774 extra_preargs=None, 775 extra_postargs=None, 776 depends=None): 777 778 if not self.compiler.initialized: 779 self.compiler.initialize() 780 output_dir = os.path.abspath(output_dir) 781 782 # Note [Absolute include_dirs] 783 # Convert relative path in self.compiler.include_dirs to absolute path if any, 784 # For ninja build, the build location is not local, the build happens 785 # in a in script created build folder, relative path lost their correctness. 786 # To be consistent with jit extension, we allow user to enter relative include_dirs 787 # in setuptools.setup, and we convert the relative path to absolute path here 788 convert_to_absolute_paths_inplace(self.compiler.include_dirs) 789 790 _, objects, extra_postargs, pp_opts, _ = \ 791 self.compiler._setup_compile(output_dir, macros, 792 include_dirs, sources, 793 depends, extra_postargs) 794 common_cflags = extra_preargs or [] 795 cflags = [] 796 if debug: 797 cflags.extend(self.compiler.compile_options_debug) 798 else: 799 cflags.extend(self.compiler.compile_options) 800 common_cflags.extend(COMMON_MSVC_FLAGS) 801 cflags = cflags + common_cflags + pp_opts 802 with_cuda = any(map(_is_cuda_file, sources)) 803 804 # extra_postargs can be either: 805 # - a dict mapping cxx/nvcc to extra flags 806 # - a list of extra flags. 807 if isinstance(extra_postargs, dict): 808 post_cflags = extra_postargs['cxx'] 809 else: 810 post_cflags = list(extra_postargs) 811 append_std17_if_no_std_present(post_cflags) 812 813 cuda_post_cflags = None 814 cuda_cflags = None 815 if with_cuda: 816 cuda_cflags = ['-std=c++17', '--use-local-env'] 817 for common_cflag in common_cflags: 818 cuda_cflags.append('-Xcompiler') 819 cuda_cflags.append(common_cflag) 820 for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: 821 cuda_cflags.append('-Xcudafe') 822 cuda_cflags.append('--diag_suppress=' + ignore_warning) 823 cuda_cflags.extend(pp_opts) 824 if isinstance(extra_postargs, dict): 825 cuda_post_cflags = extra_postargs['nvcc'] 826 else: 827 cuda_post_cflags = list(extra_postargs) 828 cuda_post_cflags = win_cuda_flags(cuda_post_cflags) 829 830 cflags = _nt_quote_args(cflags) 831 post_cflags = _nt_quote_args(post_cflags) 832 if with_cuda: 833 cuda_cflags = _nt_quote_args(cuda_cflags) 834 cuda_post_cflags = _nt_quote_args(cuda_post_cflags) 835 if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: 836 cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink']) 837 else: 838 cuda_dlink_post_cflags = None 839 840 _write_ninja_file_and_compile_objects( 841 sources=sources, 842 objects=objects, 843 cflags=cflags, 844 post_cflags=post_cflags, 845 cuda_cflags=cuda_cflags, 846 cuda_post_cflags=cuda_post_cflags, 847 cuda_dlink_post_cflags=cuda_dlink_post_cflags, 848 build_directory=output_dir, 849 verbose=True, 850 with_cuda=with_cuda) 851 852 # Return *all* object filenames, not just the ones we just built. 853 return objects 854 855 # Monkey-patch the _compile or compile method. 856 # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511 857 if self.compiler.compiler_type == 'msvc': 858 if self.use_ninja: 859 self.compiler.compile = win_wrap_ninja_compile 860 else: 861 self.compiler.compile = win_wrap_single_compile 862 else: 863 if self.use_ninja: 864 self.compiler.compile = unix_wrap_ninja_compile 865 else: 866 self.compiler._compile = unix_wrap_single_compile 867 868 build_ext.build_extensions(self) 869 870 def get_ext_filename(self, ext_name): 871 # Get the original shared library name. For Python 3, this name will be 872 # suffixed with "<SOABI>.so", where <SOABI> will be something like 873 # cpython-37m-x86_64-linux-gnu. 874 ext_filename = super().get_ext_filename(ext_name) 875 # If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI 876 # component. This makes building shared libraries with setuptools that 877 # aren't Python modules nicer. 878 if self.no_python_abi_suffix: 879 # The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"]. 880 ext_filename_parts = ext_filename.split('.') 881 # Omit the second to last element. 882 without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:] 883 ext_filename = '.'.join(without_abi) 884 return ext_filename 885 886 def _check_abi(self) -> Tuple[str, TorchVersion]: 887 # On some platforms, like Windows, compiler_cxx is not available. 888 if hasattr(self.compiler, 'compiler_cxx'): 889 compiler = self.compiler.compiler_cxx[0] 890 else: 891 compiler = get_cxx_compiler() 892 _, version = get_compiler_abi_compatibility_and_version(compiler) 893 # Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set. 894 if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ: 895 msg = ('It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.' 896 'This may lead to multiple activations of the VC env.' 897 'Please set `DISTUTILS_USE_SDK=1` and try again.') 898 raise UserWarning(msg) 899 return compiler, version 900 901 def _add_compile_flag(self, extension, flag): 902 extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args) 903 if isinstance(extension.extra_compile_args, dict): 904 for args in extension.extra_compile_args.values(): 905 args.append(flag) 906 else: 907 extension.extra_compile_args.append(flag) 908 909 def _define_torch_extension_name(self, extension): 910 # pybind11 doesn't support dots in the names 911 # so in order to support extensions in the packages 912 # like torch._C, we take the last part of the string 913 # as the library name 914 names = extension.name.split('.') 915 name = names[-1] 916 define = f'-DTORCH_EXTENSION_NAME={name}' 917 self._add_compile_flag(extension, define) 918 919 def _add_gnu_cpp_abi_flag(self, extension): 920 # use the same CXX ABI as what PyTorch was compiled with 921 self._add_compile_flag(extension, '-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))) 922 923 924def CppExtension(name, sources, *args, **kwargs): 925 """ 926 Create a :class:`setuptools.Extension` for C++. 927 928 Convenience method that creates a :class:`setuptools.Extension` with the 929 bare minimum (but often sufficient) arguments to build a C++ extension. 930 931 All arguments are forwarded to the :class:`setuptools.Extension` 932 constructor. Full list arguments can be found at 933 https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference 934 935 Example: 936 >>> # xdoctest: +SKIP 937 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) 938 >>> from setuptools import setup 939 >>> from torch.utils.cpp_extension import BuildExtension, CppExtension 940 >>> setup( 941 ... name='extension', 942 ... ext_modules=[ 943 ... CppExtension( 944 ... name='extension', 945 ... sources=['extension.cpp'], 946 ... extra_compile_args=['-g'], 947 ... extra_link_args=['-Wl,--no-as-needed', '-lm']) 948 ... ], 949 ... cmdclass={ 950 ... 'build_ext': BuildExtension 951 ... }) 952 """ 953 include_dirs = kwargs.get('include_dirs', []) 954 include_dirs += include_paths() 955 kwargs['include_dirs'] = include_dirs 956 957 library_dirs = kwargs.get('library_dirs', []) 958 library_dirs += library_paths() 959 kwargs['library_dirs'] = library_dirs 960 961 libraries = kwargs.get('libraries', []) 962 libraries.append('c10') 963 libraries.append('torch') 964 libraries.append('torch_cpu') 965 libraries.append('torch_python') 966 if IS_WINDOWS: 967 libraries.append("sleef") 968 969 kwargs['libraries'] = libraries 970 971 kwargs['language'] = 'c++' 972 return setuptools.Extension(name, sources, *args, **kwargs) 973 974 975def CUDAExtension(name, sources, *args, **kwargs): 976 """ 977 Create a :class:`setuptools.Extension` for CUDA/C++. 978 979 Convenience method that creates a :class:`setuptools.Extension` with the 980 bare minimum (but often sufficient) arguments to build a CUDA/C++ 981 extension. This includes the CUDA include path, library path and runtime 982 library. 983 984 All arguments are forwarded to the :class:`setuptools.Extension` 985 constructor. Full list arguments can be found at 986 https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference 987 988 Example: 989 >>> # xdoctest: +SKIP 990 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) 991 >>> from setuptools import setup 992 >>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension 993 >>> setup( 994 ... name='cuda_extension', 995 ... ext_modules=[ 996 ... CUDAExtension( 997 ... name='cuda_extension', 998 ... sources=['extension.cpp', 'extension_kernel.cu'], 999 ... extra_compile_args={'cxx': ['-g'], 1000 ... 'nvcc': ['-O2']}, 1001 ... extra_link_args=['-Wl,--no-as-needed', '-lcuda']) 1002 ... ], 1003 ... cmdclass={ 1004 ... 'build_ext': BuildExtension 1005 ... }) 1006 1007 Compute capabilities: 1008 1009 By default the extension will be compiled to run on all archs of the cards visible during the 1010 building process of the extension, plus PTX. If down the road a new card is installed the 1011 extension may need to be recompiled. If a visible card has a compute capability (CC) that's 1012 newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch 1013 will make nvcc fall back to building kernels with the newest version of PTX your nvcc does 1014 support (see below for details on PTX). 1015 1016 You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which 1017 CCs you want the extension to support: 1018 1019 ``TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py`` 1020 ``TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py`` 1021 1022 The +PTX option causes extension kernel binaries to include PTX instructions for the specified 1023 CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >= 1024 the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with 1025 CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to 1026 provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on 1027 those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better 1028 off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6, 1029 "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but 1030 "8.0 8.6" would be better. 1031 1032 Note that while it's possible to include all supported archs, the more archs get included the 1033 slower the building process will be, as it will build a separate kernel image for each arch. 1034 1035 Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows. 1036 To workaround the issue, move python binding logic to pure C++ file. 1037 1038 Example use: 1039 #include <ATen/ATen.h> 1040 at::Tensor SigmoidAlphaBlendForwardCuda(....) 1041 1042 Instead of: 1043 #include <torch/extension.h> 1044 torch::Tensor SigmoidAlphaBlendForwardCuda(...) 1045 1046 Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460 1047 Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48 1048 1049 Relocatable device code linking: 1050 1051 If you want to reference device symbols across compilation units (across object files), 1052 the object files need to be built with `relocatable device code` (-rdc=true or -dc). 1053 An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore. 1054 `Relocatable device code` is less optimized so it needs to be used only on object files that need it. 1055 Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step 1056 help reduce the protentional perf degradation of `-rdc`. 1057 Note that it needs to be used at both steps to be useful. 1058 1059 If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step. 1060 There is also a case where `-dlink` is used without `-rdc`: 1061 when an extension is linked against a static lib containing rdc-compiled objects 1062 like the [NVSHMEM library](https://developer.nvidia.com/nvshmem). 1063 1064 Note: Ninja is required to build a CUDA Extension with RDC linking. 1065 1066 Example: 1067 >>> # xdoctest: +SKIP 1068 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) 1069 >>> CUDAExtension( 1070 ... name='cuda_extension', 1071 ... sources=['extension.cpp', 'extension_kernel.cu'], 1072 ... dlink=True, 1073 ... dlink_libraries=["dlink_lib"], 1074 ... extra_compile_args={'cxx': ['-g'], 1075 ... 'nvcc': ['-O2', '-rdc=true']}) 1076 """ 1077 library_dirs = kwargs.get('library_dirs', []) 1078 library_dirs += library_paths(cuda=True) 1079 kwargs['library_dirs'] = library_dirs 1080 1081 libraries = kwargs.get('libraries', []) 1082 libraries.append('c10') 1083 libraries.append('torch') 1084 libraries.append('torch_cpu') 1085 libraries.append('torch_python') 1086 if IS_HIP_EXTENSION: 1087 libraries.append('amdhip64') 1088 libraries.append('c10_hip') 1089 libraries.append('torch_hip') 1090 else: 1091 libraries.append('cudart') 1092 libraries.append('c10_cuda') 1093 libraries.append('torch_cuda') 1094 kwargs['libraries'] = libraries 1095 1096 include_dirs = kwargs.get('include_dirs', []) 1097 1098 if IS_HIP_EXTENSION: 1099 build_dir = os.getcwd() 1100 hipify_result = hipify_python.hipify( 1101 project_directory=build_dir, 1102 output_directory=build_dir, 1103 header_include_dirs=include_dirs, 1104 includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only 1105 extra_files=[os.path.abspath(s) for s in sources], 1106 show_detailed=True, 1107 is_pytorch_extension=True, 1108 hipify_extra_files_only=True, # don't hipify everything in includes path 1109 ) 1110 1111 hipified_sources = set() 1112 for source in sources: 1113 s_abs = os.path.abspath(source) 1114 hipified_s_abs = (hipify_result[s_abs].hipified_path if (s_abs in hipify_result and 1115 hipify_result[s_abs].hipified_path is not None) else s_abs) 1116 # setup() arguments must *always* be /-separated paths relative to the setup.py directory, 1117 # *never* absolute paths 1118 hipified_sources.add(os.path.relpath(hipified_s_abs, build_dir)) 1119 1120 sources = list(hipified_sources) 1121 1122 include_dirs += include_paths(cuda=True) 1123 kwargs['include_dirs'] = include_dirs 1124 1125 kwargs['language'] = 'c++' 1126 1127 dlink_libraries = kwargs.get('dlink_libraries', []) 1128 dlink = kwargs.get('dlink', False) or dlink_libraries 1129 if dlink: 1130 extra_compile_args = kwargs.get('extra_compile_args', {}) 1131 1132 extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', []) 1133 extra_compile_args_dlink += ['-dlink'] 1134 extra_compile_args_dlink += [f'-L{x}' for x in library_dirs] 1135 extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries] 1136 1137 if (torch.version.cuda is not None) and TorchVersion(torch.version.cuda) >= '11.2': 1138 extra_compile_args_dlink += ['-dlto'] # Device Link Time Optimization started from cuda 11.2 1139 1140 extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink 1141 1142 kwargs['extra_compile_args'] = extra_compile_args 1143 1144 return setuptools.Extension(name, sources, *args, **kwargs) 1145 1146 1147def include_paths(cuda: bool = False) -> List[str]: 1148 """ 1149 Get the include paths required to build a C++ or CUDA extension. 1150 1151 Args: 1152 cuda: If `True`, includes CUDA-specific include paths. 1153 1154 Returns: 1155 A list of include path strings. 1156 """ 1157 lib_include = os.path.join(_TORCH_PATH, 'include') 1158 paths = [ 1159 lib_include, 1160 # Remove this once torch/torch.h is officially no longer supported for C++ extensions. 1161 os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'), 1162 # Some internal (old) Torch headers don't properly prefix their includes, 1163 # so we need to pass -Itorch/lib/include/TH as well. 1164 os.path.join(lib_include, 'TH'), 1165 os.path.join(lib_include, 'THC') 1166 ] 1167 if cuda and IS_HIP_EXTENSION: 1168 paths.append(os.path.join(lib_include, 'THH')) 1169 paths.append(_join_rocm_home('include')) 1170 elif cuda: 1171 cuda_home_include = _join_cuda_home('include') 1172 # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home. 1173 # but gcc doesn't like having /usr/include passed explicitly 1174 if cuda_home_include != '/usr/include': 1175 paths.append(cuda_home_include) 1176 1177 # Support CUDA_INC_PATH env variable supported by CMake files 1178 if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \ 1179 cuda_inc_path != '/usr/include': 1180 paths.append(cuda_inc_path) 1181 if CUDNN_HOME is not None: 1182 paths.append(os.path.join(CUDNN_HOME, 'include')) 1183 return paths 1184 1185 1186def library_paths(cuda: bool = False) -> List[str]: 1187 """ 1188 Get the library paths required to build a C++ or CUDA extension. 1189 1190 Args: 1191 cuda: If `True`, includes CUDA-specific library paths. 1192 1193 Returns: 1194 A list of library path strings. 1195 """ 1196 # We need to link against libtorch.so 1197 paths = [TORCH_LIB_PATH] 1198 1199 if cuda and IS_HIP_EXTENSION: 1200 lib_dir = 'lib' 1201 paths.append(_join_rocm_home(lib_dir)) 1202 if HIP_HOME is not None: 1203 paths.append(os.path.join(HIP_HOME, 'lib')) 1204 elif cuda: 1205 if IS_WINDOWS: 1206 lib_dir = os.path.join('lib', 'x64') 1207 else: 1208 lib_dir = 'lib64' 1209 if (not os.path.exists(_join_cuda_home(lib_dir)) and 1210 os.path.exists(_join_cuda_home('lib'))): 1211 # 64-bit CUDA may be installed in 'lib' (see e.g. gh-16955) 1212 # Note that it's also possible both don't exist (see 1213 # _find_cuda_home) - in that case we stay with 'lib64'. 1214 lib_dir = 'lib' 1215 1216 paths.append(_join_cuda_home(lib_dir)) 1217 if CUDNN_HOME is not None: 1218 paths.append(os.path.join(CUDNN_HOME, lib_dir)) 1219 return paths 1220 1221 1222def load(name, 1223 sources: Union[str, List[str]], 1224 extra_cflags=None, 1225 extra_cuda_cflags=None, 1226 extra_ldflags=None, 1227 extra_include_paths=None, 1228 build_directory=None, 1229 verbose=False, 1230 with_cuda: Optional[bool] = None, 1231 is_python_module=True, 1232 is_standalone=False, 1233 keep_intermediates=True): 1234 """ 1235 Load a PyTorch C++ extension just-in-time (JIT). 1236 1237 To load an extension, a Ninja build file is emitted, which is used to 1238 compile the given sources into a dynamic library. This library is 1239 subsequently loaded into the current Python process as a module and 1240 returned from this function, ready for use. 1241 1242 By default, the directory to which the build file is emitted and the 1243 resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where 1244 ``<tmp>`` is the temporary folder on the current platform and ``<name>`` 1245 the name of the extension. This location can be overridden in two ways. 1246 First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it 1247 replaces ``<tmp>/torch_extensions`` and all extensions will be compiled 1248 into subfolders of this directory. Second, if the ``build_directory`` 1249 argument to this function is supplied, it overrides the entire path, i.e. 1250 the library will be compiled into that folder directly. 1251 1252 To compile the sources, the default system compiler (``c++``) is used, 1253 which can be overridden by setting the ``CXX`` environment variable. To pass 1254 additional arguments to the compilation process, ``extra_cflags`` or 1255 ``extra_ldflags`` can be provided. For example, to compile your extension 1256 with optimizations, pass ``extra_cflags=['-O3']``. You can also use 1257 ``extra_cflags`` to pass further include directories. 1258 1259 CUDA support with mixed compilation is provided. Simply pass CUDA source 1260 files (``.cu`` or ``.cuh``) along with other sources. Such files will be 1261 detected and compiled with nvcc rather than the C++ compiler. This includes 1262 passing the CUDA lib64 directory as a library directory, and linking 1263 ``cudart``. You can pass additional flags to nvcc via 1264 ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various 1265 heuristics for finding the CUDA install directory are used, which usually 1266 work fine. If not, setting the ``CUDA_HOME`` environment variable is the 1267 safest option. 1268 1269 Args: 1270 name: The name of the extension to build. This MUST be the same as the 1271 name of the pybind11 module! 1272 sources: A list of relative or absolute paths to C++ source files. 1273 extra_cflags: optional list of compiler flags to forward to the build. 1274 extra_cuda_cflags: optional list of compiler flags to forward to nvcc 1275 when building CUDA sources. 1276 extra_ldflags: optional list of linker flags to forward to the build. 1277 extra_include_paths: optional list of include directories to forward 1278 to the build. 1279 build_directory: optional path to use as build workspace. 1280 verbose: If ``True``, turns on verbose logging of load steps. 1281 with_cuda: Determines whether CUDA headers and libraries are added to 1282 the build. If set to ``None`` (default), this value is 1283 automatically determined based on the existence of ``.cu`` or 1284 ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers 1285 and libraries to be included. 1286 is_python_module: If ``True`` (default), imports the produced shared 1287 library as a Python module. If ``False``, behavior depends on 1288 ``is_standalone``. 1289 is_standalone: If ``False`` (default) loads the constructed extension 1290 into the process as a plain dynamic library. If ``True``, build a 1291 standalone executable. 1292 1293 Returns: 1294 If ``is_python_module`` is ``True``: 1295 Returns the loaded PyTorch extension as a Python module. 1296 1297 If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``: 1298 Returns nothing. (The shared library is loaded into the process as 1299 a side effect.) 1300 1301 If ``is_standalone`` is ``True``. 1302 Return the path to the executable. (On Windows, TORCH_LIB_PATH is 1303 added to the PATH environment variable as a side effect.) 1304 1305 Example: 1306 >>> # xdoctest: +SKIP 1307 >>> from torch.utils.cpp_extension import load 1308 >>> module = load( 1309 ... name='extension', 1310 ... sources=['extension.cpp', 'extension_kernel.cu'], 1311 ... extra_cflags=['-O2'], 1312 ... verbose=True) 1313 """ 1314 return _jit_compile( 1315 name, 1316 [sources] if isinstance(sources, str) else sources, 1317 extra_cflags, 1318 extra_cuda_cflags, 1319 extra_ldflags, 1320 extra_include_paths, 1321 build_directory or _get_build_directory(name, verbose), 1322 verbose, 1323 with_cuda, 1324 is_python_module, 1325 is_standalone, 1326 keep_intermediates=keep_intermediates) 1327 1328def _get_pybind11_abi_build_flags(): 1329 # Note [Pybind11 ABI constants] 1330 # 1331 # Pybind11 before 2.4 used to build an ABI strings using the following pattern: 1332 # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__" 1333 # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this: 1334 # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__" 1335 # 1336 # This was done in order to further narrow down the chances of compiler ABI incompatibility 1337 # that can cause a hard to debug segfaults. 1338 # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties 1339 # captured during PyTorch native library compilation in torch/csrc/Module.cpp 1340 1341 abi_cflags = [] 1342 for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: 1343 pval = getattr(torch._C, f"_PYBIND11_{pname}") 1344 if pval is not None and not IS_WINDOWS: 1345 abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"') 1346 return abi_cflags 1347 1348def _get_glibcxx_abi_build_flags(): 1349 glibcxx_abi_cflags = ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] 1350 return glibcxx_abi_cflags 1351 1352def check_compiler_is_gcc(compiler): 1353 if not IS_LINUX: 1354 return False 1355 1356 env = os.environ.copy() 1357 env['LC_ALL'] = 'C' # Don't localize output 1358 try: 1359 version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) 1360 except Exception as e: 1361 try: 1362 version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) 1363 except Exception as e: 1364 return False 1365 # Check for 'gcc' or 'g++' for sccache wrapper 1366 pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) 1367 results = re.findall(pattern, version_string) 1368 if len(results) != 1: 1369 return False 1370 compiler_path = os.path.realpath(results[0].strip()) 1371 # On RHEL/CentOS c++ is a gcc compiler wrapper 1372 if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string: 1373 return True 1374 return False 1375 1376def _check_and_build_extension_h_precompiler_headers( 1377 extra_cflags, 1378 extra_include_paths, 1379 is_standalone=False): 1380 r''' 1381 Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules. 1382 GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html 1383 PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need 1384 add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild 1385 PCH file. 1386 1387 Note: 1388 1. Windows and MacOS have different PCH mechanism. We only support Linux currently. 1389 2. It only works on GCC/G++. 1390 ''' 1391 if not IS_LINUX: 1392 return 1393 1394 compiler = get_cxx_compiler() 1395 1396 b_is_gcc = check_compiler_is_gcc(compiler) 1397 if b_is_gcc is False: 1398 return 1399 1400 head_file = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h') 1401 head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch') 1402 head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign') 1403 1404 def listToString(s): 1405 # initialize an empty string 1406 string = "" 1407 if s is None: 1408 return string 1409 1410 # traverse in the string 1411 for element in s: 1412 string += (element + ' ') 1413 # return string 1414 return string 1415 1416 def format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags, torch_include_dirs, extra_cflags, extra_include_paths): 1417 return re.sub( 1418 r"[ \n]+", 1419 " ", 1420 f""" 1421 {compiler} -x c++-header {head_file} -o {head_file_pch} {torch_include_dirs} {extra_include_paths} {extra_cflags} {common_cflags} 1422 """, 1423 ).strip() 1424 1425 def command_to_signature(cmd): 1426 signature = cmd.replace(' ', '_') 1427 return signature 1428 1429 def check_pch_signature_in_file(file_path, signature): 1430 b_exist = os.path.isfile(file_path) 1431 if b_exist is False: 1432 return False 1433 1434 with open(file_path) as file: 1435 # read all content of a file 1436 content = file.read() 1437 # check if string present in a file 1438 return signature == content 1439 1440 def _create_if_not_exist(path_dir): 1441 if not os.path.exists(path_dir): 1442 try: 1443 Path(path_dir).mkdir(parents=True, exist_ok=True) 1444 except OSError as exc: # Guard against race condition 1445 if exc.errno != errno.EEXIST: 1446 raise RuntimeError(f"Fail to create path {path_dir}") from exc 1447 1448 def write_pch_signature_to_file(file_path, pch_sign): 1449 _create_if_not_exist(os.path.dirname(file_path)) 1450 with open(file_path, "w") as f: 1451 f.write(pch_sign) 1452 f.close() 1453 1454 def build_precompile_header(pch_cmd): 1455 try: 1456 subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT) 1457 except subprocess.CalledProcessError as e: 1458 raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}") from e 1459 1460 extra_cflags_str = listToString(extra_cflags) 1461 extra_include_paths_str = " ".join( 1462 [f"-I{include}" for include in extra_include_paths] if extra_include_paths else [] 1463 ) 1464 1465 lib_include = os.path.join(_TORCH_PATH, 'include') 1466 torch_include_dirs = [ 1467 f"-I {lib_include}", 1468 # Python.h 1469 "-I {}".format(sysconfig.get_path("include")), 1470 # torch/all.h 1471 "-I {}".format(os.path.join(lib_include, 'torch', 'csrc', 'api', 'include')), 1472 ] 1473 1474 torch_include_dirs_str = listToString(torch_include_dirs) 1475 1476 common_cflags = [] 1477 if not is_standalone: 1478 common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H'] 1479 1480 common_cflags += ['-std=c++17', '-fPIC'] 1481 common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] 1482 common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()] 1483 common_cflags_str = listToString(common_cflags) 1484 1485 pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str) 1486 pch_sign = command_to_signature(pch_cmd) 1487 1488 if os.path.isfile(head_file_pch) is not True: 1489 build_precompile_header(pch_cmd) 1490 write_pch_signature_to_file(head_file_signature, pch_sign) 1491 else: 1492 b_same_sign = check_pch_signature_in_file(head_file_signature, pch_sign) 1493 if b_same_sign is False: 1494 build_precompile_header(pch_cmd) 1495 write_pch_signature_to_file(head_file_signature, pch_sign) 1496 1497def remove_extension_h_precompiler_headers(): 1498 def _remove_if_file_exists(path_file): 1499 if os.path.exists(path_file): 1500 os.remove(path_file) 1501 1502 head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch') 1503 head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign') 1504 1505 _remove_if_file_exists(head_file_pch) 1506 _remove_if_file_exists(head_file_signature) 1507 1508def load_inline(name, 1509 cpp_sources, 1510 cuda_sources=None, 1511 functions=None, 1512 extra_cflags=None, 1513 extra_cuda_cflags=None, 1514 extra_ldflags=None, 1515 extra_include_paths=None, 1516 build_directory=None, 1517 verbose=False, 1518 with_cuda=None, 1519 is_python_module=True, 1520 with_pytorch_error_handling=True, 1521 keep_intermediates=True, 1522 use_pch=False): 1523 r''' 1524 Load a PyTorch C++ extension just-in-time (JIT) from string sources. 1525 1526 This function behaves exactly like :func:`load`, but takes its sources as 1527 strings rather than filenames. These strings are stored to files in the 1528 build directory, after which the behavior of :func:`load_inline` is 1529 identical to :func:`load`. 1530 1531 See `the 1532 tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions_jit.py>`_ 1533 for good examples of using this function. 1534 1535 Sources may omit two required parts of a typical non-inline C++ extension: 1536 the necessary header includes, as well as the (pybind11) binding code. More 1537 precisely, strings passed to ``cpp_sources`` are first concatenated into a 1538 single ``.cpp`` file. This file is then prepended with ``#include 1539 <torch/extension.h>``. 1540 1541 Furthermore, if the ``functions`` argument is supplied, bindings will be 1542 automatically generated for each function specified. ``functions`` can 1543 either be a list of function names, or a dictionary mapping from function 1544 names to docstrings. If a list is given, the name of each function is used 1545 as its docstring. 1546 1547 The sources in ``cuda_sources`` are concatenated into a separate ``.cu`` 1548 file and prepended with ``torch/types.h``, ``cuda.h`` and 1549 ``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled 1550 separately, but ultimately linked into a single library. Note that no 1551 bindings are generated for functions in ``cuda_sources`` per se. To bind 1552 to a CUDA kernel, you must create a C++ function that calls it, and either 1553 declare or define this C++ function in one of the ``cpp_sources`` (and 1554 include its name in ``functions``). 1555 1556 See :func:`load` for a description of arguments omitted below. 1557 1558 Args: 1559 cpp_sources: A string, or list of strings, containing C++ source code. 1560 cuda_sources: A string, or list of strings, containing CUDA source code. 1561 functions: A list of function names for which to generate function 1562 bindings. If a dictionary is given, it should map function names to 1563 docstrings (which are otherwise just the function names). 1564 with_cuda: Determines whether CUDA headers and libraries are added to 1565 the build. If set to ``None`` (default), this value is 1566 automatically determined based on whether ``cuda_sources`` is 1567 provided. Set it to ``True`` to force CUDA headers 1568 and libraries to be included. 1569 with_pytorch_error_handling: Determines whether pytorch error and 1570 warning macros are handled by pytorch instead of pybind. To do 1571 this, each function ``foo`` is called via an intermediary ``_safe_foo`` 1572 function. This redirection might cause issues in obscure cases 1573 of cpp. This flag should be set to ``False`` when this redirect 1574 causes issues. 1575 1576 Example: 1577 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) 1578 >>> from torch.utils.cpp_extension import load_inline 1579 >>> source = """ 1580 at::Tensor sin_add(at::Tensor x, at::Tensor y) { 1581 return x.sin() + y.sin(); 1582 } 1583 """ 1584 >>> module = load_inline(name='inline_extension', 1585 ... cpp_sources=[source], 1586 ... functions=['sin_add']) 1587 1588 .. note:: 1589 By default, the Ninja backend uses #CPUS + 2 workers to build the 1590 extension. This may use up too many resources on some systems. One 1591 can control the number of workers by setting the `MAX_JOBS` environment 1592 variable to a non-negative number. 1593 ''' 1594 build_directory = build_directory or _get_build_directory(name, verbose) 1595 1596 if isinstance(cpp_sources, str): 1597 cpp_sources = [cpp_sources] 1598 cuda_sources = cuda_sources or [] 1599 if isinstance(cuda_sources, str): 1600 cuda_sources = [cuda_sources] 1601 1602 cpp_sources.insert(0, '#include <torch/extension.h>') 1603 1604 if use_pch is True: 1605 # Using PreCompile Header('torch/extension.h') to reduce compile time. 1606 _check_and_build_extension_h_precompiler_headers(extra_cflags, extra_include_paths) 1607 else: 1608 remove_extension_h_precompiler_headers() 1609 1610 # If `functions` is supplied, we create the pybind11 bindings for the user. 1611 # Here, `functions` is (or becomes, after some processing) a map from 1612 # function names to function docstrings. 1613 if functions is not None: 1614 module_def = [] 1615 module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {') 1616 if isinstance(functions, str): 1617 functions = [functions] 1618 if isinstance(functions, list): 1619 # Make the function docstring the same as the function name. 1620 functions = {f: f for f in functions} 1621 elif not isinstance(functions, dict): 1622 raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}") 1623 for function_name, docstring in functions.items(): 1624 if with_pytorch_error_handling: 1625 module_def.append(f'm.def("{function_name}", torch::wrap_pybind_function({function_name}), "{docstring}");') 1626 else: 1627 module_def.append(f'm.def("{function_name}", {function_name}, "{docstring}");') 1628 module_def.append('}') 1629 cpp_sources += module_def 1630 1631 cpp_source_path = os.path.join(build_directory, 'main.cpp') 1632 _maybe_write(cpp_source_path, "\n".join(cpp_sources)) 1633 1634 sources = [cpp_source_path] 1635 1636 if cuda_sources: 1637 cuda_sources.insert(0, '#include <torch/types.h>') 1638 cuda_sources.insert(1, '#include <cuda.h>') 1639 cuda_sources.insert(2, '#include <cuda_runtime.h>') 1640 1641 cuda_source_path = os.path.join(build_directory, 'cuda.cu') 1642 _maybe_write(cuda_source_path, "\n".join(cuda_sources)) 1643 1644 sources.append(cuda_source_path) 1645 1646 return _jit_compile( 1647 name, 1648 sources, 1649 extra_cflags, 1650 extra_cuda_cflags, 1651 extra_ldflags, 1652 extra_include_paths, 1653 build_directory, 1654 verbose, 1655 with_cuda, 1656 is_python_module, 1657 is_standalone=False, 1658 keep_intermediates=keep_intermediates) 1659 1660 1661def _jit_compile(name, 1662 sources, 1663 extra_cflags, 1664 extra_cuda_cflags, 1665 extra_ldflags, 1666 extra_include_paths, 1667 build_directory: str, 1668 verbose: bool, 1669 with_cuda: Optional[bool], 1670 is_python_module, 1671 is_standalone, 1672 keep_intermediates=True) -> None: 1673 if is_python_module and is_standalone: 1674 raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.") 1675 1676 if with_cuda is None: 1677 with_cuda = any(map(_is_cuda_file, sources)) 1678 with_cudnn = any('cudnn' in f for f in extra_ldflags or []) 1679 old_version = JIT_EXTENSION_VERSIONER.get_version(name) 1680 version = JIT_EXTENSION_VERSIONER.bump_version_if_changed( 1681 name, 1682 sources, 1683 build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths], 1684 build_directory=build_directory, 1685 with_cuda=with_cuda, 1686 is_python_module=is_python_module, 1687 is_standalone=is_standalone, 1688 ) 1689 if version > 0: 1690 if version != old_version and verbose: 1691 print(f'The input conditions for extension module {name} have changed. ' + 1692 f'Bumping to version {version} and re-building as {name}_v{version}...', 1693 file=sys.stderr) 1694 name = f'{name}_v{version}' 1695 1696 baton = FileBaton(os.path.join(build_directory, 'lock')) 1697 if baton.try_acquire(): 1698 try: 1699 if version != old_version: 1700 with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx: 1701 if IS_HIP_EXTENSION and (with_cuda or with_cudnn): 1702 hipify_result = hipify_python.hipify( 1703 project_directory=build_directory, 1704 output_directory=build_directory, 1705 header_include_dirs=(extra_include_paths if extra_include_paths is not None else []), 1706 extra_files=[os.path.abspath(s) for s in sources], 1707 ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers 1708 show_detailed=verbose, 1709 show_progress=verbose, 1710 is_pytorch_extension=True, 1711 clean_ctx=clean_ctx 1712 ) 1713 1714 hipified_sources = set() 1715 for source in sources: 1716 s_abs = os.path.abspath(source) 1717 hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs) 1718 1719 sources = list(hipified_sources) 1720 1721 _write_ninja_file_and_build_library( 1722 name=name, 1723 sources=sources, 1724 extra_cflags=extra_cflags or [], 1725 extra_cuda_cflags=extra_cuda_cflags or [], 1726 extra_ldflags=extra_ldflags or [], 1727 extra_include_paths=extra_include_paths or [], 1728 build_directory=build_directory, 1729 verbose=verbose, 1730 with_cuda=with_cuda, 1731 is_standalone=is_standalone) 1732 elif verbose: 1733 print('No modifications detected for re-loaded extension ' 1734 f'module {name}, skipping build step...', file=sys.stderr) 1735 finally: 1736 baton.release() 1737 else: 1738 baton.wait() 1739 1740 if verbose: 1741 print(f'Loading extension module {name}...', file=sys.stderr) 1742 1743 if is_standalone: 1744 return _get_exec_path(name, build_directory) 1745 1746 return _import_module_from_library(name, build_directory, is_python_module) 1747 1748 1749def _write_ninja_file_and_compile_objects( 1750 sources: List[str], 1751 objects, 1752 cflags, 1753 post_cflags, 1754 cuda_cflags, 1755 cuda_post_cflags, 1756 cuda_dlink_post_cflags, 1757 build_directory: str, 1758 verbose: bool, 1759 with_cuda: Optional[bool]) -> None: 1760 verify_ninja_availability() 1761 1762 compiler = get_cxx_compiler() 1763 1764 get_compiler_abi_compatibility_and_version(compiler) 1765 if with_cuda is None: 1766 with_cuda = any(map(_is_cuda_file, sources)) 1767 build_file_path = os.path.join(build_directory, 'build.ninja') 1768 if verbose: 1769 print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr) 1770 _write_ninja_file( 1771 path=build_file_path, 1772 cflags=cflags, 1773 post_cflags=post_cflags, 1774 cuda_cflags=cuda_cflags, 1775 cuda_post_cflags=cuda_post_cflags, 1776 cuda_dlink_post_cflags=cuda_dlink_post_cflags, 1777 sources=sources, 1778 objects=objects, 1779 ldflags=None, 1780 library_target=None, 1781 with_cuda=with_cuda) 1782 if verbose: 1783 print('Compiling objects...', file=sys.stderr) 1784 _run_ninja_build( 1785 build_directory, 1786 verbose, 1787 # It would be better if we could tell users the name of the extension 1788 # that failed to build but there isn't a good way to get it here. 1789 error_prefix='Error compiling objects for extension') 1790 1791 1792def _write_ninja_file_and_build_library( 1793 name, 1794 sources: List[str], 1795 extra_cflags, 1796 extra_cuda_cflags, 1797 extra_ldflags, 1798 extra_include_paths, 1799 build_directory: str, 1800 verbose: bool, 1801 with_cuda: Optional[bool], 1802 is_standalone: bool = False) -> None: 1803 verify_ninja_availability() 1804 1805 compiler = get_cxx_compiler() 1806 1807 get_compiler_abi_compatibility_and_version(compiler) 1808 if with_cuda is None: 1809 with_cuda = any(map(_is_cuda_file, sources)) 1810 extra_ldflags = _prepare_ldflags( 1811 extra_ldflags or [], 1812 with_cuda, 1813 verbose, 1814 is_standalone) 1815 build_file_path = os.path.join(build_directory, 'build.ninja') 1816 if verbose: 1817 print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr) 1818 # NOTE: Emitting a new ninja build file does not cause re-compilation if 1819 # the sources did not change, so it's ok to re-emit (and it's fast). 1820 _write_ninja_file_to_build_library( 1821 path=build_file_path, 1822 name=name, 1823 sources=sources, 1824 extra_cflags=extra_cflags or [], 1825 extra_cuda_cflags=extra_cuda_cflags or [], 1826 extra_ldflags=extra_ldflags or [], 1827 extra_include_paths=extra_include_paths or [], 1828 with_cuda=with_cuda, 1829 is_standalone=is_standalone) 1830 1831 if verbose: 1832 print(f'Building extension module {name}...', file=sys.stderr) 1833 _run_ninja_build( 1834 build_directory, 1835 verbose, 1836 error_prefix=f"Error building extension '{name}'") 1837 1838 1839def is_ninja_available(): 1840 """Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise.""" 1841 try: 1842 subprocess.check_output('ninja --version'.split()) 1843 except Exception: 1844 return False 1845 else: 1846 return True 1847 1848 1849def verify_ninja_availability(): 1850 """Raise ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not available on the system, does nothing otherwise.""" 1851 if not is_ninja_available(): 1852 raise RuntimeError("Ninja is required to load C++ extensions") 1853 1854 1855def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): 1856 if IS_WINDOWS: 1857 python_lib_path = os.path.join(sys.base_exec_prefix, 'libs') 1858 1859 extra_ldflags.append('c10.lib') 1860 if with_cuda: 1861 extra_ldflags.append('c10_cuda.lib') 1862 extra_ldflags.append('torch_cpu.lib') 1863 if with_cuda: 1864 extra_ldflags.append('torch_cuda.lib') 1865 # /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it. 1866 # Related issue: https://github.com/pytorch/pytorch/issues/31611 1867 extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ') 1868 extra_ldflags.append('torch.lib') 1869 extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}') 1870 if not is_standalone: 1871 extra_ldflags.append('torch_python.lib') 1872 extra_ldflags.append(f'/LIBPATH:{python_lib_path}') 1873 1874 else: 1875 extra_ldflags.append(f'-L{TORCH_LIB_PATH}') 1876 extra_ldflags.append('-lc10') 1877 if with_cuda: 1878 extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') 1879 extra_ldflags.append('-ltorch_cpu') 1880 if with_cuda: 1881 extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda') 1882 extra_ldflags.append('-ltorch') 1883 if not is_standalone: 1884 extra_ldflags.append('-ltorch_python') 1885 1886 if is_standalone: 1887 extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}") 1888 1889 if with_cuda: 1890 if verbose: 1891 print('Detected CUDA files, patching ldflags', file=sys.stderr) 1892 if IS_WINDOWS: 1893 extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}') 1894 extra_ldflags.append('cudart.lib') 1895 if CUDNN_HOME is not None: 1896 extra_ldflags.append(f'/LIBPATH:{os.path.join(CUDNN_HOME, "lib", "x64")}') 1897 elif not IS_HIP_EXTENSION: 1898 extra_lib_dir = "lib64" 1899 if (not os.path.exists(_join_cuda_home(extra_lib_dir)) and 1900 os.path.exists(_join_cuda_home("lib"))): 1901 # 64-bit CUDA may be installed in "lib" 1902 # Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64" 1903 extra_lib_dir = "lib" 1904 extra_ldflags.append(f'-L{_join_cuda_home(extra_lib_dir)}') 1905 extra_ldflags.append('-lcudart') 1906 if CUDNN_HOME is not None: 1907 extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}') 1908 elif IS_HIP_EXTENSION: 1909 extra_ldflags.append(f'-L{_join_rocm_home("lib")}') 1910 extra_ldflags.append('-lamdhip64') 1911 return extra_ldflags 1912 1913 1914def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: 1915 """ 1916 Determine CUDA arch flags to use. 1917 1918 For an arch, say "6.1", the added compile flag will be 1919 ``-gencode=arch=compute_61,code=sm_61``. 1920 For an added "+PTX", an additional 1921 ``-gencode=arch=compute_xx,code=compute_xx`` is added. 1922 1923 See select_compute_arch.cmake for corresponding named and supported arches 1924 when building with CMake. 1925 """ 1926 # If cflags is given, there may already be user-provided arch flags in it 1927 # (from `extra_compile_args`) 1928 if cflags is not None: 1929 for flag in cflags: 1930 if 'TORCH_EXTENSION_NAME' in flag: 1931 continue 1932 if 'arch' in flag: 1933 return [] 1934 1935 # Note: keep combined names ("arch1+arch2") above single names, otherwise 1936 # string replacement may not do the right thing 1937 named_arches = collections.OrderedDict([ 1938 ('Kepler+Tesla', '3.7'), 1939 ('Kepler', '3.5+PTX'), 1940 ('Maxwell+Tegra', '5.3'), 1941 ('Maxwell', '5.0;5.2+PTX'), 1942 ('Pascal', '6.0;6.1+PTX'), 1943 ('Volta+Tegra', '7.2'), 1944 ('Volta', '7.0+PTX'), 1945 ('Turing', '7.5+PTX'), 1946 ('Ampere+Tegra', '8.7'), 1947 ('Ampere', '8.0;8.6+PTX'), 1948 ('Ada', '8.9+PTX'), 1949 ('Hopper', '9.0+PTX'), 1950 ]) 1951 1952 supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2', 1953 '7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a'] 1954 valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] 1955 1956 # The default is sm_30 for CUDA 9.x and 10.x 1957 # First check for an env var (same as used by the main setup.py) 1958 # Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX" 1959 # See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake 1960 _arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None) 1961 1962 # If not given, determine what's best for the GPU / CUDA version that can be found 1963 if not _arch_list: 1964 warnings.warn( 1965 "TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n" 1966 "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].") 1967 arch_list = [] 1968 # the assumption is that the extension should run on any of the currently visible cards, 1969 # which could be of different types - therefore all archs for visible cards should be included 1970 for i in range(torch.cuda.device_count()): 1971 capability = torch.cuda.get_device_capability(i) 1972 supported_sm = [int(arch.split('_')[1]) 1973 for arch in torch.cuda.get_arch_list() if 'sm_' in arch] 1974 max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) 1975 # Capability of the device may be higher than what's supported by the user's 1976 # NVCC, causing compilation error. User's NVCC is expected to match the one 1977 # used to build pytorch, so we use the maximum supported capability of pytorch 1978 # to clamp the capability. 1979 capability = min(max_supported_sm, capability) 1980 arch = f'{capability[0]}.{capability[1]}' 1981 if arch not in arch_list: 1982 arch_list.append(arch) 1983 arch_list = sorted(arch_list) 1984 arch_list[-1] += '+PTX' 1985 else: 1986 # Deal with lists that are ' ' separated (only deal with ';' after) 1987 _arch_list = _arch_list.replace(' ', ';') 1988 # Expand named arches 1989 for named_arch, archval in named_arches.items(): 1990 _arch_list = _arch_list.replace(named_arch, archval) 1991 1992 arch_list = _arch_list.split(';') 1993 1994 flags = [] 1995 for arch in arch_list: 1996 if arch not in valid_arch_strings: 1997 raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported") 1998 else: 1999 num = arch[0] + arch[2:].split("+")[0] 2000 flags.append(f'-gencode=arch=compute_{num},code=sm_{num}') 2001 if arch.endswith('+PTX'): 2002 flags.append(f'-gencode=arch=compute_{num},code=compute_{num}') 2003 2004 return sorted(set(flags)) 2005 2006 2007def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: 2008 # If cflags is given, there may already be user-provided arch flags in it 2009 # (from `extra_compile_args`) 2010 if cflags is not None: 2011 for flag in cflags: 2012 if 'amdgpu-target' in flag or 'offload-arch' in flag: 2013 return ['-fno-gpu-rdc'] 2014 # Use same defaults as used for building PyTorch 2015 # Allow env var to override, just like during initial cmake build. 2016 _archs = os.environ.get('PYTORCH_ROCM_ARCH', None) 2017 if not _archs: 2018 archFlags = torch._C._cuda_getArchFlags() 2019 if archFlags: 2020 archs = archFlags.split() 2021 else: 2022 archs = [] 2023 else: 2024 archs = _archs.replace(' ', ';').split(';') 2025 flags = [f'--offload-arch={arch}' for arch in archs] 2026 flags += ['-fno-gpu-rdc'] 2027 return flags 2028 2029def _get_build_directory(name: str, verbose: bool) -> str: 2030 root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR') 2031 if root_extensions_directory is None: 2032 root_extensions_directory = get_default_build_root() 2033 cu_str = ('cpu' if torch.version.cuda is None else 2034 f'cu{torch.version.cuda.replace(".", "")}') # type: ignore[attr-defined] 2035 python_version = f'py{sys.version_info.major}{sys.version_info.minor}' 2036 build_folder = f'{python_version}_{cu_str}' 2037 2038 root_extensions_directory = os.path.join( 2039 root_extensions_directory, build_folder) 2040 2041 if verbose: 2042 print(f'Using {root_extensions_directory} as PyTorch extensions root...', file=sys.stderr) 2043 2044 build_directory = os.path.join(root_extensions_directory, name) 2045 if not os.path.exists(build_directory): 2046 if verbose: 2047 print(f'Creating extension directory {build_directory}...', file=sys.stderr) 2048 # This is like mkdir -p, i.e. will also create parent directories. 2049 os.makedirs(build_directory, exist_ok=True) 2050 2051 return build_directory 2052 2053 2054def _get_num_workers(verbose: bool) -> Optional[int]: 2055 max_jobs = os.environ.get('MAX_JOBS') 2056 if max_jobs is not None and max_jobs.isdigit(): 2057 if verbose: 2058 print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...', 2059 file=sys.stderr) 2060 return int(max_jobs) 2061 if verbose: 2062 print('Allowing ninja to set a default number of workers... ' 2063 '(overridable by setting the environment variable MAX_JOBS=N)', 2064 file=sys.stderr) 2065 return None 2066 2067 2068def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None: 2069 command = ['ninja', '-v'] 2070 num_workers = _get_num_workers(verbose) 2071 if num_workers is not None: 2072 command.extend(['-j', str(num_workers)]) 2073 env = os.environ.copy() 2074 # Try to activate the vc env for the users 2075 if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' not in env: 2076 from setuptools import distutils 2077 2078 plat_name = distutils.util.get_platform() 2079 plat_spec = PLAT_TO_VCVARS[plat_name] 2080 2081 vc_env = distutils._msvccompiler._get_vc_env(plat_spec) 2082 vc_env = {k.upper(): v for k, v in vc_env.items()} 2083 for k, v in env.items(): 2084 uk = k.upper() 2085 if uk not in vc_env: 2086 vc_env[uk] = v 2087 env = vc_env 2088 try: 2089 sys.stdout.flush() 2090 sys.stderr.flush() 2091 # Warning: don't pass stdout=None to subprocess.run to get output. 2092 # subprocess.run assumes that sys.__stdout__ has not been modified and 2093 # attempts to write to it by default. However, when we call _run_ninja_build 2094 # from ahead-of-time cpp extensions, the following happens: 2095 # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__. 2096 # https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110 2097 # (it probably shouldn't do this) 2098 # 2) subprocess.run (on POSIX, with no stdout override) relies on 2099 # __stdout__ not being detached: 2100 # https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214 2101 # To work around this, we pass in the fileno directly and hope that 2102 # it is valid. 2103 stdout_fileno = 1 2104 subprocess.run( 2105 command, 2106 stdout=stdout_fileno if verbose else subprocess.PIPE, 2107 stderr=subprocess.STDOUT, 2108 cwd=build_directory, 2109 check=True, 2110 env=env) 2111 except subprocess.CalledProcessError as e: 2112 # Python 2 and 3 compatible way of getting the error object. 2113 _, error, _ = sys.exc_info() 2114 # error.output contains the stdout and stderr of the build attempt. 2115 message = error_prefix 2116 # `error` is a CalledProcessError (which has an `output`) attribute, but 2117 # mypy thinks it's Optional[BaseException] and doesn't narrow 2118 if hasattr(error, 'output') and error.output: # type: ignore[union-attr] 2119 message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}" # type: ignore[union-attr] 2120 raise RuntimeError(message) from e 2121 2122 2123def _get_exec_path(module_name, path): 2124 if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'): 2125 torch_lib_in_path = any( 2126 os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH) 2127 for p in os.getenv('PATH', '').split(';') 2128 ) 2129 if not torch_lib_in_path: 2130 os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}" 2131 return os.path.join(path, f'{module_name}{EXEC_EXT}') 2132 2133 2134def _import_module_from_library(module_name, path, is_python_module): 2135 filepath = os.path.join(path, f"{module_name}{LIB_EXT}") 2136 if is_python_module: 2137 # https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path 2138 spec = importlib.util.spec_from_file_location(module_name, filepath) 2139 assert spec is not None 2140 module = importlib.util.module_from_spec(spec) 2141 assert isinstance(spec.loader, importlib.abc.Loader) 2142 spec.loader.exec_module(module) 2143 return module 2144 else: 2145 torch.ops.load_library(filepath) 2146 return filepath 2147 2148 2149def _write_ninja_file_to_build_library(path, 2150 name, 2151 sources, 2152 extra_cflags, 2153 extra_cuda_cflags, 2154 extra_ldflags, 2155 extra_include_paths, 2156 with_cuda, 2157 is_standalone) -> None: 2158 extra_cflags = [flag.strip() for flag in extra_cflags] 2159 extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags] 2160 extra_ldflags = [flag.strip() for flag in extra_ldflags] 2161 extra_include_paths = [flag.strip() for flag in extra_include_paths] 2162 2163 # Turn into absolute paths so we can emit them into the ninja build 2164 # file wherever it is. 2165 user_includes = [os.path.abspath(file) for file in extra_include_paths] 2166 2167 # include_paths() gives us the location of torch/extension.h 2168 system_includes = include_paths(with_cuda) 2169 # sysconfig.get_path('include') gives us the location of Python.h 2170 # Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS 2171 # installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder 2172 python_include_path = sysconfig.get_path('include', scheme='nt' if IS_WINDOWS else 'posix_prefix') 2173 if python_include_path is not None: 2174 system_includes.append(python_include_path) 2175 2176 common_cflags = [] 2177 if not is_standalone: 2178 common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}') 2179 common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') 2180 2181 common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] 2182 2183 # Windows does not understand `-isystem` and quotes flags later. 2184 if IS_WINDOWS: 2185 common_cflags += [f'-I{include}' for include in user_includes + system_includes] 2186 else: 2187 common_cflags += [f'-I{shlex.quote(include)}' for include in user_includes] 2188 common_cflags += [f'-isystem {shlex.quote(include)}' for include in system_includes] 2189 2190 common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()] 2191 2192 if IS_WINDOWS: 2193 cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags 2194 cflags = _nt_quote_args(cflags) 2195 else: 2196 cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags 2197 2198 if with_cuda and IS_HIP_EXTENSION: 2199 cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS 2200 cuda_flags += extra_cuda_cflags 2201 cuda_flags += _get_rocm_arch_flags(cuda_flags) 2202 elif with_cuda: 2203 cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags() 2204 if IS_WINDOWS: 2205 for flag in COMMON_MSVC_FLAGS: 2206 cuda_flags = ['-Xcompiler', flag] + cuda_flags 2207 for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: 2208 cuda_flags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cuda_flags 2209 cuda_flags = cuda_flags + ['-std=c++17'] 2210 cuda_flags = _nt_quote_args(cuda_flags) 2211 cuda_flags += _nt_quote_args(extra_cuda_cflags) 2212 else: 2213 cuda_flags += ['--compiler-options', "'-fPIC'"] 2214 cuda_flags += extra_cuda_cflags 2215 if not any(flag.startswith('-std=') for flag in cuda_flags): 2216 cuda_flags.append('-std=c++17') 2217 cc_env = os.getenv("CC") 2218 if cc_env is not None: 2219 cuda_flags = ['-ccbin', cc_env] + cuda_flags 2220 else: 2221 cuda_flags = None 2222 2223 def object_file_path(source_file: str) -> str: 2224 # '/path/to/file.cpp' -> 'file' 2225 file_name = os.path.splitext(os.path.basename(source_file))[0] 2226 if _is_cuda_file(source_file) and with_cuda: 2227 # Use a different object filename in case a C++ and CUDA file have 2228 # the same filename but different extension (.cpp vs. .cu). 2229 target = f'{file_name}.cuda.o' 2230 else: 2231 target = f'{file_name}.o' 2232 return target 2233 2234 objects = [object_file_path(src) for src in sources] 2235 ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags 2236 2237 # The darwin linker needs explicit consent to ignore unresolved symbols. 2238 if IS_MACOS: 2239 ldflags.append('-undefined dynamic_lookup') 2240 elif IS_WINDOWS: 2241 ldflags = _nt_quote_args(ldflags) 2242 2243 ext = EXEC_EXT if is_standalone else LIB_EXT 2244 library_target = f'{name}{ext}' 2245 2246 _write_ninja_file( 2247 path=path, 2248 cflags=cflags, 2249 post_cflags=None, 2250 cuda_cflags=cuda_flags, 2251 cuda_post_cflags=None, 2252 cuda_dlink_post_cflags=None, 2253 sources=sources, 2254 objects=objects, 2255 ldflags=ldflags, 2256 library_target=library_target, 2257 with_cuda=with_cuda) 2258 2259 2260def _write_ninja_file(path, 2261 cflags, 2262 post_cflags, 2263 cuda_cflags, 2264 cuda_post_cflags, 2265 cuda_dlink_post_cflags, 2266 sources, 2267 objects, 2268 ldflags, 2269 library_target, 2270 with_cuda) -> None: 2271 r"""Write a ninja file that does the desired compiling and linking. 2272 2273 `path`: Where to write this file 2274 `cflags`: list of flags to pass to $cxx. Can be None. 2275 `post_cflags`: list of flags to append to the $cxx invocation. Can be None. 2276 `cuda_cflags`: list of flags to pass to $nvcc. Can be None. 2277 `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None. 2278 `sources`: list of paths to source files 2279 `objects`: list of desired paths to objects, one per source. 2280 `ldflags`: list of flags to pass to linker. Can be None. 2281 `library_target`: Name of the output library. Can be None; in that case, 2282 we do no linking. 2283 `with_cuda`: If we should be compiling with CUDA. 2284 """ 2285 def sanitize_flags(flags): 2286 if flags is None: 2287 return [] 2288 else: 2289 return [flag.strip() for flag in flags] 2290 2291 cflags = sanitize_flags(cflags) 2292 post_cflags = sanitize_flags(post_cflags) 2293 cuda_cflags = sanitize_flags(cuda_cflags) 2294 cuda_post_cflags = sanitize_flags(cuda_post_cflags) 2295 cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags) 2296 ldflags = sanitize_flags(ldflags) 2297 2298 # Sanity checks... 2299 assert len(sources) == len(objects) 2300 assert len(sources) > 0 2301 2302 compiler = get_cxx_compiler() 2303 2304 # Version 1.3 is required for the `deps` directive. 2305 config = ['ninja_required_version = 1.3'] 2306 config.append(f'cxx = {compiler}') 2307 if with_cuda or cuda_dlink_post_cflags: 2308 if "PYTORCH_NVCC" in os.environ: 2309 nvcc = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here 2310 else: 2311 if IS_HIP_EXTENSION: 2312 nvcc = _join_rocm_home('bin', 'hipcc') 2313 else: 2314 nvcc = _join_cuda_home('bin', 'nvcc') 2315 config.append(f'nvcc = {nvcc}') 2316 2317 if IS_HIP_EXTENSION: 2318 post_cflags = COMMON_HIP_FLAGS + post_cflags 2319 flags = [f'cflags = {" ".join(cflags)}'] 2320 flags.append(f'post_cflags = {" ".join(post_cflags)}') 2321 if with_cuda: 2322 flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') 2323 flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') 2324 flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') 2325 flags.append(f'ldflags = {" ".join(ldflags)}') 2326 2327 # Turn into absolute paths so we can emit them into the ninja build 2328 # file wherever it is. 2329 sources = [os.path.abspath(file) for file in sources] 2330 2331 # See https://ninja-build.org/build.ninja.html for reference. 2332 compile_rule = ['rule compile'] 2333 if IS_WINDOWS: 2334 compile_rule.append( 2335 ' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags') 2336 compile_rule.append(' deps = msvc') 2337 else: 2338 compile_rule.append( 2339 ' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags') 2340 compile_rule.append(' depfile = $out.d') 2341 compile_rule.append(' deps = gcc') 2342 2343 if with_cuda: 2344 cuda_compile_rule = ['rule cuda_compile'] 2345 nvcc_gendeps = '' 2346 # --generate-dependencies-with-compile is not supported by ROCm 2347 # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time. 2348 if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1': 2349 cuda_compile_rule.append(' depfile = $out.d') 2350 cuda_compile_rule.append(' deps = gcc') 2351 # Note: non-system deps with nvcc are only supported 2352 # on Linux so use --generate-dependencies-with-compile 2353 # to make this work on Windows too. 2354 nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' 2355 cuda_compile_rule.append( 2356 f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') 2357 2358 # Emit one build rule per source to enable incremental build. 2359 build = [] 2360 for source_file, object_file in zip(sources, objects): 2361 is_cuda_source = _is_cuda_file(source_file) and with_cuda 2362 rule = 'cuda_compile' if is_cuda_source else 'compile' 2363 if IS_WINDOWS: 2364 source_file = source_file.replace(':', '$:') 2365 object_file = object_file.replace(':', '$:') 2366 source_file = source_file.replace(" ", "$ ") 2367 object_file = object_file.replace(" ", "$ ") 2368 build.append(f'build {object_file}: {rule} {source_file}') 2369 2370 if cuda_dlink_post_cflags: 2371 devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o') 2372 devlink_rule = ['rule cuda_devlink'] 2373 devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags') 2374 devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}'] 2375 objects += [devlink_out] 2376 else: 2377 devlink_rule, devlink = [], [] 2378 2379 if library_target is not None: 2380 link_rule = ['rule link'] 2381 if IS_WINDOWS: 2382 cl_paths = subprocess.check_output(['where', 2383 'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n') 2384 if len(cl_paths) >= 1: 2385 cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') 2386 else: 2387 raise RuntimeError("MSVC is required to load C++ extensions") 2388 link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out') 2389 else: 2390 link_rule.append(' command = $cxx $in $ldflags -o $out') 2391 2392 link = [f'build {library_target}: link {" ".join(objects)}'] 2393 2394 default = [f'default {library_target}'] 2395 else: 2396 link_rule, link, default = [], [], [] 2397 2398 # 'Blocks' should be separated by newlines, for visual benefit. 2399 blocks = [config, flags, compile_rule] 2400 if with_cuda: 2401 blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] 2402 blocks += [devlink_rule, link_rule, build, devlink, link, default] 2403 content = "\n\n".join("\n".join(b) for b in blocks) 2404 # Ninja requires a new lines at the end of the .ninja file 2405 content += "\n" 2406 _maybe_write(path, content) 2407 2408def _join_cuda_home(*paths) -> str: 2409 """ 2410 Join paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set. 2411 2412 This is basically a lazy way of raising an error for missing $CUDA_HOME 2413 only once we need to get any CUDA-specific path. 2414 """ 2415 if CUDA_HOME is None: 2416 raise OSError('CUDA_HOME environment variable is not set. ' 2417 'Please set it to your CUDA install root.') 2418 return os.path.join(CUDA_HOME, *paths) 2419 2420 2421def _is_cuda_file(path: str) -> bool: 2422 valid_ext = ['.cu', '.cuh'] 2423 if IS_HIP_EXTENSION: 2424 valid_ext.append('.hip') 2425 return os.path.splitext(path)[1] in valid_ext 2426