xref: /aosp_15_r20/external/pytorch/torch/utils/cpp_extension.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import glob
4import importlib
5import importlib.abc
6import os
7import re
8import shlex
9import shutil
10import setuptools
11import subprocess
12import sys
13import sysconfig
14import warnings
15import collections
16from pathlib import Path
17import errno
18
19import torch
20import torch._appdirs
21from .file_baton import FileBaton
22from ._cpp_extension_versioner import ExtensionVersioner
23from .hipify import hipify_python
24from .hipify.hipify_python import GeneratedFileCleaner
25from typing import Dict, List, Optional, Union, Tuple
26from torch.torch_version import TorchVersion, Version
27
28from setuptools.command.build_ext import build_ext
29
30IS_WINDOWS = sys.platform == 'win32'
31IS_MACOS = sys.platform.startswith('darwin')
32IS_LINUX = sys.platform.startswith('linux')
33LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
34EXEC_EXT = '.exe' if IS_WINDOWS else ''
35CLIB_PREFIX = '' if IS_WINDOWS else 'lib'
36CLIB_EXT = '.dll' if IS_WINDOWS else '.so'
37SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared'
38
39_HERE = os.path.abspath(__file__)
40_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
41TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib')
42
43
44SUBPROCESS_DECODE_ARGS = ('oem',) if IS_WINDOWS else ()
45MINIMUM_GCC_VERSION = (5, 0, 0)
46MINIMUM_MSVC_VERSION = (19, 0, 24215)
47
48VersionRange = Tuple[Tuple[int, ...], Tuple[int, ...]]
49VersionMap = Dict[str, VersionRange]
50# The following values were taken from the following GitHub gist that
51# summarizes the minimum valid major versions of g++/clang++ for each supported
52# CUDA version: https://gist.github.com/ax3l/9489132
53# Or from include/crt/host_config.h in the CUDA SDK
54# The second value is the exclusive(!) upper bound, i.e. min <= version < max
55CUDA_GCC_VERSIONS: VersionMap = {
56    '11.0': (MINIMUM_GCC_VERSION, (10, 0)),
57    '11.1': (MINIMUM_GCC_VERSION, (11, 0)),
58    '11.2': (MINIMUM_GCC_VERSION, (11, 0)),
59    '11.3': (MINIMUM_GCC_VERSION, (11, 0)),
60    '11.4': ((6, 0, 0), (12, 0)),
61    '11.5': ((6, 0, 0), (12, 0)),
62    '11.6': ((6, 0, 0), (12, 0)),
63    '11.7': ((6, 0, 0), (12, 0)),
64}
65
66MINIMUM_CLANG_VERSION = (3, 3, 0)
67CUDA_CLANG_VERSIONS: VersionMap = {
68    '11.1': (MINIMUM_CLANG_VERSION, (11, 0)),
69    '11.2': (MINIMUM_CLANG_VERSION, (12, 0)),
70    '11.3': (MINIMUM_CLANG_VERSION, (12, 0)),
71    '11.4': (MINIMUM_CLANG_VERSION, (13, 0)),
72    '11.5': (MINIMUM_CLANG_VERSION, (13, 0)),
73    '11.6': (MINIMUM_CLANG_VERSION, (14, 0)),
74    '11.7': (MINIMUM_CLANG_VERSION, (14, 0)),
75}
76
77__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension",
78           "CppExtension", "CUDAExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available",
79           "verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"]
80# Taken directly from python stdlib < 3.9
81# See https://github.com/pytorch/pytorch/issues/48617
82def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
83    """Quote command-line arguments for DOS/Windows conventions.
84
85    Just wraps every argument which contains blanks in double quotes, and
86    returns a new argument list.
87    """
88    # Cover None-type
89    if not args:
90        return []
91    return [f'"{arg}"' if ' ' in arg else arg for arg in args]
92
93def _find_cuda_home() -> Optional[str]:
94    """Find the CUDA install path."""
95    # Guess #1
96    cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
97    if cuda_home is None:
98        # Guess #2
99        nvcc_path = shutil.which("nvcc")
100        if nvcc_path is not None:
101            cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
102        else:
103            # Guess #3
104            if IS_WINDOWS:
105                cuda_homes = glob.glob(
106                    'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
107                if len(cuda_homes) == 0:
108                    cuda_home = ''
109                else:
110                    cuda_home = cuda_homes[0]
111            else:
112                cuda_home = '/usr/local/cuda'
113            if not os.path.exists(cuda_home):
114                cuda_home = None
115    if cuda_home and not torch.cuda.is_available():
116        print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'",
117              file=sys.stderr)
118    return cuda_home
119
120def _find_rocm_home() -> Optional[str]:
121    """Find the ROCm install path."""
122    # Guess #1
123    rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
124    if rocm_home is None:
125        # Guess #2
126        hipcc_path = shutil.which('hipcc')
127        if hipcc_path is not None:
128            rocm_home = os.path.dirname(os.path.dirname(
129                os.path.realpath(hipcc_path)))
130            # can be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
131            if os.path.basename(rocm_home) == 'hip':
132                rocm_home = os.path.dirname(rocm_home)
133        else:
134            # Guess #3
135            fallback_path = '/opt/rocm'
136            if os.path.exists(fallback_path):
137                rocm_home = fallback_path
138    if rocm_home and torch.version.hip is None:
139        print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'",
140              file=sys.stderr)
141    return rocm_home
142
143
144def _join_rocm_home(*paths) -> str:
145    """
146    Join paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
147
148    This is basically a lazy way of raising an error for missing $ROCM_HOME
149    only once we need to get any ROCm-specific path.
150    """
151    if ROCM_HOME is None:
152        raise OSError('ROCM_HOME environment variable is not set. '
153                      'Please set it to your ROCm install root.')
154    elif IS_WINDOWS:
155        raise OSError('Building PyTorch extensions using '
156                      'ROCm and Windows is not supported.')
157    return os.path.join(ROCM_HOME, *paths)
158
159
160ABI_INCOMPATIBILITY_WARNING = '''
161
162                               !! WARNING !!
163
164!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
165Your compiler ({}) may be ABI-incompatible with PyTorch!
166Please use a compiler that is ABI-compatible with GCC 5.0 and above.
167See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.
168
169See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
170for instructions on how to install GCC 5 or higher.
171!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
172
173                              !! WARNING !!
174'''
175WRONG_COMPILER_WARNING = '''
176
177                               !! WARNING !!
178
179!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
180Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
181built with for this platform, which is {pytorch_compiler} on {platform}. Please
182use {pytorch_compiler} to to compile your extension. Alternatively, you may
183compile PyTorch from source using {user_compiler}, and then you can also use
184{user_compiler} to compile your extension.
185
186See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
187with compiling PyTorch from source.
188!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
189
190                              !! WARNING !!
191'''
192CUDA_MISMATCH_MESSAGE = '''
193The detected CUDA version ({0}) mismatches the version that was used to compile
194PyTorch ({1}). Please make sure to use the same CUDA versions.
195'''
196CUDA_MISMATCH_WARN = "The detected CUDA version ({0}) has a minor version mismatch with the version that was used to compile PyTorch ({1}). Most likely this shouldn't be a problem."
197CUDA_NOT_FOUND_MESSAGE = '''
198CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATH
199environment variable or add NVCC to your system PATH. The extension compilation will fail.
200'''
201ROCM_HOME = _find_rocm_home()
202HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None
203IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
204ROCM_VERSION = None
205if torch.version.hip is not None:
206    ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
207
208CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None
209CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
210# PyTorch releases have the version pattern major.minor.patch, whereas when
211# PyTorch is built from source, we append the git commit hash, which gives
212# it the below pattern.
213BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
214
215COMMON_MSVC_FLAGS = ['/MD', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', '/wd4190', '/wd4624', '/wd4067', '/wd4068', '/EHsc']
216
217MSVC_IGNORE_CUDAFE_WARNINGS = [
218    'base_class_has_different_dll_interface',
219    'field_without_dll_interface',
220    'dll_interface_conflict_none_assumed',
221    'dll_interface_conflict_dllexport_assumed'
222]
223
224COMMON_NVCC_FLAGS = [
225    '-D__CUDA_NO_HALF_OPERATORS__',
226    '-D__CUDA_NO_HALF_CONVERSIONS__',
227    '-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
228    '-D__CUDA_NO_HALF2_OPERATORS__',
229    '--expt-relaxed-constexpr'
230]
231
232COMMON_HIP_FLAGS = [
233    '-fPIC',
234    '-D__HIP_PLATFORM_AMD__=1',
235    '-DUSE_ROCM=1',
236    '-DHIPBLAS_V2',
237]
238
239COMMON_HIPCC_FLAGS = [
240    '-DCUDA_HAS_FP16=1',
241    '-D__HIP_NO_HALF_OPERATORS__=1',
242    '-D__HIP_NO_HALF_CONVERSIONS__=1',
243]
244
245JIT_EXTENSION_VERSIONER = ExtensionVersioner()
246
247PLAT_TO_VCVARS = {
248    'win32' : 'x86',
249    'win-amd64' : 'x86_amd64',
250}
251
252def get_cxx_compiler():
253    if IS_WINDOWS:
254        compiler = os.environ.get('CXX', 'cl')
255    else:
256        compiler = os.environ.get('CXX', 'c++')
257    return compiler
258
259def _is_binary_build() -> bool:
260    return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
261
262
263def _accepted_compilers_for_platform() -> List[str]:
264    # gnu-c++ and gnu-cc are the conda gcc compilers
265    return ['clang++', 'clang'] if IS_MACOS else ['g++', 'gcc', 'gnu-c++', 'gnu-cc', 'clang++', 'clang']
266
267def _maybe_write(filename, new_content):
268    r'''
269    Equivalent to writing the content into the file but will not touch the file
270    if it already had the right content (to avoid triggering recompile).
271    '''
272    if os.path.exists(filename):
273        with open(filename) as f:
274            content = f.read()
275
276        if content == new_content:
277            # The file already contains the right thing!
278            return
279
280    with open(filename, 'w') as source_file:
281        source_file.write(new_content)
282
283def get_default_build_root() -> str:
284    """
285    Return the path to the root folder under which extensions will built.
286
287    For each extension module built, there will be one folder underneath the
288    folder returned by this function. For example, if ``p`` is the path
289    returned by this function and ``ext`` the name of an extension, the build
290    folder for the extension will be ``p/ext``.
291
292    This directory is **user-specific** so that multiple users on the same
293    machine won't meet permission issues.
294    """
295    return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))
296
297
298def check_compiler_ok_for_platform(compiler: str) -> bool:
299    """
300    Verify that the compiler is the expected one for the current platform.
301
302    Args:
303        compiler (str): The compiler executable to check.
304
305    Returns:
306        True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
307        and always True for Windows.
308    """
309    if IS_WINDOWS:
310        return True
311    compiler_path = shutil.which(compiler)
312    if compiler_path is None:
313        return False
314    # Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
315    compiler_path = os.path.realpath(compiler_path)
316    # Check the compiler name
317    if any(name in compiler_path for name in _accepted_compilers_for_platform()):
318        return True
319    # If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag
320    env = os.environ.copy()
321    env['LC_ALL'] = 'C'  # Don't localize output
322    version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
323    if IS_LINUX:
324        # Check for 'gcc' or 'g++' for sccache wrapper
325        pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
326        results = re.findall(pattern, version_string)
327        if len(results) != 1:
328            # Clang is also a supported compiler on Linux
329            # Though on Ubuntu it's sometimes called "Ubuntu clang version"
330            return 'clang version' in version_string
331        compiler_path = os.path.realpath(results[0].strip())
332        # On RHEL/CentOS c++ is a gcc compiler wrapper
333        if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
334            return True
335        return any(name in compiler_path for name in _accepted_compilers_for_platform())
336    if IS_MACOS:
337        # Check for 'clang' or 'clang++'
338        return version_string.startswith("Apple clang")
339    return False
340
341
342def get_compiler_abi_compatibility_and_version(compiler) -> Tuple[bool, TorchVersion]:
343    """
344    Determine if the given compiler is ABI-compatible with PyTorch alongside its version.
345
346    Args:
347        compiler (str): The compiler executable name to check (e.g. ``g++``).
348            Must be executable in a shell process.
349
350    Returns:
351        A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch,
352        followed by a `TorchVersion` string that contains the compiler version separated by dots.
353    """
354    if not _is_binary_build():
355        return (True, TorchVersion('0.0.0'))
356    if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
357        return (True, TorchVersion('0.0.0'))
358
359    # First check if the compiler is one of the expected ones for the particular platform.
360    if not check_compiler_ok_for_platform(compiler):
361        warnings.warn(WRONG_COMPILER_WARNING.format(
362            user_compiler=compiler,
363            pytorch_compiler=_accepted_compilers_for_platform()[0],
364            platform=sys.platform))
365        return (False, TorchVersion('0.0.0'))
366
367    if IS_MACOS:
368        # There is no particular minimum version we need for clang, so we're good here.
369        return (True, TorchVersion('0.0.0'))
370    try:
371        if IS_LINUX:
372            minimum_required_version = MINIMUM_GCC_VERSION
373            versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
374            version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.')
375        else:
376            minimum_required_version = MINIMUM_MSVC_VERSION
377            compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
378            match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip())
379            version = ['0', '0', '0'] if match is None else list(match.groups())
380    except Exception:
381        _, error, _ = sys.exc_info()
382        warnings.warn(f'Error checking compiler version for {compiler}: {error}')
383        return (False, TorchVersion('0.0.0'))
384
385    if tuple(map(int, version)) >= minimum_required_version:
386        return (True, TorchVersion('.'.join(version)))
387
388    compiler = f'{compiler} {".".join(version)}'
389    warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))
390
391    return (False, TorchVersion('.'.join(version)))
392
393
394def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None:
395    if not CUDA_HOME:
396        raise RuntimeError(CUDA_NOT_FOUND_MESSAGE)
397
398    nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc')
399    cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS)
400    cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str)
401    if cuda_version is None:
402        return
403
404    cuda_str_version = cuda_version.group(1)
405    cuda_ver = Version(cuda_str_version)
406    if torch.version.cuda is None:
407        return
408
409    torch_cuda_version = Version(torch.version.cuda)
410    if cuda_ver != torch_cuda_version:
411        # major/minor attributes are only available in setuptools>=49.4.0
412        if getattr(cuda_ver, "major", None) is None:
413            raise ValueError("setuptools>=49.4.0 is required")
414        if cuda_ver.major != torch_cuda_version.major:
415            raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
416        warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))
417
418    if not (sys.platform.startswith('linux') and
419            os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') not in ['ON', '1', 'YES', 'TRUE', 'Y'] and
420            _is_binary_build()):
421        return
422
423    cuda_compiler_bounds: VersionMap = CUDA_CLANG_VERSIONS if compiler_name.startswith('clang') else CUDA_GCC_VERSIONS
424
425    if cuda_str_version not in cuda_compiler_bounds:
426        warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
427    else:
428        min_compiler_version, max_excl_compiler_version = cuda_compiler_bounds[cuda_str_version]
429        # Special case for 11.4.0, which has lower compiler bounds than 11.4.1
430        if "V11.4.48" in cuda_version_str and cuda_compiler_bounds == CUDA_GCC_VERSIONS:
431            max_excl_compiler_version = (11, 0)
432        min_compiler_version_str = '.'.join(map(str, min_compiler_version))
433        max_excl_compiler_version_str = '.'.join(map(str, max_excl_compiler_version))
434
435        version_bound_str = f'>={min_compiler_version_str}, <{max_excl_compiler_version_str}'
436
437        if compiler_version < TorchVersion(min_compiler_version_str):
438            raise RuntimeError(
439                f'The current installed version of {compiler_name} ({compiler_version}) is less '
440                f'than the minimum required version by CUDA {cuda_str_version} ({min_compiler_version_str}). '
441                f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
442            )
443        if compiler_version >= TorchVersion(max_excl_compiler_version_str):
444            raise RuntimeError(
445                f'The current installed version of {compiler_name} ({compiler_version}) is greater '
446                f'than the maximum required version by CUDA {cuda_str_version}. '
447                f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
448            )
449
450
451class BuildExtension(build_ext):
452    """
453    A custom :mod:`setuptools` build extension .
454
455    This :class:`setuptools.build_ext` subclass takes care of passing the
456    minimum required compiler flags (e.g. ``-std=c++17``) as well as mixed
457    C++/CUDA compilation (and support for CUDA files in general).
458
459    When using :class:`BuildExtension`, it is allowed to supply a dictionary
460    for ``extra_compile_args`` (rather than the usual list) that maps from
461    languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to
462    supply to the compiler. This makes it possible to supply different flags to
463    the C++ and CUDA compiler during mixed compilation.
464
465    ``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we
466    attempt to build using the Ninja backend. Ninja greatly speeds up
467    compilation compared to the standard ``setuptools.build_ext``.
468    Fallbacks to the standard distutils backend if Ninja is not available.
469
470    .. note::
471        By default, the Ninja backend uses #CPUS + 2 workers to build the
472        extension. This may use up too many resources on some systems. One
473        can control the number of workers by setting the `MAX_JOBS` environment
474        variable to a non-negative number.
475    """
476
477    @classmethod
478    def with_options(cls, **options):
479        """Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options."""
480        class cls_with_options(cls):  # type: ignore[misc, valid-type]
481            def __init__(self, *args, **kwargs):
482                kwargs.update(options)
483                super().__init__(*args, **kwargs)
484
485        return cls_with_options
486
487    def __init__(self, *args, **kwargs) -> None:
488        super().__init__(*args, **kwargs)
489        self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
490
491        self.use_ninja = kwargs.get('use_ninja', True)
492        if self.use_ninja:
493            # Test if we can use ninja. Fallback otherwise.
494            msg = ('Attempted to use ninja as the BuildExtension backend but '
495                   '{}. Falling back to using the slow distutils backend.')
496            if not is_ninja_available():
497                warnings.warn(msg.format('we could not find ninja.'))
498                self.use_ninja = False
499
500    def finalize_options(self) -> None:
501        super().finalize_options()
502        if self.use_ninja:
503            self.force = True
504
505    def build_extensions(self) -> None:
506        compiler_name, compiler_version = self._check_abi()
507
508        cuda_ext = False
509        extension_iter = iter(self.extensions)
510        extension = next(extension_iter, None)
511        while not cuda_ext and extension:
512            for source in extension.sources:
513                _, ext = os.path.splitext(source)
514                if ext == '.cu':
515                    cuda_ext = True
516                    break
517            extension = next(extension_iter, None)
518
519        if cuda_ext and not IS_HIP_EXTENSION:
520            _check_cuda_version(compiler_name, compiler_version)
521
522        for extension in self.extensions:
523            # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
524            # extra_compile_args is a dict. Otherwise, default torch flags do
525            # not get passed. Necessary when only one of 'cxx' and 'nvcc' is
526            # passed to extra_compile_args in CUDAExtension, i.e.
527            #   CUDAExtension(..., extra_compile_args={'cxx': [...]})
528            # or
529            #   CUDAExtension(..., extra_compile_args={'nvcc': [...]})
530            if isinstance(extension.extra_compile_args, dict):
531                for ext in ['cxx', 'nvcc']:
532                    if ext not in extension.extra_compile_args:
533                        extension.extra_compile_args[ext] = []
534
535            self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H')
536            # See note [Pybind11 ABI constants]
537            for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
538                val = getattr(torch._C, f"_PYBIND11_{name}")
539                if val is not None and not IS_WINDOWS:
540                    self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"')
541            self._define_torch_extension_name(extension)
542            self._add_gnu_cpp_abi_flag(extension)
543
544            if 'nvcc_dlink' in extension.extra_compile_args:
545                assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
546
547        # Register .cu, .cuh, .hip, and .mm as valid source extensions.
548        self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
549        if torch.backends.mps.is_built():
550            self.compiler.src_extensions += ['.mm']
551        # Save the original _compile method for later.
552        if self.compiler.compiler_type == 'msvc':
553            self.compiler._cpp_extensions += ['.cu', '.cuh']
554            original_compile = self.compiler.compile
555            original_spawn = self.compiler.spawn
556        else:
557            original_compile = self.compiler._compile
558
559        def append_std17_if_no_std_present(cflags) -> None:
560            # NVCC does not allow multiple -std to be passed, so we avoid
561            # overriding the option if the user explicitly passed it.
562            cpp_format_prefix = '/{}:' if self.compiler.compiler_type == 'msvc' else '-{}='
563            cpp_flag_prefix = cpp_format_prefix.format('std')
564            cpp_flag = cpp_flag_prefix + 'c++17'
565            if not any(flag.startswith(cpp_flag_prefix) for flag in cflags):
566                cflags.append(cpp_flag)
567
568        def unix_cuda_flags(cflags):
569            cflags = (COMMON_NVCC_FLAGS +
570                      ['--compiler-options', "'-fPIC'"] +
571                      cflags + _get_cuda_arch_flags(cflags))
572
573            # NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid
574            # overriding the option if the user explicitly passed it.
575            _ccbin = os.getenv("CC")
576            if (
577                _ccbin is not None
578                and not any(flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags)
579            ):
580                cflags.extend(['-ccbin', _ccbin])
581
582            return cflags
583
584        def convert_to_absolute_paths_inplace(paths):
585            # Helper function. See Note [Absolute include_dirs]
586            if paths is not None:
587                for i in range(len(paths)):
588                    if not os.path.isabs(paths[i]):
589                        paths[i] = os.path.abspath(paths[i])
590
591        def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
592            # Copy before we make any modifications.
593            cflags = copy.deepcopy(extra_postargs)
594            try:
595                original_compiler = self.compiler.compiler_so
596                if _is_cuda_file(src):
597                    nvcc = [_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc')]
598                    self.compiler.set_executable('compiler_so', nvcc)
599                    if isinstance(cflags, dict):
600                        cflags = cflags['nvcc']
601                    if IS_HIP_EXTENSION:
602                        cflags = COMMON_HIPCC_FLAGS + cflags + _get_rocm_arch_flags(cflags)
603                    else:
604                        cflags = unix_cuda_flags(cflags)
605                elif isinstance(cflags, dict):
606                    cflags = cflags['cxx']
607                if IS_HIP_EXTENSION:
608                    cflags = COMMON_HIP_FLAGS + cflags
609                append_std17_if_no_std_present(cflags)
610
611                original_compile(obj, src, ext, cc_args, cflags, pp_opts)
612            finally:
613                # Put the original compiler back in place.
614                self.compiler.set_executable('compiler_so', original_compiler)
615
616        def unix_wrap_ninja_compile(sources,
617                                    output_dir=None,
618                                    macros=None,
619                                    include_dirs=None,
620                                    debug=0,
621                                    extra_preargs=None,
622                                    extra_postargs=None,
623                                    depends=None):
624            r"""Compiles sources by outputting a ninja file and running it."""
625            # NB: I copied some lines from self.compiler (which is an instance
626            # of distutils.UnixCCompiler). See the following link.
627            # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567
628            # This can be fragile, but a lot of other repos also do this
629            # (see https://github.com/search?q=_setup_compile&type=Code)
630            # so it is probably OK; we'll also get CI signal if/when
631            # we update our python version (which is when distutils can be
632            # upgraded)
633
634            # Use absolute path for output_dir so that the object file paths
635            # (`objects`) get generated with absolute paths.
636            output_dir = os.path.abspath(output_dir)
637
638            # See Note [Absolute include_dirs]
639            convert_to_absolute_paths_inplace(self.compiler.include_dirs)
640
641            _, objects, extra_postargs, pp_opts, _ = \
642                self.compiler._setup_compile(output_dir, macros,
643                                             include_dirs, sources,
644                                             depends, extra_postargs)
645            common_cflags = self.compiler._get_cc_args(pp_opts, debug, extra_preargs)
646            extra_cc_cflags = self.compiler.compiler_so[1:]
647            with_cuda = any(map(_is_cuda_file, sources))
648
649            # extra_postargs can be either:
650            # - a dict mapping cxx/nvcc to extra flags
651            # - a list of extra flags.
652            if isinstance(extra_postargs, dict):
653                post_cflags = extra_postargs['cxx']
654            else:
655                post_cflags = list(extra_postargs)
656            if IS_HIP_EXTENSION:
657                post_cflags = COMMON_HIP_FLAGS + post_cflags
658            append_std17_if_no_std_present(post_cflags)
659
660            cuda_post_cflags = None
661            cuda_cflags = None
662            if with_cuda:
663                cuda_cflags = common_cflags
664                if isinstance(extra_postargs, dict):
665                    cuda_post_cflags = extra_postargs['nvcc']
666                else:
667                    cuda_post_cflags = list(extra_postargs)
668                if IS_HIP_EXTENSION:
669                    cuda_post_cflags = cuda_post_cflags + _get_rocm_arch_flags(cuda_post_cflags)
670                    cuda_post_cflags = COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_post_cflags
671                else:
672                    cuda_post_cflags = unix_cuda_flags(cuda_post_cflags)
673                append_std17_if_no_std_present(cuda_post_cflags)
674                cuda_cflags = [shlex.quote(f) for f in cuda_cflags]
675                cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags]
676
677            if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
678                cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink'])
679            else:
680                cuda_dlink_post_cflags = None
681            _write_ninja_file_and_compile_objects(
682                sources=sources,
683                objects=objects,
684                cflags=[shlex.quote(f) for f in extra_cc_cflags + common_cflags],
685                post_cflags=[shlex.quote(f) for f in post_cflags],
686                cuda_cflags=cuda_cflags,
687                cuda_post_cflags=cuda_post_cflags,
688                cuda_dlink_post_cflags=cuda_dlink_post_cflags,
689                build_directory=output_dir,
690                verbose=True,
691                with_cuda=with_cuda)
692
693            # Return *all* object filenames, not just the ones we just built.
694            return objects
695
696        def win_cuda_flags(cflags):
697            return (COMMON_NVCC_FLAGS +
698                    cflags + _get_cuda_arch_flags(cflags))
699
700        def win_wrap_single_compile(sources,
701                                    output_dir=None,
702                                    macros=None,
703                                    include_dirs=None,
704                                    debug=0,
705                                    extra_preargs=None,
706                                    extra_postargs=None,
707                                    depends=None):
708
709            self.cflags = copy.deepcopy(extra_postargs)
710            extra_postargs = None
711
712            def spawn(cmd):
713                # Using regex to match src, obj and include files
714                src_regex = re.compile('/T(p|c)(.*)')
715                src_list = [
716                    m.group(2) for m in (src_regex.match(elem) for elem in cmd)
717                    if m
718                ]
719
720                obj_regex = re.compile('/Fo(.*)')
721                obj_list = [
722                    m.group(1) for m in (obj_regex.match(elem) for elem in cmd)
723                    if m
724                ]
725
726                include_regex = re.compile(r'((\-|\/)I.*)')
727                include_list = [
728                    m.group(1)
729                    for m in (include_regex.match(elem) for elem in cmd) if m
730                ]
731
732                if len(src_list) >= 1 and len(obj_list) >= 1:
733                    src = src_list[0]
734                    obj = obj_list[0]
735                    if _is_cuda_file(src):
736                        nvcc = _join_cuda_home('bin', 'nvcc')
737                        if isinstance(self.cflags, dict):
738                            cflags = self.cflags['nvcc']
739                        elif isinstance(self.cflags, list):
740                            cflags = self.cflags
741                        else:
742                            cflags = []
743
744                        cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env']
745                        for flag in COMMON_MSVC_FLAGS:
746                            cflags = ['-Xcompiler', flag] + cflags
747                        for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
748                            cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags
749                        cmd = [nvcc, '-c', src, '-o', obj] + include_list + cflags
750                    elif isinstance(self.cflags, dict):
751                        cflags = COMMON_MSVC_FLAGS + self.cflags['cxx']
752                        append_std17_if_no_std_present(cflags)
753                        cmd += cflags
754                    elif isinstance(self.cflags, list):
755                        cflags = COMMON_MSVC_FLAGS + self.cflags
756                        append_std17_if_no_std_present(cflags)
757                        cmd += cflags
758
759                return original_spawn(cmd)
760
761            try:
762                self.compiler.spawn = spawn
763                return original_compile(sources, output_dir, macros,
764                                        include_dirs, debug, extra_preargs,
765                                        extra_postargs, depends)
766            finally:
767                self.compiler.spawn = original_spawn
768
769        def win_wrap_ninja_compile(sources,
770                                   output_dir=None,
771                                   macros=None,
772                                   include_dirs=None,
773                                   debug=0,
774                                   extra_preargs=None,
775                                   extra_postargs=None,
776                                   depends=None):
777
778            if not self.compiler.initialized:
779                self.compiler.initialize()
780            output_dir = os.path.abspath(output_dir)
781
782            # Note [Absolute include_dirs]
783            # Convert relative path in self.compiler.include_dirs to absolute path if any,
784            # For ninja build, the build location is not local, the build happens
785            # in a in script created build folder, relative path lost their correctness.
786            # To be consistent with jit extension, we allow user to enter relative include_dirs
787            # in setuptools.setup, and we convert the relative path to absolute path here
788            convert_to_absolute_paths_inplace(self.compiler.include_dirs)
789
790            _, objects, extra_postargs, pp_opts, _ = \
791                self.compiler._setup_compile(output_dir, macros,
792                                             include_dirs, sources,
793                                             depends, extra_postargs)
794            common_cflags = extra_preargs or []
795            cflags = []
796            if debug:
797                cflags.extend(self.compiler.compile_options_debug)
798            else:
799                cflags.extend(self.compiler.compile_options)
800            common_cflags.extend(COMMON_MSVC_FLAGS)
801            cflags = cflags + common_cflags + pp_opts
802            with_cuda = any(map(_is_cuda_file, sources))
803
804            # extra_postargs can be either:
805            # - a dict mapping cxx/nvcc to extra flags
806            # - a list of extra flags.
807            if isinstance(extra_postargs, dict):
808                post_cflags = extra_postargs['cxx']
809            else:
810                post_cflags = list(extra_postargs)
811            append_std17_if_no_std_present(post_cflags)
812
813            cuda_post_cflags = None
814            cuda_cflags = None
815            if with_cuda:
816                cuda_cflags = ['-std=c++17', '--use-local-env']
817                for common_cflag in common_cflags:
818                    cuda_cflags.append('-Xcompiler')
819                    cuda_cflags.append(common_cflag)
820                for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
821                    cuda_cflags.append('-Xcudafe')
822                    cuda_cflags.append('--diag_suppress=' + ignore_warning)
823                cuda_cflags.extend(pp_opts)
824                if isinstance(extra_postargs, dict):
825                    cuda_post_cflags = extra_postargs['nvcc']
826                else:
827                    cuda_post_cflags = list(extra_postargs)
828                cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
829
830            cflags = _nt_quote_args(cflags)
831            post_cflags = _nt_quote_args(post_cflags)
832            if with_cuda:
833                cuda_cflags = _nt_quote_args(cuda_cflags)
834                cuda_post_cflags = _nt_quote_args(cuda_post_cflags)
835            if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
836                cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink'])
837            else:
838                cuda_dlink_post_cflags = None
839
840            _write_ninja_file_and_compile_objects(
841                sources=sources,
842                objects=objects,
843                cflags=cflags,
844                post_cflags=post_cflags,
845                cuda_cflags=cuda_cflags,
846                cuda_post_cflags=cuda_post_cflags,
847                cuda_dlink_post_cflags=cuda_dlink_post_cflags,
848                build_directory=output_dir,
849                verbose=True,
850                with_cuda=with_cuda)
851
852            # Return *all* object filenames, not just the ones we just built.
853            return objects
854
855        # Monkey-patch the _compile or compile method.
856        # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
857        if self.compiler.compiler_type == 'msvc':
858            if self.use_ninja:
859                self.compiler.compile = win_wrap_ninja_compile
860            else:
861                self.compiler.compile = win_wrap_single_compile
862        else:
863            if self.use_ninja:
864                self.compiler.compile = unix_wrap_ninja_compile
865            else:
866                self.compiler._compile = unix_wrap_single_compile
867
868        build_ext.build_extensions(self)
869
870    def get_ext_filename(self, ext_name):
871        # Get the original shared library name. For Python 3, this name will be
872        # suffixed with "<SOABI>.so", where <SOABI> will be something like
873        # cpython-37m-x86_64-linux-gnu.
874        ext_filename = super().get_ext_filename(ext_name)
875        # If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI
876        # component. This makes building shared libraries with setuptools that
877        # aren't Python modules nicer.
878        if self.no_python_abi_suffix:
879            # The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"].
880            ext_filename_parts = ext_filename.split('.')
881            # Omit the second to last element.
882            without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
883            ext_filename = '.'.join(without_abi)
884        return ext_filename
885
886    def _check_abi(self) -> Tuple[str, TorchVersion]:
887        # On some platforms, like Windows, compiler_cxx is not available.
888        if hasattr(self.compiler, 'compiler_cxx'):
889            compiler = self.compiler.compiler_cxx[0]
890        else:
891            compiler = get_cxx_compiler()
892        _, version = get_compiler_abi_compatibility_and_version(compiler)
893        # Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.
894        if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
895            msg = ('It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
896                   'This may lead to multiple activations of the VC env.'
897                   'Please set `DISTUTILS_USE_SDK=1` and try again.')
898            raise UserWarning(msg)
899        return compiler, version
900
901    def _add_compile_flag(self, extension, flag):
902        extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args)
903        if isinstance(extension.extra_compile_args, dict):
904            for args in extension.extra_compile_args.values():
905                args.append(flag)
906        else:
907            extension.extra_compile_args.append(flag)
908
909    def _define_torch_extension_name(self, extension):
910        # pybind11 doesn't support dots in the names
911        # so in order to support extensions in the packages
912        # like torch._C, we take the last part of the string
913        # as the library name
914        names = extension.name.split('.')
915        name = names[-1]
916        define = f'-DTORCH_EXTENSION_NAME={name}'
917        self._add_compile_flag(extension, define)
918
919    def _add_gnu_cpp_abi_flag(self, extension):
920        # use the same CXX ABI as what PyTorch was compiled with
921        self._add_compile_flag(extension, '-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)))
922
923
924def CppExtension(name, sources, *args, **kwargs):
925    """
926    Create a :class:`setuptools.Extension` for C++.
927
928    Convenience method that creates a :class:`setuptools.Extension` with the
929    bare minimum (but often sufficient) arguments to build a C++ extension.
930
931    All arguments are forwarded to the :class:`setuptools.Extension`
932    constructor. Full list arguments can be found at
933    https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
934
935    Example:
936        >>> # xdoctest: +SKIP
937        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
938        >>> from setuptools import setup
939        >>> from torch.utils.cpp_extension import BuildExtension, CppExtension
940        >>> setup(
941        ...     name='extension',
942        ...     ext_modules=[
943        ...         CppExtension(
944        ...             name='extension',
945        ...             sources=['extension.cpp'],
946        ...             extra_compile_args=['-g'],
947        ...             extra_link_args=['-Wl,--no-as-needed', '-lm'])
948        ...     ],
949        ...     cmdclass={
950        ...         'build_ext': BuildExtension
951        ...     })
952    """
953    include_dirs = kwargs.get('include_dirs', [])
954    include_dirs += include_paths()
955    kwargs['include_dirs'] = include_dirs
956
957    library_dirs = kwargs.get('library_dirs', [])
958    library_dirs += library_paths()
959    kwargs['library_dirs'] = library_dirs
960
961    libraries = kwargs.get('libraries', [])
962    libraries.append('c10')
963    libraries.append('torch')
964    libraries.append('torch_cpu')
965    libraries.append('torch_python')
966    if IS_WINDOWS:
967        libraries.append("sleef")
968
969    kwargs['libraries'] = libraries
970
971    kwargs['language'] = 'c++'
972    return setuptools.Extension(name, sources, *args, **kwargs)
973
974
975def CUDAExtension(name, sources, *args, **kwargs):
976    """
977    Create a :class:`setuptools.Extension` for CUDA/C++.
978
979    Convenience method that creates a :class:`setuptools.Extension` with the
980    bare minimum (but often sufficient) arguments to build a CUDA/C++
981    extension. This includes the CUDA include path, library path and runtime
982    library.
983
984    All arguments are forwarded to the :class:`setuptools.Extension`
985    constructor. Full list arguments can be found at
986    https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
987
988    Example:
989        >>> # xdoctest: +SKIP
990        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
991        >>> from setuptools import setup
992        >>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
993        >>> setup(
994        ...     name='cuda_extension',
995        ...     ext_modules=[
996        ...         CUDAExtension(
997        ...                 name='cuda_extension',
998        ...                 sources=['extension.cpp', 'extension_kernel.cu'],
999        ...                 extra_compile_args={'cxx': ['-g'],
1000        ...                                     'nvcc': ['-O2']},
1001        ...                 extra_link_args=['-Wl,--no-as-needed', '-lcuda'])
1002        ...     ],
1003        ...     cmdclass={
1004        ...         'build_ext': BuildExtension
1005        ...     })
1006
1007    Compute capabilities:
1008
1009    By default the extension will be compiled to run on all archs of the cards visible during the
1010    building process of the extension, plus PTX. If down the road a new card is installed the
1011    extension may need to be recompiled. If a visible card has a compute capability (CC) that's
1012    newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
1013    will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
1014    support (see below for details on PTX).
1015
1016    You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which
1017    CCs you want the extension to support:
1018
1019    ``TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py``
1020    ``TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py``
1021
1022    The +PTX option causes extension kernel binaries to include PTX instructions for the specified
1023    CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >=
1024    the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with
1025    CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to
1026    provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on
1027    those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better
1028    off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6,
1029    "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but
1030    "8.0 8.6" would be better.
1031
1032    Note that while it's possible to include all supported archs, the more archs get included the
1033    slower the building process will be, as it will build a separate kernel image for each arch.
1034
1035    Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows.
1036    To workaround the issue, move python binding logic to pure C++ file.
1037
1038    Example use:
1039        #include <ATen/ATen.h>
1040        at::Tensor SigmoidAlphaBlendForwardCuda(....)
1041
1042    Instead of:
1043        #include <torch/extension.h>
1044        torch::Tensor SigmoidAlphaBlendForwardCuda(...)
1045
1046    Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460
1047    Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48
1048
1049    Relocatable device code linking:
1050
1051    If you want to reference device symbols across compilation units (across object files),
1052    the object files need to be built with `relocatable device code` (-rdc=true or -dc).
1053    An exception to this rule is "dynamic parallelism" (nested kernel launches)  which is not used a lot anymore.
1054    `Relocatable device code` is less optimized so it needs to be used only on object files that need it.
1055    Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
1056    help reduce the protentional perf degradation of `-rdc`.
1057    Note that it needs to be used at both steps to be useful.
1058
1059    If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
1060    There is also a case where `-dlink` is used without `-rdc`:
1061    when an extension is linked against a static lib containing rdc-compiled objects
1062    like the [NVSHMEM library](https://developer.nvidia.com/nvshmem).
1063
1064    Note: Ninja is required to build a CUDA Extension with RDC linking.
1065
1066    Example:
1067        >>> # xdoctest: +SKIP
1068        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
1069        >>> CUDAExtension(
1070        ...        name='cuda_extension',
1071        ...        sources=['extension.cpp', 'extension_kernel.cu'],
1072        ...        dlink=True,
1073        ...        dlink_libraries=["dlink_lib"],
1074        ...        extra_compile_args={'cxx': ['-g'],
1075        ...                            'nvcc': ['-O2', '-rdc=true']})
1076    """
1077    library_dirs = kwargs.get('library_dirs', [])
1078    library_dirs += library_paths(cuda=True)
1079    kwargs['library_dirs'] = library_dirs
1080
1081    libraries = kwargs.get('libraries', [])
1082    libraries.append('c10')
1083    libraries.append('torch')
1084    libraries.append('torch_cpu')
1085    libraries.append('torch_python')
1086    if IS_HIP_EXTENSION:
1087        libraries.append('amdhip64')
1088        libraries.append('c10_hip')
1089        libraries.append('torch_hip')
1090    else:
1091        libraries.append('cudart')
1092        libraries.append('c10_cuda')
1093        libraries.append('torch_cuda')
1094    kwargs['libraries'] = libraries
1095
1096    include_dirs = kwargs.get('include_dirs', [])
1097
1098    if IS_HIP_EXTENSION:
1099        build_dir = os.getcwd()
1100        hipify_result = hipify_python.hipify(
1101            project_directory=build_dir,
1102            output_directory=build_dir,
1103            header_include_dirs=include_dirs,
1104            includes=[os.path.join(build_dir, '*')],  # limit scope to build_dir only
1105            extra_files=[os.path.abspath(s) for s in sources],
1106            show_detailed=True,
1107            is_pytorch_extension=True,
1108            hipify_extra_files_only=True,  # don't hipify everything in includes path
1109        )
1110
1111        hipified_sources = set()
1112        for source in sources:
1113            s_abs = os.path.abspath(source)
1114            hipified_s_abs = (hipify_result[s_abs].hipified_path if (s_abs in hipify_result and
1115                              hipify_result[s_abs].hipified_path is not None) else s_abs)
1116            # setup() arguments must *always* be /-separated paths relative to the setup.py directory,
1117            # *never* absolute paths
1118            hipified_sources.add(os.path.relpath(hipified_s_abs, build_dir))
1119
1120        sources = list(hipified_sources)
1121
1122    include_dirs += include_paths(cuda=True)
1123    kwargs['include_dirs'] = include_dirs
1124
1125    kwargs['language'] = 'c++'
1126
1127    dlink_libraries = kwargs.get('dlink_libraries', [])
1128    dlink = kwargs.get('dlink', False) or dlink_libraries
1129    if dlink:
1130        extra_compile_args = kwargs.get('extra_compile_args', {})
1131
1132        extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', [])
1133        extra_compile_args_dlink += ['-dlink']
1134        extra_compile_args_dlink += [f'-L{x}' for x in library_dirs]
1135        extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries]
1136
1137        if (torch.version.cuda is not None) and TorchVersion(torch.version.cuda) >= '11.2':
1138            extra_compile_args_dlink += ['-dlto']   # Device Link Time Optimization started from cuda 11.2
1139
1140        extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink
1141
1142        kwargs['extra_compile_args'] = extra_compile_args
1143
1144    return setuptools.Extension(name, sources, *args, **kwargs)
1145
1146
1147def include_paths(cuda: bool = False) -> List[str]:
1148    """
1149    Get the include paths required to build a C++ or CUDA extension.
1150
1151    Args:
1152        cuda: If `True`, includes CUDA-specific include paths.
1153
1154    Returns:
1155        A list of include path strings.
1156    """
1157    lib_include = os.path.join(_TORCH_PATH, 'include')
1158    paths = [
1159        lib_include,
1160        # Remove this once torch/torch.h is officially no longer supported for C++ extensions.
1161        os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'),
1162        # Some internal (old) Torch headers don't properly prefix their includes,
1163        # so we need to pass -Itorch/lib/include/TH as well.
1164        os.path.join(lib_include, 'TH'),
1165        os.path.join(lib_include, 'THC')
1166    ]
1167    if cuda and IS_HIP_EXTENSION:
1168        paths.append(os.path.join(lib_include, 'THH'))
1169        paths.append(_join_rocm_home('include'))
1170    elif cuda:
1171        cuda_home_include = _join_cuda_home('include')
1172        # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
1173        # but gcc doesn't like having /usr/include passed explicitly
1174        if cuda_home_include != '/usr/include':
1175            paths.append(cuda_home_include)
1176
1177        # Support CUDA_INC_PATH env variable supported by CMake files
1178        if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
1179                cuda_inc_path != '/usr/include':
1180            paths.append(cuda_inc_path)
1181        if CUDNN_HOME is not None:
1182            paths.append(os.path.join(CUDNN_HOME, 'include'))
1183    return paths
1184
1185
1186def library_paths(cuda: bool = False) -> List[str]:
1187    """
1188    Get the library paths required to build a C++ or CUDA extension.
1189
1190    Args:
1191        cuda: If `True`, includes CUDA-specific library paths.
1192
1193    Returns:
1194        A list of library path strings.
1195    """
1196    # We need to link against libtorch.so
1197    paths = [TORCH_LIB_PATH]
1198
1199    if cuda and IS_HIP_EXTENSION:
1200        lib_dir = 'lib'
1201        paths.append(_join_rocm_home(lib_dir))
1202        if HIP_HOME is not None:
1203            paths.append(os.path.join(HIP_HOME, 'lib'))
1204    elif cuda:
1205        if IS_WINDOWS:
1206            lib_dir = os.path.join('lib', 'x64')
1207        else:
1208            lib_dir = 'lib64'
1209            if (not os.path.exists(_join_cuda_home(lib_dir)) and
1210                    os.path.exists(_join_cuda_home('lib'))):
1211                # 64-bit CUDA may be installed in 'lib' (see e.g. gh-16955)
1212                # Note that it's also possible both don't exist (see
1213                # _find_cuda_home) - in that case we stay with 'lib64'.
1214                lib_dir = 'lib'
1215
1216        paths.append(_join_cuda_home(lib_dir))
1217        if CUDNN_HOME is not None:
1218            paths.append(os.path.join(CUDNN_HOME, lib_dir))
1219    return paths
1220
1221
1222def load(name,
1223         sources: Union[str, List[str]],
1224         extra_cflags=None,
1225         extra_cuda_cflags=None,
1226         extra_ldflags=None,
1227         extra_include_paths=None,
1228         build_directory=None,
1229         verbose=False,
1230         with_cuda: Optional[bool] = None,
1231         is_python_module=True,
1232         is_standalone=False,
1233         keep_intermediates=True):
1234    """
1235    Load a PyTorch C++ extension just-in-time (JIT).
1236
1237    To load an extension, a Ninja build file is emitted, which is used to
1238    compile the given sources into a dynamic library. This library is
1239    subsequently loaded into the current Python process as a module and
1240    returned from this function, ready for use.
1241
1242    By default, the directory to which the build file is emitted and the
1243    resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where
1244    ``<tmp>`` is the temporary folder on the current platform and ``<name>``
1245    the name of the extension. This location can be overridden in two ways.
1246    First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it
1247    replaces ``<tmp>/torch_extensions`` and all extensions will be compiled
1248    into subfolders of this directory. Second, if the ``build_directory``
1249    argument to this function is supplied, it overrides the entire path, i.e.
1250    the library will be compiled into that folder directly.
1251
1252    To compile the sources, the default system compiler (``c++``) is used,
1253    which can be overridden by setting the ``CXX`` environment variable. To pass
1254    additional arguments to the compilation process, ``extra_cflags`` or
1255    ``extra_ldflags`` can be provided. For example, to compile your extension
1256    with optimizations, pass ``extra_cflags=['-O3']``. You can also use
1257    ``extra_cflags`` to pass further include directories.
1258
1259    CUDA support with mixed compilation is provided. Simply pass CUDA source
1260    files (``.cu`` or ``.cuh``) along with other sources. Such files will be
1261    detected and compiled with nvcc rather than the C++ compiler. This includes
1262    passing the CUDA lib64 directory as a library directory, and linking
1263    ``cudart``. You can pass additional flags to nvcc via
1264    ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
1265    heuristics for finding the CUDA install directory are used, which usually
1266    work fine. If not, setting the ``CUDA_HOME`` environment variable is the
1267    safest option.
1268
1269    Args:
1270        name: The name of the extension to build. This MUST be the same as the
1271            name of the pybind11 module!
1272        sources: A list of relative or absolute paths to C++ source files.
1273        extra_cflags: optional list of compiler flags to forward to the build.
1274        extra_cuda_cflags: optional list of compiler flags to forward to nvcc
1275            when building CUDA sources.
1276        extra_ldflags: optional list of linker flags to forward to the build.
1277        extra_include_paths: optional list of include directories to forward
1278            to the build.
1279        build_directory: optional path to use as build workspace.
1280        verbose: If ``True``, turns on verbose logging of load steps.
1281        with_cuda: Determines whether CUDA headers and libraries are added to
1282            the build. If set to ``None`` (default), this value is
1283            automatically determined based on the existence of ``.cu`` or
1284            ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
1285            and libraries to be included.
1286        is_python_module: If ``True`` (default), imports the produced shared
1287            library as a Python module. If ``False``, behavior depends on
1288            ``is_standalone``.
1289        is_standalone: If ``False`` (default) loads the constructed extension
1290            into the process as a plain dynamic library. If ``True``, build a
1291            standalone executable.
1292
1293    Returns:
1294        If ``is_python_module`` is ``True``:
1295            Returns the loaded PyTorch extension as a Python module.
1296
1297        If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``:
1298            Returns nothing. (The shared library is loaded into the process as
1299            a side effect.)
1300
1301        If ``is_standalone`` is ``True``.
1302            Return the path to the executable. (On Windows, TORCH_LIB_PATH is
1303            added to the PATH environment variable as a side effect.)
1304
1305    Example:
1306        >>> # xdoctest: +SKIP
1307        >>> from torch.utils.cpp_extension import load
1308        >>> module = load(
1309        ...     name='extension',
1310        ...     sources=['extension.cpp', 'extension_kernel.cu'],
1311        ...     extra_cflags=['-O2'],
1312        ...     verbose=True)
1313    """
1314    return _jit_compile(
1315        name,
1316        [sources] if isinstance(sources, str) else sources,
1317        extra_cflags,
1318        extra_cuda_cflags,
1319        extra_ldflags,
1320        extra_include_paths,
1321        build_directory or _get_build_directory(name, verbose),
1322        verbose,
1323        with_cuda,
1324        is_python_module,
1325        is_standalone,
1326        keep_intermediates=keep_intermediates)
1327
1328def _get_pybind11_abi_build_flags():
1329    # Note [Pybind11 ABI constants]
1330    #
1331    # Pybind11 before 2.4 used to build an ABI strings using the following pattern:
1332    # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
1333    # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
1334    # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
1335    #
1336    # This was done in order to further narrow down the chances of compiler ABI incompatibility
1337    # that can cause a hard to debug segfaults.
1338    # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
1339    # captured during PyTorch native library compilation in torch/csrc/Module.cpp
1340
1341    abi_cflags = []
1342    for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
1343        pval = getattr(torch._C, f"_PYBIND11_{pname}")
1344        if pval is not None and not IS_WINDOWS:
1345            abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
1346    return abi_cflags
1347
1348def _get_glibcxx_abi_build_flags():
1349    glibcxx_abi_cflags = ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
1350    return glibcxx_abi_cflags
1351
1352def check_compiler_is_gcc(compiler):
1353    if not IS_LINUX:
1354        return False
1355
1356    env = os.environ.copy()
1357    env['LC_ALL'] = 'C'  # Don't localize output
1358    try:
1359        version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
1360    except Exception as e:
1361        try:
1362            version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS)
1363        except Exception as e:
1364            return False
1365    # Check for 'gcc' or 'g++' for sccache wrapper
1366    pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
1367    results = re.findall(pattern, version_string)
1368    if len(results) != 1:
1369        return False
1370    compiler_path = os.path.realpath(results[0].strip())
1371    # On RHEL/CentOS c++ is a gcc compiler wrapper
1372    if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string:
1373        return True
1374    return False
1375
1376def _check_and_build_extension_h_precompiler_headers(
1377        extra_cflags,
1378        extra_include_paths,
1379        is_standalone=False):
1380    r'''
1381    Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules.
1382    GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html
1383    PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need
1384    add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild
1385    PCH file.
1386
1387    Note:
1388    1. Windows and MacOS have different PCH mechanism. We only support Linux currently.
1389    2. It only works on GCC/G++.
1390    '''
1391    if not IS_LINUX:
1392        return
1393
1394    compiler = get_cxx_compiler()
1395
1396    b_is_gcc = check_compiler_is_gcc(compiler)
1397    if b_is_gcc is False:
1398        return
1399
1400    head_file = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h')
1401    head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
1402    head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
1403
1404    def listToString(s):
1405        # initialize an empty string
1406        string = ""
1407        if s is None:
1408            return string
1409
1410        # traverse in the string
1411        for element in s:
1412            string += (element + ' ')
1413        # return string
1414        return string
1415
1416    def format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags, torch_include_dirs, extra_cflags, extra_include_paths):
1417        return re.sub(
1418            r"[ \n]+",
1419            " ",
1420            f"""
1421                {compiler} -x c++-header {head_file} -o {head_file_pch} {torch_include_dirs} {extra_include_paths} {extra_cflags} {common_cflags}
1422            """,
1423        ).strip()
1424
1425    def command_to_signature(cmd):
1426        signature = cmd.replace(' ', '_')
1427        return signature
1428
1429    def check_pch_signature_in_file(file_path, signature):
1430        b_exist = os.path.isfile(file_path)
1431        if b_exist is False:
1432            return False
1433
1434        with open(file_path) as file:
1435            # read all content of a file
1436            content = file.read()
1437            # check if string present in a file
1438            return signature == content
1439
1440    def _create_if_not_exist(path_dir):
1441        if not os.path.exists(path_dir):
1442            try:
1443                Path(path_dir).mkdir(parents=True, exist_ok=True)
1444            except OSError as exc:  # Guard against race condition
1445                if exc.errno != errno.EEXIST:
1446                    raise RuntimeError(f"Fail to create path {path_dir}") from exc
1447
1448    def write_pch_signature_to_file(file_path, pch_sign):
1449        _create_if_not_exist(os.path.dirname(file_path))
1450        with open(file_path, "w") as f:
1451            f.write(pch_sign)
1452            f.close()
1453
1454    def build_precompile_header(pch_cmd):
1455        try:
1456            subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT)
1457        except subprocess.CalledProcessError as e:
1458            raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}") from e
1459
1460    extra_cflags_str = listToString(extra_cflags)
1461    extra_include_paths_str = " ".join(
1462        [f"-I{include}" for include in extra_include_paths] if extra_include_paths else []
1463    )
1464
1465    lib_include = os.path.join(_TORCH_PATH, 'include')
1466    torch_include_dirs = [
1467        f"-I {lib_include}",
1468        # Python.h
1469        "-I {}".format(sysconfig.get_path("include")),
1470        # torch/all.h
1471        "-I {}".format(os.path.join(lib_include, 'torch', 'csrc', 'api', 'include')),
1472    ]
1473
1474    torch_include_dirs_str = listToString(torch_include_dirs)
1475
1476    common_cflags = []
1477    if not is_standalone:
1478        common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H']
1479
1480    common_cflags += ['-std=c++17', '-fPIC']
1481    common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
1482    common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
1483    common_cflags_str = listToString(common_cflags)
1484
1485    pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str)
1486    pch_sign = command_to_signature(pch_cmd)
1487
1488    if os.path.isfile(head_file_pch) is not True:
1489        build_precompile_header(pch_cmd)
1490        write_pch_signature_to_file(head_file_signature, pch_sign)
1491    else:
1492        b_same_sign = check_pch_signature_in_file(head_file_signature, pch_sign)
1493        if b_same_sign is False:
1494            build_precompile_header(pch_cmd)
1495            write_pch_signature_to_file(head_file_signature, pch_sign)
1496
1497def remove_extension_h_precompiler_headers():
1498    def _remove_if_file_exists(path_file):
1499        if os.path.exists(path_file):
1500            os.remove(path_file)
1501
1502    head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch')
1503    head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign')
1504
1505    _remove_if_file_exists(head_file_pch)
1506    _remove_if_file_exists(head_file_signature)
1507
1508def load_inline(name,
1509                cpp_sources,
1510                cuda_sources=None,
1511                functions=None,
1512                extra_cflags=None,
1513                extra_cuda_cflags=None,
1514                extra_ldflags=None,
1515                extra_include_paths=None,
1516                build_directory=None,
1517                verbose=False,
1518                with_cuda=None,
1519                is_python_module=True,
1520                with_pytorch_error_handling=True,
1521                keep_intermediates=True,
1522                use_pch=False):
1523    r'''
1524    Load a PyTorch C++ extension just-in-time (JIT) from string sources.
1525
1526    This function behaves exactly like :func:`load`, but takes its sources as
1527    strings rather than filenames. These strings are stored to files in the
1528    build directory, after which the behavior of :func:`load_inline` is
1529    identical to :func:`load`.
1530
1531    See `the
1532    tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions_jit.py>`_
1533    for good examples of using this function.
1534
1535    Sources may omit two required parts of a typical non-inline C++ extension:
1536    the necessary header includes, as well as the (pybind11) binding code. More
1537    precisely, strings passed to ``cpp_sources`` are first concatenated into a
1538    single ``.cpp`` file. This file is then prepended with ``#include
1539    <torch/extension.h>``.
1540
1541    Furthermore, if the ``functions`` argument is supplied, bindings will be
1542    automatically generated for each function specified. ``functions`` can
1543    either be a list of function names, or a dictionary mapping from function
1544    names to docstrings. If a list is given, the name of each function is used
1545    as its docstring.
1546
1547    The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
1548    file and  prepended with ``torch/types.h``, ``cuda.h`` and
1549    ``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
1550    separately, but ultimately linked into a single library. Note that no
1551    bindings are generated for functions in ``cuda_sources`` per  se. To bind
1552    to a CUDA kernel, you must create a C++ function that calls it, and either
1553    declare or define this C++ function in one of the ``cpp_sources`` (and
1554    include its name in ``functions``).
1555
1556    See :func:`load` for a description of arguments omitted below.
1557
1558    Args:
1559        cpp_sources: A string, or list of strings, containing C++ source code.
1560        cuda_sources: A string, or list of strings, containing CUDA source code.
1561        functions: A list of function names for which to generate function
1562            bindings. If a dictionary is given, it should map function names to
1563            docstrings (which are otherwise just the function names).
1564        with_cuda: Determines whether CUDA headers and libraries are added to
1565            the build. If set to ``None`` (default), this value is
1566            automatically determined based on whether ``cuda_sources`` is
1567            provided. Set it to ``True`` to force CUDA headers
1568            and libraries to be included.
1569        with_pytorch_error_handling: Determines whether pytorch error and
1570            warning macros are handled by pytorch instead of pybind. To do
1571            this, each function ``foo`` is called via an intermediary ``_safe_foo``
1572            function. This redirection might cause issues in obscure cases
1573            of cpp. This flag should be set to ``False`` when this redirect
1574            causes issues.
1575
1576    Example:
1577        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
1578        >>> from torch.utils.cpp_extension import load_inline
1579        >>> source = """
1580        at::Tensor sin_add(at::Tensor x, at::Tensor y) {
1581          return x.sin() + y.sin();
1582        }
1583        """
1584        >>> module = load_inline(name='inline_extension',
1585        ...                      cpp_sources=[source],
1586        ...                      functions=['sin_add'])
1587
1588    .. note::
1589        By default, the Ninja backend uses #CPUS + 2 workers to build the
1590        extension. This may use up too many resources on some systems. One
1591        can control the number of workers by setting the `MAX_JOBS` environment
1592        variable to a non-negative number.
1593    '''
1594    build_directory = build_directory or _get_build_directory(name, verbose)
1595
1596    if isinstance(cpp_sources, str):
1597        cpp_sources = [cpp_sources]
1598    cuda_sources = cuda_sources or []
1599    if isinstance(cuda_sources, str):
1600        cuda_sources = [cuda_sources]
1601
1602    cpp_sources.insert(0, '#include <torch/extension.h>')
1603
1604    if use_pch is True:
1605        # Using PreCompile Header('torch/extension.h') to reduce compile time.
1606        _check_and_build_extension_h_precompiler_headers(extra_cflags, extra_include_paths)
1607    else:
1608        remove_extension_h_precompiler_headers()
1609
1610    # If `functions` is supplied, we create the pybind11 bindings for the user.
1611    # Here, `functions` is (or becomes, after some processing) a map from
1612    # function names to function docstrings.
1613    if functions is not None:
1614        module_def = []
1615        module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
1616        if isinstance(functions, str):
1617            functions = [functions]
1618        if isinstance(functions, list):
1619            # Make the function docstring the same as the function name.
1620            functions = {f: f for f in functions}
1621        elif not isinstance(functions, dict):
1622            raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}")
1623        for function_name, docstring in functions.items():
1624            if with_pytorch_error_handling:
1625                module_def.append(f'm.def("{function_name}", torch::wrap_pybind_function({function_name}), "{docstring}");')
1626            else:
1627                module_def.append(f'm.def("{function_name}", {function_name}, "{docstring}");')
1628        module_def.append('}')
1629        cpp_sources += module_def
1630
1631    cpp_source_path = os.path.join(build_directory, 'main.cpp')
1632    _maybe_write(cpp_source_path, "\n".join(cpp_sources))
1633
1634    sources = [cpp_source_path]
1635
1636    if cuda_sources:
1637        cuda_sources.insert(0, '#include <torch/types.h>')
1638        cuda_sources.insert(1, '#include <cuda.h>')
1639        cuda_sources.insert(2, '#include <cuda_runtime.h>')
1640
1641        cuda_source_path = os.path.join(build_directory, 'cuda.cu')
1642        _maybe_write(cuda_source_path, "\n".join(cuda_sources))
1643
1644        sources.append(cuda_source_path)
1645
1646    return _jit_compile(
1647        name,
1648        sources,
1649        extra_cflags,
1650        extra_cuda_cflags,
1651        extra_ldflags,
1652        extra_include_paths,
1653        build_directory,
1654        verbose,
1655        with_cuda,
1656        is_python_module,
1657        is_standalone=False,
1658        keep_intermediates=keep_intermediates)
1659
1660
1661def _jit_compile(name,
1662                 sources,
1663                 extra_cflags,
1664                 extra_cuda_cflags,
1665                 extra_ldflags,
1666                 extra_include_paths,
1667                 build_directory: str,
1668                 verbose: bool,
1669                 with_cuda: Optional[bool],
1670                 is_python_module,
1671                 is_standalone,
1672                 keep_intermediates=True) -> None:
1673    if is_python_module and is_standalone:
1674        raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.")
1675
1676    if with_cuda is None:
1677        with_cuda = any(map(_is_cuda_file, sources))
1678    with_cudnn = any('cudnn' in f for f in extra_ldflags or [])
1679    old_version = JIT_EXTENSION_VERSIONER.get_version(name)
1680    version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
1681        name,
1682        sources,
1683        build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
1684        build_directory=build_directory,
1685        with_cuda=with_cuda,
1686        is_python_module=is_python_module,
1687        is_standalone=is_standalone,
1688    )
1689    if version > 0:
1690        if version != old_version and verbose:
1691            print(f'The input conditions for extension module {name} have changed. ' +
1692                  f'Bumping to version {version} and re-building as {name}_v{version}...',
1693                  file=sys.stderr)
1694        name = f'{name}_v{version}'
1695
1696    baton = FileBaton(os.path.join(build_directory, 'lock'))
1697    if baton.try_acquire():
1698        try:
1699            if version != old_version:
1700                with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
1701                    if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
1702                        hipify_result = hipify_python.hipify(
1703                            project_directory=build_directory,
1704                            output_directory=build_directory,
1705                            header_include_dirs=(extra_include_paths if extra_include_paths is not None else []),
1706                            extra_files=[os.path.abspath(s) for s in sources],
1707                            ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')],  # no need to hipify ROCm or PyTorch headers
1708                            show_detailed=verbose,
1709                            show_progress=verbose,
1710                            is_pytorch_extension=True,
1711                            clean_ctx=clean_ctx
1712                        )
1713
1714                        hipified_sources = set()
1715                        for source in sources:
1716                            s_abs = os.path.abspath(source)
1717                            hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs)
1718
1719                        sources = list(hipified_sources)
1720
1721                    _write_ninja_file_and_build_library(
1722                        name=name,
1723                        sources=sources,
1724                        extra_cflags=extra_cflags or [],
1725                        extra_cuda_cflags=extra_cuda_cflags or [],
1726                        extra_ldflags=extra_ldflags or [],
1727                        extra_include_paths=extra_include_paths or [],
1728                        build_directory=build_directory,
1729                        verbose=verbose,
1730                        with_cuda=with_cuda,
1731                        is_standalone=is_standalone)
1732            elif verbose:
1733                print('No modifications detected for re-loaded extension '
1734                      f'module {name}, skipping build step...', file=sys.stderr)
1735        finally:
1736            baton.release()
1737    else:
1738        baton.wait()
1739
1740    if verbose:
1741        print(f'Loading extension module {name}...', file=sys.stderr)
1742
1743    if is_standalone:
1744        return _get_exec_path(name, build_directory)
1745
1746    return _import_module_from_library(name, build_directory, is_python_module)
1747
1748
1749def _write_ninja_file_and_compile_objects(
1750        sources: List[str],
1751        objects,
1752        cflags,
1753        post_cflags,
1754        cuda_cflags,
1755        cuda_post_cflags,
1756        cuda_dlink_post_cflags,
1757        build_directory: str,
1758        verbose: bool,
1759        with_cuda: Optional[bool]) -> None:
1760    verify_ninja_availability()
1761
1762    compiler = get_cxx_compiler()
1763
1764    get_compiler_abi_compatibility_and_version(compiler)
1765    if with_cuda is None:
1766        with_cuda = any(map(_is_cuda_file, sources))
1767    build_file_path = os.path.join(build_directory, 'build.ninja')
1768    if verbose:
1769        print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr)
1770    _write_ninja_file(
1771        path=build_file_path,
1772        cflags=cflags,
1773        post_cflags=post_cflags,
1774        cuda_cflags=cuda_cflags,
1775        cuda_post_cflags=cuda_post_cflags,
1776        cuda_dlink_post_cflags=cuda_dlink_post_cflags,
1777        sources=sources,
1778        objects=objects,
1779        ldflags=None,
1780        library_target=None,
1781        with_cuda=with_cuda)
1782    if verbose:
1783        print('Compiling objects...', file=sys.stderr)
1784    _run_ninja_build(
1785        build_directory,
1786        verbose,
1787        # It would be better if we could tell users the name of the extension
1788        # that failed to build but there isn't a good way to get it here.
1789        error_prefix='Error compiling objects for extension')
1790
1791
1792def _write_ninja_file_and_build_library(
1793        name,
1794        sources: List[str],
1795        extra_cflags,
1796        extra_cuda_cflags,
1797        extra_ldflags,
1798        extra_include_paths,
1799        build_directory: str,
1800        verbose: bool,
1801        with_cuda: Optional[bool],
1802        is_standalone: bool = False) -> None:
1803    verify_ninja_availability()
1804
1805    compiler = get_cxx_compiler()
1806
1807    get_compiler_abi_compatibility_and_version(compiler)
1808    if with_cuda is None:
1809        with_cuda = any(map(_is_cuda_file, sources))
1810    extra_ldflags = _prepare_ldflags(
1811        extra_ldflags or [],
1812        with_cuda,
1813        verbose,
1814        is_standalone)
1815    build_file_path = os.path.join(build_directory, 'build.ninja')
1816    if verbose:
1817        print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr)
1818    # NOTE: Emitting a new ninja build file does not cause re-compilation if
1819    # the sources did not change, so it's ok to re-emit (and it's fast).
1820    _write_ninja_file_to_build_library(
1821        path=build_file_path,
1822        name=name,
1823        sources=sources,
1824        extra_cflags=extra_cflags or [],
1825        extra_cuda_cflags=extra_cuda_cflags or [],
1826        extra_ldflags=extra_ldflags or [],
1827        extra_include_paths=extra_include_paths or [],
1828        with_cuda=with_cuda,
1829        is_standalone=is_standalone)
1830
1831    if verbose:
1832        print(f'Building extension module {name}...', file=sys.stderr)
1833    _run_ninja_build(
1834        build_directory,
1835        verbose,
1836        error_prefix=f"Error building extension '{name}'")
1837
1838
1839def is_ninja_available():
1840    """Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise."""
1841    try:
1842        subprocess.check_output('ninja --version'.split())
1843    except Exception:
1844        return False
1845    else:
1846        return True
1847
1848
1849def verify_ninja_availability():
1850    """Raise ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not available on the system, does nothing otherwise."""
1851    if not is_ninja_available():
1852        raise RuntimeError("Ninja is required to load C++ extensions")
1853
1854
1855def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
1856    if IS_WINDOWS:
1857        python_lib_path = os.path.join(sys.base_exec_prefix, 'libs')
1858
1859        extra_ldflags.append('c10.lib')
1860        if with_cuda:
1861            extra_ldflags.append('c10_cuda.lib')
1862        extra_ldflags.append('torch_cpu.lib')
1863        if with_cuda:
1864            extra_ldflags.append('torch_cuda.lib')
1865            # /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it.
1866            # Related issue: https://github.com/pytorch/pytorch/issues/31611
1867            extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')
1868        extra_ldflags.append('torch.lib')
1869        extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}')
1870        if not is_standalone:
1871            extra_ldflags.append('torch_python.lib')
1872            extra_ldflags.append(f'/LIBPATH:{python_lib_path}')
1873
1874    else:
1875        extra_ldflags.append(f'-L{TORCH_LIB_PATH}')
1876        extra_ldflags.append('-lc10')
1877        if with_cuda:
1878            extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda')
1879        extra_ldflags.append('-ltorch_cpu')
1880        if with_cuda:
1881            extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda')
1882        extra_ldflags.append('-ltorch')
1883        if not is_standalone:
1884            extra_ldflags.append('-ltorch_python')
1885
1886        if is_standalone:
1887            extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")
1888
1889    if with_cuda:
1890        if verbose:
1891            print('Detected CUDA files, patching ldflags', file=sys.stderr)
1892        if IS_WINDOWS:
1893            extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}')
1894            extra_ldflags.append('cudart.lib')
1895            if CUDNN_HOME is not None:
1896                extra_ldflags.append(f'/LIBPATH:{os.path.join(CUDNN_HOME, "lib", "x64")}')
1897        elif not IS_HIP_EXTENSION:
1898            extra_lib_dir = "lib64"
1899            if (not os.path.exists(_join_cuda_home(extra_lib_dir)) and
1900                    os.path.exists(_join_cuda_home("lib"))):
1901                # 64-bit CUDA may be installed in "lib"
1902                # Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64"
1903                extra_lib_dir = "lib"
1904            extra_ldflags.append(f'-L{_join_cuda_home(extra_lib_dir)}')
1905            extra_ldflags.append('-lcudart')
1906            if CUDNN_HOME is not None:
1907                extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}')
1908        elif IS_HIP_EXTENSION:
1909            extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
1910            extra_ldflags.append('-lamdhip64')
1911    return extra_ldflags
1912
1913
1914def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
1915    """
1916    Determine CUDA arch flags to use.
1917
1918    For an arch, say "6.1", the added compile flag will be
1919    ``-gencode=arch=compute_61,code=sm_61``.
1920    For an added "+PTX", an additional
1921    ``-gencode=arch=compute_xx,code=compute_xx`` is added.
1922
1923    See select_compute_arch.cmake for corresponding named and supported arches
1924    when building with CMake.
1925    """
1926    # If cflags is given, there may already be user-provided arch flags in it
1927    # (from `extra_compile_args`)
1928    if cflags is not None:
1929        for flag in cflags:
1930            if 'TORCH_EXTENSION_NAME' in flag:
1931                continue
1932            if 'arch' in flag:
1933                return []
1934
1935    # Note: keep combined names ("arch1+arch2") above single names, otherwise
1936    # string replacement may not do the right thing
1937    named_arches = collections.OrderedDict([
1938        ('Kepler+Tesla', '3.7'),
1939        ('Kepler', '3.5+PTX'),
1940        ('Maxwell+Tegra', '5.3'),
1941        ('Maxwell', '5.0;5.2+PTX'),
1942        ('Pascal', '6.0;6.1+PTX'),
1943        ('Volta+Tegra', '7.2'),
1944        ('Volta', '7.0+PTX'),
1945        ('Turing', '7.5+PTX'),
1946        ('Ampere+Tegra', '8.7'),
1947        ('Ampere', '8.0;8.6+PTX'),
1948        ('Ada', '8.9+PTX'),
1949        ('Hopper', '9.0+PTX'),
1950    ])
1951
1952    supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
1953                        '7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a']
1954    valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
1955
1956    # The default is sm_30 for CUDA 9.x and 10.x
1957    # First check for an env var (same as used by the main setup.py)
1958    # Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX"
1959    # See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
1960    _arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
1961
1962    # If not given, determine what's best for the GPU / CUDA version that can be found
1963    if not _arch_list:
1964        warnings.warn(
1965            "TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n"
1966            "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].")
1967        arch_list = []
1968        # the assumption is that the extension should run on any of the currently visible cards,
1969        # which could be of different types - therefore all archs for visible cards should be included
1970        for i in range(torch.cuda.device_count()):
1971            capability = torch.cuda.get_device_capability(i)
1972            supported_sm = [int(arch.split('_')[1])
1973                            for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
1974            max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
1975            # Capability of the device may be higher than what's supported by the user's
1976            # NVCC, causing compilation error. User's NVCC is expected to match the one
1977            # used to build pytorch, so we use the maximum supported capability of pytorch
1978            # to clamp the capability.
1979            capability = min(max_supported_sm, capability)
1980            arch = f'{capability[0]}.{capability[1]}'
1981            if arch not in arch_list:
1982                arch_list.append(arch)
1983        arch_list = sorted(arch_list)
1984        arch_list[-1] += '+PTX'
1985    else:
1986        # Deal with lists that are ' ' separated (only deal with ';' after)
1987        _arch_list = _arch_list.replace(' ', ';')
1988        # Expand named arches
1989        for named_arch, archval in named_arches.items():
1990            _arch_list = _arch_list.replace(named_arch, archval)
1991
1992        arch_list = _arch_list.split(';')
1993
1994    flags = []
1995    for arch in arch_list:
1996        if arch not in valid_arch_strings:
1997            raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
1998        else:
1999            num = arch[0] + arch[2:].split("+")[0]
2000            flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
2001            if arch.endswith('+PTX'):
2002                flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')
2003
2004    return sorted(set(flags))
2005
2006
2007def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
2008    # If cflags is given, there may already be user-provided arch flags in it
2009    # (from `extra_compile_args`)
2010    if cflags is not None:
2011        for flag in cflags:
2012            if 'amdgpu-target' in flag or 'offload-arch' in flag:
2013                return ['-fno-gpu-rdc']
2014    # Use same defaults as used for building PyTorch
2015    # Allow env var to override, just like during initial cmake build.
2016    _archs = os.environ.get('PYTORCH_ROCM_ARCH', None)
2017    if not _archs:
2018        archFlags = torch._C._cuda_getArchFlags()
2019        if archFlags:
2020            archs = archFlags.split()
2021        else:
2022            archs = []
2023    else:
2024        archs = _archs.replace(' ', ';').split(';')
2025    flags = [f'--offload-arch={arch}' for arch in archs]
2026    flags += ['-fno-gpu-rdc']
2027    return flags
2028
2029def _get_build_directory(name: str, verbose: bool) -> str:
2030    root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR')
2031    if root_extensions_directory is None:
2032        root_extensions_directory = get_default_build_root()
2033        cu_str = ('cpu' if torch.version.cuda is None else
2034                  f'cu{torch.version.cuda.replace(".", "")}')  # type: ignore[attr-defined]
2035        python_version = f'py{sys.version_info.major}{sys.version_info.minor}'
2036        build_folder = f'{python_version}_{cu_str}'
2037
2038        root_extensions_directory = os.path.join(
2039            root_extensions_directory, build_folder)
2040
2041    if verbose:
2042        print(f'Using {root_extensions_directory} as PyTorch extensions root...', file=sys.stderr)
2043
2044    build_directory = os.path.join(root_extensions_directory, name)
2045    if not os.path.exists(build_directory):
2046        if verbose:
2047            print(f'Creating extension directory {build_directory}...', file=sys.stderr)
2048        # This is like mkdir -p, i.e. will also create parent directories.
2049        os.makedirs(build_directory, exist_ok=True)
2050
2051    return build_directory
2052
2053
2054def _get_num_workers(verbose: bool) -> Optional[int]:
2055    max_jobs = os.environ.get('MAX_JOBS')
2056    if max_jobs is not None and max_jobs.isdigit():
2057        if verbose:
2058            print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...',
2059                  file=sys.stderr)
2060        return int(max_jobs)
2061    if verbose:
2062        print('Allowing ninja to set a default number of workers... '
2063              '(overridable by setting the environment variable MAX_JOBS=N)',
2064              file=sys.stderr)
2065    return None
2066
2067
2068def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None:
2069    command = ['ninja', '-v']
2070    num_workers = _get_num_workers(verbose)
2071    if num_workers is not None:
2072        command.extend(['-j', str(num_workers)])
2073    env = os.environ.copy()
2074    # Try to activate the vc env for the users
2075    if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' not in env:
2076        from setuptools import distutils
2077
2078        plat_name = distutils.util.get_platform()
2079        plat_spec = PLAT_TO_VCVARS[plat_name]
2080
2081        vc_env = distutils._msvccompiler._get_vc_env(plat_spec)
2082        vc_env = {k.upper(): v for k, v in vc_env.items()}
2083        for k, v in env.items():
2084            uk = k.upper()
2085            if uk not in vc_env:
2086                vc_env[uk] = v
2087        env = vc_env
2088    try:
2089        sys.stdout.flush()
2090        sys.stderr.flush()
2091        # Warning: don't pass stdout=None to subprocess.run to get output.
2092        # subprocess.run assumes that sys.__stdout__ has not been modified and
2093        # attempts to write to it by default.  However, when we call _run_ninja_build
2094        # from ahead-of-time cpp extensions, the following happens:
2095        # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__.
2096        #    https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110
2097        #    (it probably shouldn't do this)
2098        # 2) subprocess.run (on POSIX, with no stdout override) relies on
2099        #    __stdout__ not being detached:
2100        #    https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214
2101        # To work around this, we pass in the fileno directly and hope that
2102        # it is valid.
2103        stdout_fileno = 1
2104        subprocess.run(
2105            command,
2106            stdout=stdout_fileno if verbose else subprocess.PIPE,
2107            stderr=subprocess.STDOUT,
2108            cwd=build_directory,
2109            check=True,
2110            env=env)
2111    except subprocess.CalledProcessError as e:
2112        # Python 2 and 3 compatible way of getting the error object.
2113        _, error, _ = sys.exc_info()
2114        # error.output contains the stdout and stderr of the build attempt.
2115        message = error_prefix
2116        # `error` is a CalledProcessError (which has an `output`) attribute, but
2117        # mypy thinks it's Optional[BaseException] and doesn't narrow
2118        if hasattr(error, 'output') and error.output:  # type: ignore[union-attr]
2119            message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}"  # type: ignore[union-attr]
2120        raise RuntimeError(message) from e
2121
2122
2123def _get_exec_path(module_name, path):
2124    if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'):
2125        torch_lib_in_path = any(
2126            os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH)
2127            for p in os.getenv('PATH', '').split(';')
2128        )
2129        if not torch_lib_in_path:
2130            os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}"
2131    return os.path.join(path, f'{module_name}{EXEC_EXT}')
2132
2133
2134def _import_module_from_library(module_name, path, is_python_module):
2135    filepath = os.path.join(path, f"{module_name}{LIB_EXT}")
2136    if is_python_module:
2137        # https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
2138        spec = importlib.util.spec_from_file_location(module_name, filepath)
2139        assert spec is not None
2140        module = importlib.util.module_from_spec(spec)
2141        assert isinstance(spec.loader, importlib.abc.Loader)
2142        spec.loader.exec_module(module)
2143        return module
2144    else:
2145        torch.ops.load_library(filepath)
2146        return filepath
2147
2148
2149def _write_ninja_file_to_build_library(path,
2150                                       name,
2151                                       sources,
2152                                       extra_cflags,
2153                                       extra_cuda_cflags,
2154                                       extra_ldflags,
2155                                       extra_include_paths,
2156                                       with_cuda,
2157                                       is_standalone) -> None:
2158    extra_cflags = [flag.strip() for flag in extra_cflags]
2159    extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
2160    extra_ldflags = [flag.strip() for flag in extra_ldflags]
2161    extra_include_paths = [flag.strip() for flag in extra_include_paths]
2162
2163    # Turn into absolute paths so we can emit them into the ninja build
2164    # file wherever it is.
2165    user_includes = [os.path.abspath(file) for file in extra_include_paths]
2166
2167    # include_paths() gives us the location of torch/extension.h
2168    system_includes = include_paths(with_cuda)
2169    # sysconfig.get_path('include') gives us the location of Python.h
2170    # Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS
2171    # installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder
2172    python_include_path = sysconfig.get_path('include', scheme='nt' if IS_WINDOWS else 'posix_prefix')
2173    if python_include_path is not None:
2174        system_includes.append(python_include_path)
2175
2176    common_cflags = []
2177    if not is_standalone:
2178        common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}')
2179        common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
2180
2181    common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
2182
2183    # Windows does not understand `-isystem` and quotes flags later.
2184    if IS_WINDOWS:
2185        common_cflags += [f'-I{include}' for include in user_includes + system_includes]
2186    else:
2187        common_cflags += [f'-I{shlex.quote(include)}' for include in user_includes]
2188        common_cflags += [f'-isystem {shlex.quote(include)}' for include in system_includes]
2189
2190    common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
2191
2192    if IS_WINDOWS:
2193        cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags
2194        cflags = _nt_quote_args(cflags)
2195    else:
2196        cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags
2197
2198    if with_cuda and IS_HIP_EXTENSION:
2199        cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
2200        cuda_flags += extra_cuda_cflags
2201        cuda_flags += _get_rocm_arch_flags(cuda_flags)
2202    elif with_cuda:
2203        cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
2204        if IS_WINDOWS:
2205            for flag in COMMON_MSVC_FLAGS:
2206                cuda_flags = ['-Xcompiler', flag] + cuda_flags
2207            for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
2208                cuda_flags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cuda_flags
2209            cuda_flags = cuda_flags + ['-std=c++17']
2210            cuda_flags = _nt_quote_args(cuda_flags)
2211            cuda_flags += _nt_quote_args(extra_cuda_cflags)
2212        else:
2213            cuda_flags += ['--compiler-options', "'-fPIC'"]
2214            cuda_flags += extra_cuda_cflags
2215            if not any(flag.startswith('-std=') for flag in cuda_flags):
2216                cuda_flags.append('-std=c++17')
2217            cc_env = os.getenv("CC")
2218            if cc_env is not None:
2219                cuda_flags = ['-ccbin', cc_env] + cuda_flags
2220    else:
2221        cuda_flags = None
2222
2223    def object_file_path(source_file: str) -> str:
2224        # '/path/to/file.cpp' -> 'file'
2225        file_name = os.path.splitext(os.path.basename(source_file))[0]
2226        if _is_cuda_file(source_file) and with_cuda:
2227            # Use a different object filename in case a C++ and CUDA file have
2228            # the same filename but different extension (.cpp vs. .cu).
2229            target = f'{file_name}.cuda.o'
2230        else:
2231            target = f'{file_name}.o'
2232        return target
2233
2234    objects = [object_file_path(src) for src in sources]
2235    ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags
2236
2237    # The darwin linker needs explicit consent to ignore unresolved symbols.
2238    if IS_MACOS:
2239        ldflags.append('-undefined dynamic_lookup')
2240    elif IS_WINDOWS:
2241        ldflags = _nt_quote_args(ldflags)
2242
2243    ext = EXEC_EXT if is_standalone else LIB_EXT
2244    library_target = f'{name}{ext}'
2245
2246    _write_ninja_file(
2247        path=path,
2248        cflags=cflags,
2249        post_cflags=None,
2250        cuda_cflags=cuda_flags,
2251        cuda_post_cflags=None,
2252        cuda_dlink_post_cflags=None,
2253        sources=sources,
2254        objects=objects,
2255        ldflags=ldflags,
2256        library_target=library_target,
2257        with_cuda=with_cuda)
2258
2259
2260def _write_ninja_file(path,
2261                      cflags,
2262                      post_cflags,
2263                      cuda_cflags,
2264                      cuda_post_cflags,
2265                      cuda_dlink_post_cflags,
2266                      sources,
2267                      objects,
2268                      ldflags,
2269                      library_target,
2270                      with_cuda) -> None:
2271    r"""Write a ninja file that does the desired compiling and linking.
2272
2273    `path`: Where to write this file
2274    `cflags`: list of flags to pass to $cxx. Can be None.
2275    `post_cflags`: list of flags to append to the $cxx invocation. Can be None.
2276    `cuda_cflags`: list of flags to pass to $nvcc. Can be None.
2277    `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None.
2278    `sources`: list of paths to source files
2279    `objects`: list of desired paths to objects, one per source.
2280    `ldflags`: list of flags to pass to linker. Can be None.
2281    `library_target`: Name of the output library. Can be None; in that case,
2282                      we do no linking.
2283    `with_cuda`: If we should be compiling with CUDA.
2284    """
2285    def sanitize_flags(flags):
2286        if flags is None:
2287            return []
2288        else:
2289            return [flag.strip() for flag in flags]
2290
2291    cflags = sanitize_flags(cflags)
2292    post_cflags = sanitize_flags(post_cflags)
2293    cuda_cflags = sanitize_flags(cuda_cflags)
2294    cuda_post_cflags = sanitize_flags(cuda_post_cflags)
2295    cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
2296    ldflags = sanitize_flags(ldflags)
2297
2298    # Sanity checks...
2299    assert len(sources) == len(objects)
2300    assert len(sources) > 0
2301
2302    compiler = get_cxx_compiler()
2303
2304    # Version 1.3 is required for the `deps` directive.
2305    config = ['ninja_required_version = 1.3']
2306    config.append(f'cxx = {compiler}')
2307    if with_cuda or cuda_dlink_post_cflags:
2308        if "PYTORCH_NVCC" in os.environ:
2309            nvcc = os.getenv("PYTORCH_NVCC")    # user can set nvcc compiler with ccache using the environment variable here
2310        else:
2311            if IS_HIP_EXTENSION:
2312                nvcc = _join_rocm_home('bin', 'hipcc')
2313            else:
2314                nvcc = _join_cuda_home('bin', 'nvcc')
2315        config.append(f'nvcc = {nvcc}')
2316
2317    if IS_HIP_EXTENSION:
2318        post_cflags = COMMON_HIP_FLAGS + post_cflags
2319    flags = [f'cflags = {" ".join(cflags)}']
2320    flags.append(f'post_cflags = {" ".join(post_cflags)}')
2321    if with_cuda:
2322        flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
2323        flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
2324    flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
2325    flags.append(f'ldflags = {" ".join(ldflags)}')
2326
2327    # Turn into absolute paths so we can emit them into the ninja build
2328    # file wherever it is.
2329    sources = [os.path.abspath(file) for file in sources]
2330
2331    # See https://ninja-build.org/build.ninja.html for reference.
2332    compile_rule = ['rule compile']
2333    if IS_WINDOWS:
2334        compile_rule.append(
2335            '  command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')
2336        compile_rule.append('  deps = msvc')
2337    else:
2338        compile_rule.append(
2339            '  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')
2340        compile_rule.append('  depfile = $out.d')
2341        compile_rule.append('  deps = gcc')
2342
2343    if with_cuda:
2344        cuda_compile_rule = ['rule cuda_compile']
2345        nvcc_gendeps = ''
2346        # --generate-dependencies-with-compile is not supported by ROCm
2347        # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time.
2348        if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1':
2349            cuda_compile_rule.append('  depfile = $out.d')
2350            cuda_compile_rule.append('  deps = gcc')
2351            # Note: non-system deps with nvcc are only supported
2352            # on Linux so use --generate-dependencies-with-compile
2353            # to make this work on Windows too.
2354            nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d'
2355        cuda_compile_rule.append(
2356            f'  command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags')
2357
2358    # Emit one build rule per source to enable incremental build.
2359    build = []
2360    for source_file, object_file in zip(sources, objects):
2361        is_cuda_source = _is_cuda_file(source_file) and with_cuda
2362        rule = 'cuda_compile' if is_cuda_source else 'compile'
2363        if IS_WINDOWS:
2364            source_file = source_file.replace(':', '$:')
2365            object_file = object_file.replace(':', '$:')
2366        source_file = source_file.replace(" ", "$ ")
2367        object_file = object_file.replace(" ", "$ ")
2368        build.append(f'build {object_file}: {rule} {source_file}')
2369
2370    if cuda_dlink_post_cflags:
2371        devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
2372        devlink_rule = ['rule cuda_devlink']
2373        devlink_rule.append('  command = $nvcc $in -o $out $cuda_dlink_post_cflags')
2374        devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
2375        objects += [devlink_out]
2376    else:
2377        devlink_rule, devlink = [], []
2378
2379    if library_target is not None:
2380        link_rule = ['rule link']
2381        if IS_WINDOWS:
2382            cl_paths = subprocess.check_output(['where',
2383                                                'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n')
2384            if len(cl_paths) >= 1:
2385                cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')
2386            else:
2387                raise RuntimeError("MSVC is required to load C++ extensions")
2388            link_rule.append(f'  command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out')
2389        else:
2390            link_rule.append('  command = $cxx $in $ldflags -o $out')
2391
2392        link = [f'build {library_target}: link {" ".join(objects)}']
2393
2394        default = [f'default {library_target}']
2395    else:
2396        link_rule, link, default = [], [], []
2397
2398    # 'Blocks' should be separated by newlines, for visual benefit.
2399    blocks = [config, flags, compile_rule]
2400    if with_cuda:
2401        blocks.append(cuda_compile_rule)  # type: ignore[possibly-undefined]
2402    blocks += [devlink_rule, link_rule, build, devlink, link, default]
2403    content = "\n\n".join("\n".join(b) for b in blocks)
2404    # Ninja requires a new lines at the end of the .ninja file
2405    content += "\n"
2406    _maybe_write(path, content)
2407
2408def _join_cuda_home(*paths) -> str:
2409    """
2410    Join paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
2411
2412    This is basically a lazy way of raising an error for missing $CUDA_HOME
2413    only once we need to get any CUDA-specific path.
2414    """
2415    if CUDA_HOME is None:
2416        raise OSError('CUDA_HOME environment variable is not set. '
2417                      'Please set it to your CUDA install root.')
2418    return os.path.join(CUDA_HOME, *paths)
2419
2420
2421def _is_cuda_file(path: str) -> bool:
2422    valid_ext = ['.cu', '.cuh']
2423    if IS_HIP_EXTENSION:
2424        valid_ext.append('.hip')
2425    return os.path.splitext(path)[1] in valid_ext
2426