xref: /aosp_15_r20/external/pytorch/torch/cuda/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3This package adds support for CUDA tensor types.
4
5It implements the same function as CPU tensors, but they utilize
6GPUs for computation.
7
8It is lazily initialized, so you can always import it, and use
9:func:`is_available()` to determine if your system supports CUDA.
10
11:ref:`cuda-semantics` has more details about working with CUDA.
12"""
13
14import importlib
15import os
16import threading
17import traceback
18import warnings
19from functools import lru_cache
20from typing import Any, Callable, cast, List, Optional, Tuple, Union
21
22import torch
23import torch._C
24from torch import device as _device
25from torch._utils import _dummy_type, _LazySeedTracker, classproperty
26from torch.types import Device
27
28from . import gds
29from ._utils import _get_device_index
30from .graphs import (
31    CUDAGraph,
32    graph,
33    graph_pool_handle,
34    is_current_stream_capturing,
35    make_graphed_callables,
36)
37from .streams import Event, ExternalStream, Stream
38
39
40try:
41    from torch._C import _cudart  # type: ignore[attr-defined]
42except ImportError:
43    _cudart = None
44
45_initialized = False
46_tls = threading.local()
47_initialization_lock = threading.Lock()
48_queued_calls: List[
49    Tuple[Callable[[], None], List[str]]
50] = []  # don't invoke these until initialization occurs
51_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
52_device_t = Union[_device, str, int, None]
53
54_HAS_PYNVML = False
55_PYNVML_ERR = None
56try:
57    from torch import version as _version
58
59    try:
60        if not _version.hip:
61            import pynvml  # type: ignore[import]
62        else:
63            import amdsmi  # type: ignore[import]
64
65        _HAS_PYNVML = True
66    except ModuleNotFoundError:
67        pass
68    finally:
69        del _version
70except ImportError as err:
71    _PYNVML_ERR = err  # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
72
73_lazy_seed_tracker = _LazySeedTracker()
74
75# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
76if hasattr(torch._C, "_CudaDeviceProperties"):
77    _CudaDeviceProperties = torch._C._CudaDeviceProperties
78else:
79    _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties")  # type: ignore[assignment, misc]
80
81if hasattr(torch._C, "_cuda_exchangeDevice"):
82    _exchange_device = torch._C._cuda_exchangeDevice
83else:
84
85    def _exchange_device(device: int) -> int:
86        if device < 0:
87            return -1
88        raise RuntimeError("PyTorch was compiled without CUDA support")
89
90
91if hasattr(torch._C, "_cuda_maybeExchangeDevice"):
92    _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice
93else:
94
95    def _maybe_exchange_device(device: int) -> int:
96        if device < 0:
97            return -1
98        raise RuntimeError("PyTorch was compiled without CUDA support")
99
100
101has_half: bool = True
102has_magma: bool = torch._C._has_magma
103
104default_generators: Tuple[torch._C.Generator] = ()  # type: ignore[assignment]
105
106
107def _is_compiled() -> bool:
108    r"""Return true if compile with CUDA support."""
109    return hasattr(torch._C, "_cuda_getDeviceCount")
110
111
112def _nvml_based_avail() -> bool:
113    return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1"
114
115
116def is_available() -> bool:
117    r"""Return a bool indicating if CUDA is currently available."""
118    if not _is_compiled():
119        return False
120    if _nvml_based_avail():
121        # The user has set an env variable to request this availability check that attempts to avoid fork poisoning by
122        # using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization
123        # fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`)
124        return device_count() > 0
125    else:
126        # The default availability inspection never throws and returns 0 if the driver is missing or can't
127        # be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver
128        # API via `cuInit`
129        return torch._C._cuda_getDeviceCount() > 0
130
131
132def is_bf16_supported(including_emulation: bool = True):
133    r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
134    # Check for ROCm, if true return true, no ROCM_VERSION check required,
135    # since it is supported on AMD GPU archs.
136    if torch.version.hip:
137        return True
138
139    # If CUDA is not available, than it does not support bf16 either
140    if not is_available():
141        return False
142
143    device = torch.cuda.current_device()
144
145    # Check for CUDA version and device compute capability.
146    # This is a fast way to check for it.
147    cuda_version = torch.version.cuda
148    if (
149        cuda_version is not None
150        and int(cuda_version.split(".")[0]) >= 11
151        and torch.cuda.get_device_properties(device).major >= 8
152    ):
153        return True
154
155    if not including_emulation:
156        return False
157
158    # Finally try to create a bfloat16 device.
159    return _check_bf16_tensor_supported(device)
160
161
162@lru_cache(maxsize=16)
163def _check_bf16_tensor_supported(device: _device_t):
164    try:
165        torch.tensor([1.0], dtype=torch.bfloat16, device=device)
166        return True
167    except Exception:
168        return False
169
170
171def _sleep(cycles):
172    torch._C._cuda_sleep(cycles)
173
174
175def _extract_arch_version(arch_string: str):
176    """Extracts the architecture string from a CUDA version"""
177    base = arch_string.split("_")[1]
178    if base.endswith("a"):
179        base = base[:-1]
180    return int(base)
181
182
183def _check_capability():
184    incorrect_binary_warn = """
185    Found GPU%d %s which requires CUDA_VERSION >= %d to
186     work properly, but your PyTorch was compiled
187     with CUDA_VERSION %d. Please install the correct PyTorch binary
188     using instructions from https://pytorch.org
189    """
190
191    old_gpu_warn = """
192    Found GPU%d %s which is of cuda capability %d.%d.
193    PyTorch no longer supports this GPU because it is too old.
194    The minimum cuda capability supported by this library is %d.%d.
195    """
196
197    if torch.version.cuda is not None:  # on ROCm we don't want this check
198        CUDA_VERSION = torch._C._cuda_getCompiledVersion()
199        for d in range(device_count()):
200            capability = get_device_capability(d)
201            major = capability[0]
202            minor = capability[1]
203            name = get_device_name(d)
204            current_arch = major * 10 + minor
205            min_arch = min(
206                (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
207                default=35,
208            )
209            if current_arch < min_arch:
210                warnings.warn(
211                    old_gpu_warn
212                    % (d, name, major, minor, min_arch // 10, min_arch % 10)
213                )
214
215
216def _check_cubins():
217    incompatible_device_warn = """
218{} with CUDA capability sm_{} is not compatible with the current PyTorch installation.
219The current PyTorch install supports CUDA capabilities {}.
220If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/
221"""
222    if torch.version.cuda is None:  # on ROCm we don't want this check
223        return
224    arch_list = get_arch_list()
225    if len(arch_list) == 0:
226        return
227    supported_sm = [_extract_arch_version(arch) for arch in arch_list if "sm_" in arch]
228    for idx in range(device_count()):
229        cap_major, cap_minor = get_device_capability(idx)
230        # NVIDIA GPU compute architectures are backward compatible within major version
231        supported = any(sm // 10 == cap_major for sm in supported_sm)
232        if not supported:
233            device_name = get_device_name(idx)
234            capability = cap_major * 10 + cap_minor
235            warnings.warn(
236                incompatible_device_warn.format(
237                    device_name, capability, " ".join(arch_list), device_name
238                )
239            )
240
241
242def is_initialized():
243    r"""Return whether PyTorch's CUDA state has been initialized."""
244    return _initialized and not _is_in_bad_fork()
245
246
247def _lazy_call(callable, **kwargs):
248    if is_initialized():
249        callable()
250    else:
251        # TODO(torch_deploy): this accesses linecache, which attempts to read the
252        # file system to get traceback info. Patch linecache or do something
253        # else here if this ends up being important.
254        global _lazy_seed_tracker
255        if kwargs.get("seed_all", False):
256            _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
257        elif kwargs.get("seed", False):
258            _lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
259        else:
260            # Don't store the actual traceback to avoid memory cycle
261            _queued_calls.append((callable, traceback.format_stack()))
262
263
264_lazy_call(_check_capability)
265_lazy_call(_check_cubins)
266
267
268class DeferredCudaCallError(Exception):
269    pass
270
271
272OutOfMemoryError = torch._C.OutOfMemoryError
273
274
275def init():
276    r"""Initialize PyTorch's CUDA state.
277
278    You may need to call this explicitly if you are interacting with
279    PyTorch via its C API, as Python bindings for CUDA functionality
280    will not be available until this initialization takes place.
281    Ordinary users should not need this, as all of PyTorch's CUDA methods
282    automatically initialize CUDA state on-demand.
283
284    Does nothing if the CUDA state is already initialized.
285    """
286    _lazy_init()
287
288
289def _lazy_init():
290    global _initialized, _queued_calls
291    if is_initialized() or hasattr(_tls, "is_initializing"):
292        return
293    with _initialization_lock:
294        # We be double-checked locking, boys!  This is OK because
295        # the above test was GIL protected anyway.  The inner test
296        # is for when a thread blocked on some other thread which was
297        # doing the initialization; when they get the lock, they will
298        # find there is nothing left to do.
299        if is_initialized():
300            return
301        # It is important to prevent other threads from entering _lazy_init
302        # immediately, while we are still guaranteed to have the GIL, because some
303        # of the C calls we make below will release the GIL
304        if _is_in_bad_fork():
305            raise RuntimeError(
306                "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
307                "multiprocessing, you must use the 'spawn' start method"
308            )
309        if not hasattr(torch._C, "_cuda_getDeviceCount"):
310            raise AssertionError("Torch not compiled with CUDA enabled")
311        if _cudart is None:
312            raise AssertionError(
313                "libcudart functions unavailable. It looks like you have a broken build?"
314            )
315        # This function throws if there's a driver initialization error, no GPUs
316        # are found or any other error occurs
317        if "CUDA_MODULE_LOADING" not in os.environ:
318            os.environ["CUDA_MODULE_LOADING"] = "LAZY"
319        torch._C._cuda_init()
320        # Some of the queued calls may reentrantly call _lazy_init();
321        # we need to just return without initializing in that case.
322        # However, we must not let any *other* threads in!
323        _tls.is_initializing = True
324
325        for calls in _lazy_seed_tracker.get_calls():
326            if calls:
327                _queued_calls.append(calls)
328
329        try:
330            for queued_call, orig_traceback in _queued_calls:
331                try:
332                    queued_call()
333                except Exception as e:
334                    msg = (
335                        f"CUDA call failed lazily at initialization with error: {str(e)}\n\n"
336                        f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}"
337                    )
338                    raise DeferredCudaCallError(msg) from e
339        finally:
340            delattr(_tls, "is_initializing")
341        _initialized = True
342
343
344def cudart():
345    r"""Retrieves the CUDA runtime API module.
346
347
348    This function initializes the CUDA runtime environment if it is not already
349    initialized and returns the CUDA runtime API module (_cudart). The CUDA
350    runtime API module provides access to various CUDA runtime functions.
351
352    Args:
353        ``None``
354
355    Returns:
356        module: The CUDA runtime API module (_cudart).
357
358    Raises:
359        RuntimeError: If CUDA cannot be re-initialized in a forked subprocess.
360        AssertionError: If PyTorch is not compiled with CUDA support or if libcudart functions are unavailable.
361
362    Example of CUDA operations with profiling:
363        >>> import torch
364        >>> from torch.cuda import cudart, check_error
365        >>> import os
366        >>>
367        >>> os.environ['CUDA_PROFILE'] = '1'
368        >>>
369        >>> def perform_cuda_operations_with_streams():
370        >>>     stream = torch.cuda.Stream()
371        >>>     with torch.cuda.stream(stream):
372        >>>         x = torch.randn(100, 100, device='cuda')
373        >>>         y = torch.randn(100, 100, device='cuda')
374        >>>         z = torch.mul(x, y)
375        >>>     return z
376        >>>
377        >>> torch.cuda.synchronize()
378        >>> print("====== Start nsys profiling ======")
379        >>> check_error(cudart().cudaProfilerStart())
380        >>> with torch.autograd.profiler.emit_nvtx():
381        >>>     result = perform_cuda_operations_with_streams()
382        >>>     print("CUDA operations completed.")
383        >>> check_error(torch.cuda.cudart().cudaProfilerStop())
384        >>> print("====== End nsys profiling ======")
385
386    To run this example and save the profiling information, execute:
387        >>> $ nvprof --profile-from-start off --csv --print-summary -o trace_name.prof -f -- python cudart_test.py
388
389    This command profiles the CUDA operations in the provided script and saves
390    the profiling information to a file named `trace_name.prof`.
391    The `--profile-from-start off` option ensures that profiling starts only
392    after the `cudaProfilerStart` call in the script.
393    The `--csv` and `--print-summary` options format the profiling output as a
394    CSV file and print a summary, respectively.
395    The `-o` option specifies the output file name, and the `-f` option forces the
396    overwrite of the output file if it already exists.
397    """
398    _lazy_init()
399    return _cudart
400
401
402class cudaStatus:
403    SUCCESS: int = 0
404    ERROR_NOT_READY: int = 34
405
406
407class CudaError(RuntimeError):
408    def __init__(self, code: int) -> None:
409        msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
410        super().__init__(f"{msg} ({code})")
411
412
413def check_error(res: int) -> None:
414    if res != _cudart.cudaError.success:
415        raise CudaError(res)
416
417
418class _DeviceGuard:
419    def __init__(self, index: int):
420        self.idx = index
421        self.prev_idx = -1
422
423    def __enter__(self):
424        self.prev_idx = torch.cuda._exchange_device(self.idx)
425
426    def __exit__(self, type: Any, value: Any, traceback: Any):
427        self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
428        return False
429
430
431class device:
432    r"""Context-manager that changes the selected device.
433
434    Args:
435        device (torch.device or int): device index to select. It's a no-op if
436            this argument is a negative integer or ``None``.
437    """
438
439    def __init__(self, device: Any):
440        self.idx = _get_device_index(device, optional=True)
441        self.prev_idx = -1
442
443    def __enter__(self):
444        self.prev_idx = torch.cuda._exchange_device(self.idx)
445
446    def __exit__(self, type: Any, value: Any, traceback: Any):
447        self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
448        return False
449
450
451class device_of(device):
452    r"""Context-manager that changes the current device to that of given object.
453
454    You can use both tensors and storages as arguments. If a given object is
455    not allocated on a GPU, this is a no-op.
456
457    Args:
458        obj (Tensor or Storage): object allocated on the selected device.
459    """
460
461    def __init__(self, obj):
462        idx = obj.get_device() if obj.is_cuda else -1
463        super().__init__(idx)
464
465
466def set_device(device: _device_t) -> None:
467    r"""Set the current device.
468
469    Usage of this function is discouraged in favor of :any:`device`. In most
470    cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
471
472    Args:
473        device (torch.device or int): selected device. This function is a no-op
474            if this argument is negative.
475    """
476    device = _get_device_index(device)
477    if device >= 0:
478        torch._C._cuda_setDevice(device)
479
480
481def get_device_name(device: Optional[_device_t] = None) -> str:
482    r"""Get the name of a device.
483
484    Args:
485        device (torch.device or int or str, optional): device for which to return the
486            name. This function is a no-op if this argument is a negative
487            integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
488            if :attr:`device` is ``None`` (default).
489
490    Returns:
491        str: the name of the device
492    """
493    return get_device_properties(device).name
494
495
496def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
497    r"""Get the cuda capability of a device.
498
499    Args:
500        device (torch.device or int or str, optional): device for which to return the
501            device capability. This function is a no-op if this argument is
502            a negative integer. It uses the current device, given by
503            :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
504            (default).
505
506    Returns:
507        tuple(int, int): the major and minor cuda capability of the device
508    """
509    prop = get_device_properties(device)
510    return prop.major, prop.minor
511
512
513def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
514    r"""Get the properties of a device.
515
516    Args:
517        device (torch.device or int or str): device for which to return the
518            properties of the device.
519
520    Returns:
521        _CudaDeviceProperties: the properties of the device
522    """
523    _lazy_init()  # will define _get_device_properties
524    device = _get_device_index(device, optional=True)
525    if device < 0 or device >= device_count():
526        raise AssertionError("Invalid device id")
527    return _get_device_properties(device)  # type: ignore[name-defined]
528
529
530def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
531    r"""Check if peer access between two devices is possible."""
532    _lazy_init()
533    device = _get_device_index(device, optional=True)
534    peer_device = _get_device_index(peer_device)
535    if device < 0 or device >= device_count():
536        raise AssertionError("Invalid device id")
537    if peer_device < 0 or peer_device >= device_count():
538        raise AssertionError("Invalid peer device id")
539    return torch._C._cuda_canDeviceAccessPeer(device, peer_device)
540
541
542class StreamContext:
543    r"""Context-manager that selects a given stream.
544
545    All CUDA kernels queued within its context will be enqueued on a selected
546    stream.
547
548    Args:
549        Stream (Stream): selected stream. This manager is a no-op if it's
550            ``None``.
551    .. note:: Streams are per-device.
552    """
553    cur_stream: Optional["torch.cuda.Stream"]
554
555    def __init__(self, stream: Optional["torch.cuda.Stream"]):
556        self.stream = stream
557        self.idx = _get_device_index(None, True)
558        if not torch.jit.is_scripting():
559            if self.idx is None:
560                self.idx = -1
561
562        self.src_prev_stream = (
563            None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
564        )
565        self.dst_prev_stream = (
566            None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
567        )
568
569    def __enter__(self):
570        # Local cur_stream variable for type refinement
571        cur_stream = self.stream
572        # Return if stream is None or CUDA device not available
573        if cur_stream is None or self.idx == -1:
574            return
575        self.src_prev_stream = torch.cuda.current_stream(None)
576
577        # If the stream is not on the current device, then
578        # set the current stream on the device
579        if self.src_prev_stream.device != cur_stream.device:
580            with device(cur_stream.device):
581                self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device)
582        torch.cuda.set_stream(cur_stream)
583
584    def __exit__(self, type: Any, value: Any, traceback: Any):
585        # Local cur_stream variable for type refinement
586        cur_stream = self.stream
587        # If stream is None or no CUDA device available, return
588        if cur_stream is None or self.idx == -1:
589            return
590
591        # Reset the stream on the original device
592        # and destination device
593        if self.src_prev_stream.device != cur_stream.device:  # type: ignore[union-attr]
594            torch.cuda.set_stream(self.dst_prev_stream)  # type: ignore[arg-type]
595        torch.cuda.set_stream(self.src_prev_stream)  # type: ignore[arg-type]
596
597
598def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext:
599    r"""Wrap around the Context-manager StreamContext that selects a given stream.
600
601    Arguments:
602        stream (Stream): selected stream. This manager is a no-op if it's
603            ``None``.
604    ..Note:: In eager mode stream is of type Stream class while in JIT it is
605    an object of the custom class ``torch.classes.cuda.Stream``.
606    """
607    return StreamContext(stream)
608
609
610def _set_stream_by_id(stream_id, device_index, device_type):
611    r"""set stream specified by the stream id, device index and
612        device type
613
614    Args: stream_id (int): stream id in stream pool
615          device_index (int): device index in topo
616          device_type (int): enum device type
617    """
618    torch._C._cuda_setStream(
619        stream_id=stream_id,
620        device_index=device_index,
621        device_type=device_type,
622    )
623
624
625def set_stream(stream: Stream):
626    r"""Set the current stream.This is a wrapper API to set the stream.
627        Usage of this function is discouraged in favor of the ``stream``
628        context manager.
629
630    Args:
631        stream (Stream): selected stream. This function is a no-op
632            if this argument is ``None``.
633    """
634    if stream is None:
635        return
636    _set_stream_by_id(
637        stream_id=stream.stream_id,
638        device_index=stream.device_index,
639        device_type=stream.device_type,
640    )
641
642
643def _parse_visible_devices() -> Union[List[int], List[str]]:
644    r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
645    var = os.getenv("CUDA_VISIBLE_DEVICES")
646
647    if torch.version.hip:
648        hip_devices = os.getenv("HIP_VISIBLE_DEVICES")
649        if hip_devices is not None:
650            var = hip_devices
651
652    if var is None:
653        return list(range(64))
654
655    def _strtoul(s: str) -> int:
656        """Return -1 or positive integer sequence string starts with."""
657        if not s:
658            return -1
659        for idx, c in enumerate(s):
660            if not (c.isdigit() or (idx == 0 and c in "+-")):
661                break
662            if idx + 1 == len(s):
663                idx += 1
664        return int(s[:idx]) if idx > 0 else -1
665
666    def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
667        rcs: List[str] = []
668        for elem in lst.split(","):
669            # Repeated id results in empty set
670            if elem in rcs:
671                return cast(List[str], [])
672            # Anything other but prefix is ignored
673            if not elem.startswith(prefix):
674                break
675            rcs.append(elem)
676        return rcs
677
678    if var.startswith("GPU-"):
679        return parse_list_with_prefix(var, "GPU-")
680    if var.startswith("MIG-"):
681        return parse_list_with_prefix(var, "MIG-")
682    # CUDA_VISIBLE_DEVICES uses something like strtoul
683    # which makes `1gpu2,2ampere` is equivalent to `1,2`
684    rc: List[int] = []
685    for elem in var.split(","):
686        x = _strtoul(elem.strip())
687        # Repeated ordinal results in empty set
688        if x in rc:
689            return cast(List[int], [])
690        # Negative value aborts the sequence
691        if x < 0:
692            break
693        rc.append(x)
694    return rc
695
696
697def _raw_device_count_amdsmi() -> int:
698    if not _HAS_PYNVML:  # If amdsmi is not available
699        return -1
700    try:
701        amdsmi.amdsmi_init()
702    except amdsmi.AmdSmiException as e:
703        warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}")
704        return -1
705    socket_handles = amdsmi.amdsmi_get_processor_handles()
706    return len(socket_handles)
707
708
709def _raw_device_count_nvml() -> int:
710    r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
711    from ctypes import byref, c_int, CDLL
712
713    nvml_h = CDLL("libnvidia-ml.so.1")
714    rc = nvml_h.nvmlInit()
715    if rc != 0:
716        warnings.warn("Can't initialize NVML")
717        return -1
718    dev_count = c_int(-1)
719    rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
720    if rc != 0:
721        warnings.warn("Can't get nvml device count")
722        return -1
723    del nvml_h
724    return dev_count.value
725
726
727def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
728    from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
729
730    if not _HAS_PYNVML:  # If amdsmi is not available
731        return None
732    try:
733        amdsmi.amdsmi_init()
734    except amdsmi.AmdSmiException:
735        warnings.warn("Can't initialize amdsmi")
736        return None
737    try:
738        socket_handles = amdsmi.amdsmi_get_processor_handles()
739        dev_count = len(socket_handles)
740    except amdsmi.AmdSmiException:
741        warnings.warn("Can't get amdsmi device count")
742        return None
743    uuids: List[str] = []
744    for idx in range(dev_count):
745        try:
746            handler = amdsmi.amdsmi_get_processor_handles()[idx]
747        except amdsmi.AmdSmiException:
748            warnings.warn("Cannot get amd device handler")
749            return None
750        try:
751            uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler)
752        except amdsmi.AmdSmiException:
753            warnings.warn("Cannot get uuid for amd device")
754            return None
755        uuids.append(str(uuid))
756    return uuids
757
758
759def _raw_device_uuid_nvml() -> Optional[List[str]]:
760    r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
761    from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
762
763    nvml_h = CDLL("libnvidia-ml.so.1")
764    rc = nvml_h.nvmlInit()
765    if rc != 0:
766        warnings.warn("Can't initialize NVML")
767        return None
768    dev_count = c_int(-1)
769    rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
770    if rc != 0:
771        warnings.warn("Can't get nvml device count")
772        return None
773    uuids: List[str] = []
774    for idx in range(dev_count.value):
775        dev_id = c_void_p()
776        rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
777        if rc != 0:
778            warnings.warn("Can't get device handle")
779            return None
780        buf_len = 96
781        buf = create_string_buffer(buf_len)
782        rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
783        if rc != 0:
784            warnings.warn("Can't get device UUID")
785            return None
786        uuids.append(buf.raw.decode("ascii").strip("\0"))
787    del nvml_h
788    return uuids
789
790
791def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
792    r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
793
794    def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
795        best_match = -1
796        for idx, uuid in enumerate(uuids):
797            if not uuid.startswith(candidate):
798                continue
799            # Ambiguous candidate
800            if best_match != -1:
801                return -1
802            best_match = idx
803        return best_match
804
805    rc: List[int] = []
806    for candidate in candidates:
807        idx = uuid_to_orinal(candidate, uuids)
808        # First invalid ordinal stops parsing
809        if idx < 0:
810            break
811        # Duplicates result in empty set
812        if idx in rc:
813            return cast(List[int], [])
814        rc.append(idx)
815    return rc
816
817
818def _device_count_amdsmi() -> int:
819    visible_devices = _parse_visible_devices()
820    if not visible_devices:
821        return 0
822    try:
823        if type(visible_devices[0]) is str:
824            return -1
825        else:
826            raw_cnt = _raw_device_count_amdsmi()
827            if raw_cnt <= 0:
828                return raw_cnt
829            # Trim the list up to a maximum available device
830            for idx, val in enumerate(visible_devices):
831                if cast(int, val) >= raw_cnt:
832                    return idx
833    except OSError:
834        return -1
835    except AttributeError:
836        return -1
837    return len(visible_devices)
838
839
840def _device_count_nvml() -> int:
841    r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
842
843    Negative value is returned if NVML discovery or initialization has failed.
844    """
845    visible_devices = _parse_visible_devices()
846    if not visible_devices:
847        return 0
848    try:
849        if type(visible_devices[0]) is str:
850            # Skip MIG parsing
851            if visible_devices[0].startswith("MIG-"):
852                return -1
853            uuids = _raw_device_uuid_nvml()
854            if uuids is None:
855                return -1
856            visible_devices = _transform_uuid_to_ordinals(
857                cast(List[str], visible_devices), uuids
858            )
859        else:
860            raw_cnt = _raw_device_count_nvml()
861            if raw_cnt <= 0:
862                return raw_cnt
863            # Trim the list up to a maximum available device
864            for idx, val in enumerate(visible_devices):
865                if cast(int, val) >= raw_cnt:
866                    return idx
867    except OSError:
868        return -1
869    except AttributeError:
870        return -1
871    return len(visible_devices)
872
873
874def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
875    r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
876    idx = _get_device_index(device, optional=True)
877    visible_devices = _parse_visible_devices()
878    if type(visible_devices[0]) is str:
879        uuids = _raw_device_uuid_nvml()
880        if uuids is None:
881            raise RuntimeError("Can't get device UUIDs")
882        visible_devices = _transform_uuid_to_ordinals(
883            cast(List[str], visible_devices), uuids
884        )
885    visible_devices = cast(List[int], visible_devices)
886    if idx < 0 or idx >= len(visible_devices):
887        raise RuntimeError(
888            f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
889        )
890    return visible_devices[idx]
891
892
893_cached_device_count: Optional[int] = None
894
895
896def device_count() -> int:
897    r"""Return the number of GPUs available."""
898    global _cached_device_count
899    if not _is_compiled():
900        return 0
901    if _cached_device_count is not None:
902        return _cached_device_count
903    # bypass _device_count_nvml() if rocm (not supported)
904    nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml()
905    r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
906    # NB: Do not cache the device count prior to CUDA initialization, because
907    # the number of devices can change due to changes to CUDA_VISIBLE_DEVICES
908    # setting prior to CUDA initialization.
909    if _initialized:
910        _cached_device_count = r
911    return r
912
913
914def get_arch_list() -> List[str]:
915    r"""Return list CUDA architectures this library was compiled for."""
916    if not is_available():
917        return []
918    arch_flags = torch._C._cuda_getArchFlags()
919    if arch_flags is None:
920        return []
921    return arch_flags.split()
922
923
924def get_gencode_flags() -> str:
925    r"""Return NVCC gencode flags this library was compiled with."""
926    arch_list = get_arch_list()
927    if len(arch_list) == 0:
928        return ""
929    arch_list_ = [arch.split("_") for arch in arch_list]
930    return " ".join(
931        [
932            f"-gencode compute=compute_{arch},code={kind}_{arch}"
933            for (kind, arch) in arch_list_
934        ]
935    )
936
937
938def current_device() -> int:
939    r"""Return the index of a currently selected device."""
940    _lazy_init()
941    return torch._C._cuda_getDevice()
942
943
944def synchronize(device: _device_t = None) -> None:
945    r"""Wait for all kernels in all streams on a CUDA device to complete.
946
947    Args:
948        device (torch.device or int, optional): device for which to synchronize.
949            It uses the current device, given by :func:`~torch.cuda.current_device`,
950            if :attr:`device` is ``None`` (default).
951    """
952    _lazy_init()
953    with torch.cuda.device(device):
954        return torch._C._cuda_synchronize()
955
956
957def ipc_collect():
958    r"""Force collects GPU memory after it has been released by CUDA IPC.
959
960    .. note::
961        Checks if any sent CUDA tensors could be cleaned from the memory. Force
962        closes shared memory file used for reference counting if there is no
963        active counters. Useful when the producer process stopped actively sending
964        tensors and want to release unused memory.
965    """
966    _lazy_init()
967    return torch._C._cuda_ipc_collect()
968
969
970def current_stream(device: Optional[_device_t] = None) -> Stream:
971    r"""Return the currently selected :class:`Stream` for a given device.
972
973    Args:
974        device (torch.device or int, optional): selected device. Returns
975            the currently selected :class:`Stream` for the current device, given
976            by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
977            (default).
978    """
979    _lazy_init()
980    streamdata = torch._C._cuda_getCurrentStream(
981        _get_device_index(device, optional=True)
982    )
983    return Stream(
984        stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
985    )
986
987
988def default_stream(device: Optional[_device_t] = None) -> Stream:
989    r"""Return the default :class:`Stream` for a given device.
990
991    Args:
992        device (torch.device or int, optional): selected device. Returns
993            the default :class:`Stream` for the current device, given by
994            :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
995            (default).
996    """
997    _lazy_init()
998    streamdata = torch._C._cuda_getDefaultStream(
999        _get_device_index(device, optional=True)
1000    )
1001    return Stream(
1002        stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
1003    )
1004
1005
1006def current_blas_handle():
1007    r"""Return cublasHandle_t pointer to current cuBLAS handle"""
1008    _lazy_init()
1009    return torch._C._cuda_getCurrentBlasHandle()
1010
1011
1012def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
1013    r"""Set the debug mode for cuda synchronizing operations.
1014
1015    Args:
1016        debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
1017            if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
1018
1019    Warning:
1020        This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
1021        particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
1022    """
1023    _lazy_init()
1024    if isinstance(debug_mode, str):
1025        if debug_mode == "default":
1026            debug_mode = 0
1027        elif debug_mode == "warn":
1028            debug_mode = 1
1029        elif debug_mode == "error":
1030            debug_mode = 2
1031        else:
1032            raise RuntimeError(
1033                "invalid value of debug_mode, expected one of `default`, `warn`, `error`"
1034            )
1035
1036    torch._C._cuda_set_sync_debug_mode(debug_mode)
1037
1038
1039def get_sync_debug_mode() -> int:
1040    r"""Return current value of debug mode for cuda synchronizing operations."""
1041    _lazy_init()
1042    return torch._C._cuda_get_sync_debug_mode()
1043
1044
1045def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
1046    if not _HAS_PYNVML:
1047        raise ModuleNotFoundError(
1048            "pynvml does not seem to be installed or it can't be imported."
1049        ) from _PYNVML_ERR
1050    from pynvml import NVMLError_DriverNotLoaded
1051
1052    try:
1053        pynvml.nvmlInit()
1054    except NVMLError_DriverNotLoaded as e:
1055        raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
1056
1057    device = _get_nvml_device_index(device)
1058    handle = pynvml.nvmlDeviceGetHandleByIndex(device)
1059    return handle
1060
1061
1062def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
1063    if not _HAS_PYNVML:
1064        raise ModuleNotFoundError(
1065            "amdsmi does not seem to be installed or it can't be imported."
1066        ) from _PYNVML_ERR
1067    try:
1068        amdsmi.amdsmi_init()
1069    except amdsmi.AmdSmiException as e:
1070        raise RuntimeError(
1071            "amdsmi driver can't be loaded, requires >=ROCm5.6 installation"
1072        ) from e
1073    device = _get_amdsmi_device_index(device)
1074    handle = amdsmi.amdsmi_get_processor_handles()[device]
1075    return handle
1076
1077
1078def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
1079    r"""Return the amdsmi index of the device, taking visible_devices into account."""
1080    idx = _get_device_index(device, optional=True)
1081    visible_devices = _parse_visible_devices()
1082    if type(visible_devices[0]) is str:
1083        raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings")
1084    idx_map = dict(enumerate(cast(List[int], visible_devices)))
1085    if idx not in idx_map:
1086        raise RuntimeError(
1087            f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})"
1088        )
1089    return idx_map[idx]
1090
1091
1092def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int:
1093    handle = _get_amdsmi_handler()
1094    device = _get_amdsmi_device_index(device)
1095    return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
1096
1097
1098def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int:
1099    handle = _get_amdsmi_handler()
1100    device = _get_amdsmi_device_index(device)
1101    handle = amdsmi.amdsmi_get_processor_handles()[device]
1102    return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
1103
1104
1105def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
1106    handle = _get_amdsmi_handler(device)
1107    return amdsmi.amdsmi_get_temp_metric(
1108        handle,
1109        amdsmi.AmdSmiTemperatureType.JUNCTION,
1110        amdsmi.AmdSmiTemperatureMetric.CURRENT,
1111    )
1112
1113
1114def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
1115    handle = _get_amdsmi_handler(device)
1116    socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
1117    if socket_power != "N/A":
1118        return socket_power
1119    else:
1120        return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"]
1121
1122
1123def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
1124    handle = _get_amdsmi_handler(device)
1125    clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)
1126    if "cur_clk" in clock_info:  # ROCm 6.2 deprecation
1127        return clock_info["cur_clk"]
1128    else:
1129        return clock_info["clk"]
1130
1131
1132def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
1133    r"""Return the percent of time over the past sample period during which global (device)
1134    memory was being read or written as given by `nvidia-smi`.
1135
1136    Args:
1137        device (torch.device or int, optional): selected device. Returns
1138            statistic for the current device, given by :func:`~torch.cuda.current_device`,
1139            if :attr:`device` is ``None`` (default).
1140
1141    Warning: Each sample period may be between 1 second and 1/6 second,
1142    depending on the product being queried.
1143    """
1144    if not torch.version.hip:
1145        handle = _get_pynvml_handler()
1146        device = _get_nvml_device_index(device)
1147        handle = pynvml.nvmlDeviceGetHandleByIndex(device)
1148        return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
1149    else:
1150        return _get_amdsmi_memory_usage(device)
1151
1152
1153def utilization(device: Optional[Union[Device, int]] = None) -> int:
1154    r"""Return the percent of time over the past sample period during which one or
1155    more kernels was executing on the GPU as given by `nvidia-smi`.
1156
1157    Args:
1158        device (torch.device or int, optional): selected device. Returns
1159            statistic for the current device, given by :func:`~torch.cuda.current_device`,
1160            if :attr:`device` is ``None`` (default).
1161
1162    Warning: Each sample period may be between 1 second and 1/6 second,
1163    depending on the product being queried.
1164    """
1165    if not torch.version.hip:
1166        handle = _get_pynvml_handler(device)
1167        device = _get_nvml_device_index(device)
1168        handle = pynvml.nvmlDeviceGetHandleByIndex(device)
1169        return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
1170    else:
1171        return _get_amdsmi_utilization(device)
1172
1173
1174def temperature(device: Optional[Union[Device, int]] = None) -> int:
1175    r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
1176
1177    The average temperature is computed based on past sample period as given by `nvidia-smi`.
1178
1179    Args:
1180        device (torch.device or int, optional): selected device. Returns
1181            statistic for the current device, given by :func:`~torch.cuda.current_device`,
1182            if :attr:`device` is ``None`` (default).
1183
1184    Warning: Each sample period may be between 1 second and 1/6 second,
1185    depending on the product being queried.
1186    """
1187    if not torch.version.hip:
1188        handle = _get_pynvml_handler(device)
1189        # 0 refers to the temperature sensor for the GPU die.
1190        return pynvml.nvmlDeviceGetTemperature(handle, 0)
1191    else:
1192        return _get_amdsmi_temperature(device)
1193
1194
1195def power_draw(device: Optional[Union[Device, int]] = None) -> int:
1196    r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
1197        over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
1198
1199    Args:
1200        device (torch.device or int, optional): selected device. Returns
1201            statistic for the current device, given by :func:`~torch.cuda.current_device`,
1202            if :attr:`device` is ``None`` (default).
1203
1204    Warning: Each sample period may be between 1 second and 1/6 second,
1205    depending on the product being queried.
1206    """
1207    if not torch.version.hip:
1208        handle = _get_pynvml_handler(device)
1209        return pynvml.nvmlDeviceGetPowerUsage(handle)
1210    else:
1211        return _get_amdsmi_power_draw(device)
1212
1213
1214def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
1215    r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`.
1216
1217    Args:
1218        device (torch.device or int, optional): selected device. Returns
1219            statistic for the current device, given by :func:`~torch.cuda.current_device`,
1220            if :attr:`device` is ``None`` (default).
1221
1222    Warning: Each sample period may be between 1 second and 1/6 second,
1223    depending on the product being queried.
1224    """
1225    if not torch.version.hip:
1226        handle = _get_pynvml_handler(device)
1227        return pynvml.nvmlDeviceGetClockInfo(handle, 1)
1228    else:
1229        return _get_amdsmi_clock_rate(device)
1230
1231
1232def _get_device(device: Union[int, str, torch.device]) -> torch.device:
1233    r"""Return the torch.device type object from the passed in device.
1234
1235    Args:
1236        device (torch.device or int): selected device.
1237    """
1238    if isinstance(device, str):
1239        device = torch.device(device)
1240    elif isinstance(device, int):
1241        device = torch.device("cuda", device)
1242    return device
1243
1244
1245def _get_generator(device: torch.device) -> torch._C.Generator:
1246    r"""Return the CUDA Generator object for the given device.
1247
1248    Args:
1249        device (torch.device): selected device.
1250    """
1251    idx = device.index
1252    if idx is None:
1253        idx = current_device()
1254    return torch.cuda.default_generators[idx]
1255
1256
1257def _set_rng_state_offset(
1258    offset: int, device: Union[int, str, torch.device] = "cuda"
1259) -> None:
1260    r"""Set the random number generator state offset of the specified GPU.
1261
1262    Args:
1263        offset (int): The desired offset
1264        device (torch.device or int, optional): The device to set the RNG state.
1265            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
1266    """
1267    final_device = _get_device(device)
1268
1269    def cb():
1270        default_generator = _get_generator(final_device)
1271        default_generator.set_offset(offset)
1272
1273    _lazy_call(cb)
1274
1275
1276def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int:
1277    r"""Return the random number generator state offset of the specified GPU.
1278
1279    Args:
1280        device (torch.device or int, optional): The device to return the RNG state offset of.
1281            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
1282
1283    .. warning::
1284        This function eagerly initializes CUDA.
1285    """
1286    _lazy_init()
1287    final_device = _get_device(device)
1288    default_generator = _get_generator(final_device)
1289    return default_generator.get_offset()
1290
1291
1292from .memory import *  # noqa: F403
1293from .random import *  # noqa: F403
1294
1295
1296################################################################################
1297# Define Storage and Tensor classes
1298################################################################################
1299
1300
1301@staticmethod  # type: ignore[misc]
1302def _lazy_new(cls, *args, **kwargs):
1303    _lazy_init()
1304    # We may need to call lazy init again if we are a forked child
1305    # del _CudaBase.__new__
1306    return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
1307
1308
1309class _CudaBase:
1310    is_cuda = True
1311    is_sparse = False
1312
1313    def type(self, *args, **kwargs):
1314        # We could use a Protocol here to tell mypy that self has `get_device` method
1315        # but it is only available in the typing module on Python >= 3.8
1316        # or on typing_extensions module on Python >= 3.6
1317        with device(self.get_device()):  # type: ignore[attr-defined]
1318            return super().type(*args, **kwargs)  # type: ignore[misc]
1319
1320    __new__ = _lazy_new
1321
1322
1323from torch.storage import _LegacyStorage, _warn_typed_storage_removal
1324
1325
1326class _CudaLegacyStorage(_LegacyStorage):
1327    @classmethod
1328    def from_buffer(cls, *args, **kwargs):
1329        _warn_typed_storage_removal()
1330        raise RuntimeError("from_buffer: Not available for CUDA storage")
1331
1332    @classmethod
1333    def _new_with_weak_ptr(cls, *args, **kwargs):
1334        raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage")
1335
1336    @classmethod
1337    def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
1338        raise RuntimeError("_new_shared_filename: Not available for CUDA storage")
1339
1340
1341class ByteStorage(_CudaLegacyStorage):
1342    @classproperty
1343    def dtype(self):
1344        _warn_typed_storage_removal()
1345        return self._dtype
1346
1347    @classproperty
1348    def _dtype(self):
1349        return torch.uint8
1350
1351
1352class DoubleStorage(_CudaLegacyStorage):
1353    @classproperty
1354    def dtype(self):
1355        _warn_typed_storage_removal()
1356        return self._dtype
1357
1358    @classproperty
1359    def _dtype(self):
1360        return torch.double
1361
1362
1363class FloatStorage(_CudaLegacyStorage):
1364    @classproperty
1365    def dtype(self):
1366        _warn_typed_storage_removal()
1367        return self._dtype
1368
1369    @classproperty
1370    def _dtype(self):
1371        return torch.float
1372
1373
1374class HalfStorage(_CudaLegacyStorage):
1375    @classproperty
1376    def dtype(self):
1377        _warn_typed_storage_removal()
1378        return self._dtype
1379
1380    @classproperty
1381    def _dtype(self):
1382        return torch.half
1383
1384
1385class LongStorage(_CudaLegacyStorage):
1386    @classproperty
1387    def dtype(self):
1388        _warn_typed_storage_removal()
1389        return self._dtype
1390
1391    @classproperty
1392    def _dtype(self):
1393        return torch.long
1394
1395
1396class IntStorage(_CudaLegacyStorage):
1397    @classproperty
1398    def dtype(self):
1399        _warn_typed_storage_removal()
1400        return self._dtype
1401
1402    @classproperty
1403    def _dtype(self):
1404        return torch.int
1405
1406
1407class ShortStorage(_CudaLegacyStorage):
1408    @classproperty
1409    def dtype(self):
1410        _warn_typed_storage_removal()
1411        return self._dtype
1412
1413    @classproperty
1414    def _dtype(self):
1415        return torch.short
1416
1417
1418class CharStorage(_CudaLegacyStorage):
1419    @classproperty
1420    def dtype(self):
1421        _warn_typed_storage_removal()
1422        return self._dtype
1423
1424    @classproperty
1425    def _dtype(self):
1426        return torch.int8
1427
1428
1429class BoolStorage(_CudaLegacyStorage):
1430    @classproperty
1431    def dtype(self):
1432        _warn_typed_storage_removal()
1433        return self._dtype
1434
1435    @classproperty
1436    def _dtype(self):
1437        return torch.bool
1438
1439
1440class BFloat16Storage(_CudaLegacyStorage):
1441    @classproperty
1442    def dtype(self):
1443        _warn_typed_storage_removal()
1444        return self._dtype
1445
1446    @classproperty
1447    def _dtype(self):
1448        return torch.bfloat16
1449
1450
1451class ComplexDoubleStorage(_CudaLegacyStorage):
1452    @classproperty
1453    def dtype(self):
1454        _warn_typed_storage_removal()
1455        return self._dtype
1456
1457    @classproperty
1458    def _dtype(self):
1459        return torch.cdouble
1460
1461
1462class ComplexFloatStorage(_CudaLegacyStorage):
1463    @classproperty
1464    def dtype(self):
1465        _warn_typed_storage_removal()
1466        return self._dtype
1467
1468    @classproperty
1469    def _dtype(self):
1470        return torch.cfloat
1471
1472
1473del _LegacyStorage
1474del _CudaLegacyStorage
1475
1476torch._storage_classes.add(DoubleStorage)
1477torch._storage_classes.add(FloatStorage)
1478torch._storage_classes.add(LongStorage)
1479torch._storage_classes.add(IntStorage)
1480torch._storage_classes.add(ShortStorage)
1481torch._storage_classes.add(CharStorage)
1482torch._storage_classes.add(ByteStorage)
1483torch._storage_classes.add(HalfStorage)
1484torch._storage_classes.add(BoolStorage)
1485torch._storage_classes.add(BFloat16Storage)
1486torch._storage_classes.add(ComplexDoubleStorage)
1487torch._storage_classes.add(ComplexFloatStorage)
1488
1489
1490class _WrappedTritonKernel:
1491    """Just a simple wrapper to store some metadata for testing purposes."""
1492
1493    def __init__(self, kernel):
1494        self.kernel = kernel
1495        self.kernel_invoked = False
1496
1497    def __call__(self, *args, **kwargs):
1498        res = self.kernel(*args, **kwargs)
1499        self.kernel_invoked = True
1500        return res
1501
1502
1503def _register_triton_kernels():
1504    if torch._running_with_deploy():
1505        return
1506
1507    @_WrappedTritonKernel
1508    def kernel_impl(*args, **kwargs):
1509        from torch.sparse._triton_ops import bsr_dense_mm
1510
1511        return bsr_dense_mm(*args, skip_checks=True, **kwargs)
1512
1513    @_WrappedTritonKernel
1514    def addmm_kernel_impl(*args, **kwargs):
1515        from torch.sparse._triton_ops import bsr_dense_addmm
1516
1517        return bsr_dense_addmm(*args, skip_checks=True, **kwargs)
1518
1519    has_triton = importlib.util.find_spec("triton") is not None
1520    if has_triton:
1521        torch._TritonLibrary.registerOp(
1522            "_triton_bsr_dense_mm_out",
1523            "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
1524            kernel_impl,
1525            "SparseCsrCUDA",
1526        )
1527
1528        torch._TritonLibrary.registerOp(
1529            "_triton_bsr_dense_addmm_out",
1530            (
1531                "_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense,"
1532                " *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)"
1533            ),
1534            addmm_kernel_impl,
1535            "SparseCsrCUDA",
1536        )
1537
1538
1539_lazy_call(_register_triton_kernels)
1540
1541
1542from . import amp, jiterator, nvtx, profiler, sparse, tunable
1543
1544
1545__all__ = [
1546    # Typed storage and tensors
1547    "BFloat16Storage",
1548    "BFloat16Tensor",
1549    "BoolStorage",
1550    "BoolTensor",
1551    "ByteStorage",
1552    "ByteTensor",
1553    "CharStorage",
1554    "CharTensor",
1555    "ComplexDoubleStorage",
1556    "ComplexFloatStorage",
1557    "DoubleStorage",
1558    "DoubleTensor",
1559    "FloatStorage",
1560    "FloatTensor",
1561    "HalfStorage",
1562    "HalfTensor",
1563    "IntStorage",
1564    "IntTensor",
1565    "LongStorage",
1566    "LongTensor",
1567    "ShortStorage",
1568    "ShortTensor",
1569    "CUDAGraph",
1570    "CudaError",
1571    "DeferredCudaCallError",
1572    "Event",
1573    "ExternalStream",
1574    "Stream",
1575    "StreamContext",
1576    "amp",
1577    "caching_allocator_alloc",
1578    "caching_allocator_delete",
1579    "can_device_access_peer",
1580    "check_error",
1581    "cudaStatus",
1582    "cudart",
1583    "current_blas_handle",
1584    "current_device",
1585    "current_stream",
1586    "default_generators",
1587    "default_stream",
1588    "device",
1589    "device_count",
1590    "device_of",
1591    "empty_cache",
1592    "get_allocator_backend",
1593    "CUDAPluggableAllocator",
1594    "change_current_allocator",
1595    "get_arch_list",
1596    "get_device_capability",
1597    "get_device_name",
1598    "get_device_properties",
1599    "get_gencode_flags",
1600    "get_rng_state",
1601    "get_rng_state_all",
1602    "get_sync_debug_mode",
1603    "graph",
1604    "graph_pool_handle",
1605    "graphs",
1606    "has_half",
1607    "has_magma",
1608    "init",
1609    "initial_seed",
1610    "ipc_collect",
1611    "is_available",
1612    "is_bf16_supported",
1613    "is_current_stream_capturing",
1614    "is_initialized",
1615    "jiterator",
1616    "list_gpu_processes",
1617    "make_graphed_callables",
1618    "manual_seed",
1619    "manual_seed_all",
1620    "max_memory_allocated",
1621    "max_memory_cached",
1622    "max_memory_reserved",
1623    "mem_get_info",
1624    "memory",
1625    "memory_allocated",
1626    "memory_cached",
1627    "memory_reserved",
1628    "memory_snapshot",
1629    "memory_stats",
1630    "memory_stats_as_nested_dict",
1631    "memory_summary",
1632    "memory_usage",
1633    "MemPool",
1634    "MemPoolContext",
1635    "use_mem_pool",
1636    "temperature",
1637    "power_draw",
1638    "clock_rate",
1639    "nccl",
1640    "nvtx",
1641    "profiler",
1642    "random",
1643    "reset_accumulated_memory_stats",
1644    "reset_max_memory_allocated",
1645    "reset_max_memory_cached",
1646    "reset_peak_memory_stats",
1647    "seed",
1648    "seed_all",
1649    "set_device",
1650    "set_per_process_memory_fraction",
1651    "set_rng_state",
1652    "set_rng_state_all",
1653    "set_stream",
1654    "set_sync_debug_mode",
1655    "sparse",
1656    "stream",
1657    "streams",
1658    "synchronize",
1659    "tunable",
1660    "utilization",
1661]
1662