1# mypy: allow-untyped-defs 2r""" 3This package adds support for CUDA tensor types. 4 5It implements the same function as CPU tensors, but they utilize 6GPUs for computation. 7 8It is lazily initialized, so you can always import it, and use 9:func:`is_available()` to determine if your system supports CUDA. 10 11:ref:`cuda-semantics` has more details about working with CUDA. 12""" 13 14import importlib 15import os 16import threading 17import traceback 18import warnings 19from functools import lru_cache 20from typing import Any, Callable, cast, List, Optional, Tuple, Union 21 22import torch 23import torch._C 24from torch import device as _device 25from torch._utils import _dummy_type, _LazySeedTracker, classproperty 26from torch.types import Device 27 28from . import gds 29from ._utils import _get_device_index 30from .graphs import ( 31 CUDAGraph, 32 graph, 33 graph_pool_handle, 34 is_current_stream_capturing, 35 make_graphed_callables, 36) 37from .streams import Event, ExternalStream, Stream 38 39 40try: 41 from torch._C import _cudart # type: ignore[attr-defined] 42except ImportError: 43 _cudart = None 44 45_initialized = False 46_tls = threading.local() 47_initialization_lock = threading.Lock() 48_queued_calls: List[ 49 Tuple[Callable[[], None], List[str]] 50] = [] # don't invoke these until initialization occurs 51_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) 52_device_t = Union[_device, str, int, None] 53 54_HAS_PYNVML = False 55_PYNVML_ERR = None 56try: 57 from torch import version as _version 58 59 try: 60 if not _version.hip: 61 import pynvml # type: ignore[import] 62 else: 63 import amdsmi # type: ignore[import] 64 65 _HAS_PYNVML = True 66 except ModuleNotFoundError: 67 pass 68 finally: 69 del _version 70except ImportError as err: 71 _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later 72 73_lazy_seed_tracker = _LazySeedTracker() 74 75# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA 76if hasattr(torch._C, "_CudaDeviceProperties"): 77 _CudaDeviceProperties = torch._C._CudaDeviceProperties 78else: 79 _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc] 80 81if hasattr(torch._C, "_cuda_exchangeDevice"): 82 _exchange_device = torch._C._cuda_exchangeDevice 83else: 84 85 def _exchange_device(device: int) -> int: 86 if device < 0: 87 return -1 88 raise RuntimeError("PyTorch was compiled without CUDA support") 89 90 91if hasattr(torch._C, "_cuda_maybeExchangeDevice"): 92 _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice 93else: 94 95 def _maybe_exchange_device(device: int) -> int: 96 if device < 0: 97 return -1 98 raise RuntimeError("PyTorch was compiled without CUDA support") 99 100 101has_half: bool = True 102has_magma: bool = torch._C._has_magma 103 104default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] 105 106 107def _is_compiled() -> bool: 108 r"""Return true if compile with CUDA support.""" 109 return hasattr(torch._C, "_cuda_getDeviceCount") 110 111 112def _nvml_based_avail() -> bool: 113 return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1" 114 115 116def is_available() -> bool: 117 r"""Return a bool indicating if CUDA is currently available.""" 118 if not _is_compiled(): 119 return False 120 if _nvml_based_avail(): 121 # The user has set an env variable to request this availability check that attempts to avoid fork poisoning by 122 # using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization 123 # fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`) 124 return device_count() > 0 125 else: 126 # The default availability inspection never throws and returns 0 if the driver is missing or can't 127 # be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver 128 # API via `cuInit` 129 return torch._C._cuda_getDeviceCount() > 0 130 131 132def is_bf16_supported(including_emulation: bool = True): 133 r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16.""" 134 # Check for ROCm, if true return true, no ROCM_VERSION check required, 135 # since it is supported on AMD GPU archs. 136 if torch.version.hip: 137 return True 138 139 # If CUDA is not available, than it does not support bf16 either 140 if not is_available(): 141 return False 142 143 device = torch.cuda.current_device() 144 145 # Check for CUDA version and device compute capability. 146 # This is a fast way to check for it. 147 cuda_version = torch.version.cuda 148 if ( 149 cuda_version is not None 150 and int(cuda_version.split(".")[0]) >= 11 151 and torch.cuda.get_device_properties(device).major >= 8 152 ): 153 return True 154 155 if not including_emulation: 156 return False 157 158 # Finally try to create a bfloat16 device. 159 return _check_bf16_tensor_supported(device) 160 161 162@lru_cache(maxsize=16) 163def _check_bf16_tensor_supported(device: _device_t): 164 try: 165 torch.tensor([1.0], dtype=torch.bfloat16, device=device) 166 return True 167 except Exception: 168 return False 169 170 171def _sleep(cycles): 172 torch._C._cuda_sleep(cycles) 173 174 175def _extract_arch_version(arch_string: str): 176 """Extracts the architecture string from a CUDA version""" 177 base = arch_string.split("_")[1] 178 if base.endswith("a"): 179 base = base[:-1] 180 return int(base) 181 182 183def _check_capability(): 184 incorrect_binary_warn = """ 185 Found GPU%d %s which requires CUDA_VERSION >= %d to 186 work properly, but your PyTorch was compiled 187 with CUDA_VERSION %d. Please install the correct PyTorch binary 188 using instructions from https://pytorch.org 189 """ 190 191 old_gpu_warn = """ 192 Found GPU%d %s which is of cuda capability %d.%d. 193 PyTorch no longer supports this GPU because it is too old. 194 The minimum cuda capability supported by this library is %d.%d. 195 """ 196 197 if torch.version.cuda is not None: # on ROCm we don't want this check 198 CUDA_VERSION = torch._C._cuda_getCompiledVersion() 199 for d in range(device_count()): 200 capability = get_device_capability(d) 201 major = capability[0] 202 minor = capability[1] 203 name = get_device_name(d) 204 current_arch = major * 10 + minor 205 min_arch = min( 206 (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), 207 default=35, 208 ) 209 if current_arch < min_arch: 210 warnings.warn( 211 old_gpu_warn 212 % (d, name, major, minor, min_arch // 10, min_arch % 10) 213 ) 214 215 216def _check_cubins(): 217 incompatible_device_warn = """ 218{} with CUDA capability sm_{} is not compatible with the current PyTorch installation. 219The current PyTorch install supports CUDA capabilities {}. 220If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/ 221""" 222 if torch.version.cuda is None: # on ROCm we don't want this check 223 return 224 arch_list = get_arch_list() 225 if len(arch_list) == 0: 226 return 227 supported_sm = [_extract_arch_version(arch) for arch in arch_list if "sm_" in arch] 228 for idx in range(device_count()): 229 cap_major, cap_minor = get_device_capability(idx) 230 # NVIDIA GPU compute architectures are backward compatible within major version 231 supported = any(sm // 10 == cap_major for sm in supported_sm) 232 if not supported: 233 device_name = get_device_name(idx) 234 capability = cap_major * 10 + cap_minor 235 warnings.warn( 236 incompatible_device_warn.format( 237 device_name, capability, " ".join(arch_list), device_name 238 ) 239 ) 240 241 242def is_initialized(): 243 r"""Return whether PyTorch's CUDA state has been initialized.""" 244 return _initialized and not _is_in_bad_fork() 245 246 247def _lazy_call(callable, **kwargs): 248 if is_initialized(): 249 callable() 250 else: 251 # TODO(torch_deploy): this accesses linecache, which attempts to read the 252 # file system to get traceback info. Patch linecache or do something 253 # else here if this ends up being important. 254 global _lazy_seed_tracker 255 if kwargs.get("seed_all", False): 256 _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) 257 elif kwargs.get("seed", False): 258 _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) 259 else: 260 # Don't store the actual traceback to avoid memory cycle 261 _queued_calls.append((callable, traceback.format_stack())) 262 263 264_lazy_call(_check_capability) 265_lazy_call(_check_cubins) 266 267 268class DeferredCudaCallError(Exception): 269 pass 270 271 272OutOfMemoryError = torch._C.OutOfMemoryError 273 274 275def init(): 276 r"""Initialize PyTorch's CUDA state. 277 278 You may need to call this explicitly if you are interacting with 279 PyTorch via its C API, as Python bindings for CUDA functionality 280 will not be available until this initialization takes place. 281 Ordinary users should not need this, as all of PyTorch's CUDA methods 282 automatically initialize CUDA state on-demand. 283 284 Does nothing if the CUDA state is already initialized. 285 """ 286 _lazy_init() 287 288 289def _lazy_init(): 290 global _initialized, _queued_calls 291 if is_initialized() or hasattr(_tls, "is_initializing"): 292 return 293 with _initialization_lock: 294 # We be double-checked locking, boys! This is OK because 295 # the above test was GIL protected anyway. The inner test 296 # is for when a thread blocked on some other thread which was 297 # doing the initialization; when they get the lock, they will 298 # find there is nothing left to do. 299 if is_initialized(): 300 return 301 # It is important to prevent other threads from entering _lazy_init 302 # immediately, while we are still guaranteed to have the GIL, because some 303 # of the C calls we make below will release the GIL 304 if _is_in_bad_fork(): 305 raise RuntimeError( 306 "Cannot re-initialize CUDA in forked subprocess. To use CUDA with " 307 "multiprocessing, you must use the 'spawn' start method" 308 ) 309 if not hasattr(torch._C, "_cuda_getDeviceCount"): 310 raise AssertionError("Torch not compiled with CUDA enabled") 311 if _cudart is None: 312 raise AssertionError( 313 "libcudart functions unavailable. It looks like you have a broken build?" 314 ) 315 # This function throws if there's a driver initialization error, no GPUs 316 # are found or any other error occurs 317 if "CUDA_MODULE_LOADING" not in os.environ: 318 os.environ["CUDA_MODULE_LOADING"] = "LAZY" 319 torch._C._cuda_init() 320 # Some of the queued calls may reentrantly call _lazy_init(); 321 # we need to just return without initializing in that case. 322 # However, we must not let any *other* threads in! 323 _tls.is_initializing = True 324 325 for calls in _lazy_seed_tracker.get_calls(): 326 if calls: 327 _queued_calls.append(calls) 328 329 try: 330 for queued_call, orig_traceback in _queued_calls: 331 try: 332 queued_call() 333 except Exception as e: 334 msg = ( 335 f"CUDA call failed lazily at initialization with error: {str(e)}\n\n" 336 f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}" 337 ) 338 raise DeferredCudaCallError(msg) from e 339 finally: 340 delattr(_tls, "is_initializing") 341 _initialized = True 342 343 344def cudart(): 345 r"""Retrieves the CUDA runtime API module. 346 347 348 This function initializes the CUDA runtime environment if it is not already 349 initialized and returns the CUDA runtime API module (_cudart). The CUDA 350 runtime API module provides access to various CUDA runtime functions. 351 352 Args: 353 ``None`` 354 355 Returns: 356 module: The CUDA runtime API module (_cudart). 357 358 Raises: 359 RuntimeError: If CUDA cannot be re-initialized in a forked subprocess. 360 AssertionError: If PyTorch is not compiled with CUDA support or if libcudart functions are unavailable. 361 362 Example of CUDA operations with profiling: 363 >>> import torch 364 >>> from torch.cuda import cudart, check_error 365 >>> import os 366 >>> 367 >>> os.environ['CUDA_PROFILE'] = '1' 368 >>> 369 >>> def perform_cuda_operations_with_streams(): 370 >>> stream = torch.cuda.Stream() 371 >>> with torch.cuda.stream(stream): 372 >>> x = torch.randn(100, 100, device='cuda') 373 >>> y = torch.randn(100, 100, device='cuda') 374 >>> z = torch.mul(x, y) 375 >>> return z 376 >>> 377 >>> torch.cuda.synchronize() 378 >>> print("====== Start nsys profiling ======") 379 >>> check_error(cudart().cudaProfilerStart()) 380 >>> with torch.autograd.profiler.emit_nvtx(): 381 >>> result = perform_cuda_operations_with_streams() 382 >>> print("CUDA operations completed.") 383 >>> check_error(torch.cuda.cudart().cudaProfilerStop()) 384 >>> print("====== End nsys profiling ======") 385 386 To run this example and save the profiling information, execute: 387 >>> $ nvprof --profile-from-start off --csv --print-summary -o trace_name.prof -f -- python cudart_test.py 388 389 This command profiles the CUDA operations in the provided script and saves 390 the profiling information to a file named `trace_name.prof`. 391 The `--profile-from-start off` option ensures that profiling starts only 392 after the `cudaProfilerStart` call in the script. 393 The `--csv` and `--print-summary` options format the profiling output as a 394 CSV file and print a summary, respectively. 395 The `-o` option specifies the output file name, and the `-f` option forces the 396 overwrite of the output file if it already exists. 397 """ 398 _lazy_init() 399 return _cudart 400 401 402class cudaStatus: 403 SUCCESS: int = 0 404 ERROR_NOT_READY: int = 34 405 406 407class CudaError(RuntimeError): 408 def __init__(self, code: int) -> None: 409 msg = _cudart.cudaGetErrorString(_cudart.cudaError(code)) 410 super().__init__(f"{msg} ({code})") 411 412 413def check_error(res: int) -> None: 414 if res != _cudart.cudaError.success: 415 raise CudaError(res) 416 417 418class _DeviceGuard: 419 def __init__(self, index: int): 420 self.idx = index 421 self.prev_idx = -1 422 423 def __enter__(self): 424 self.prev_idx = torch.cuda._exchange_device(self.idx) 425 426 def __exit__(self, type: Any, value: Any, traceback: Any): 427 self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) 428 return False 429 430 431class device: 432 r"""Context-manager that changes the selected device. 433 434 Args: 435 device (torch.device or int): device index to select. It's a no-op if 436 this argument is a negative integer or ``None``. 437 """ 438 439 def __init__(self, device: Any): 440 self.idx = _get_device_index(device, optional=True) 441 self.prev_idx = -1 442 443 def __enter__(self): 444 self.prev_idx = torch.cuda._exchange_device(self.idx) 445 446 def __exit__(self, type: Any, value: Any, traceback: Any): 447 self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) 448 return False 449 450 451class device_of(device): 452 r"""Context-manager that changes the current device to that of given object. 453 454 You can use both tensors and storages as arguments. If a given object is 455 not allocated on a GPU, this is a no-op. 456 457 Args: 458 obj (Tensor or Storage): object allocated on the selected device. 459 """ 460 461 def __init__(self, obj): 462 idx = obj.get_device() if obj.is_cuda else -1 463 super().__init__(idx) 464 465 466def set_device(device: _device_t) -> None: 467 r"""Set the current device. 468 469 Usage of this function is discouraged in favor of :any:`device`. In most 470 cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable. 471 472 Args: 473 device (torch.device or int): selected device. This function is a no-op 474 if this argument is negative. 475 """ 476 device = _get_device_index(device) 477 if device >= 0: 478 torch._C._cuda_setDevice(device) 479 480 481def get_device_name(device: Optional[_device_t] = None) -> str: 482 r"""Get the name of a device. 483 484 Args: 485 device (torch.device or int or str, optional): device for which to return the 486 name. This function is a no-op if this argument is a negative 487 integer. It uses the current device, given by :func:`~torch.cuda.current_device`, 488 if :attr:`device` is ``None`` (default). 489 490 Returns: 491 str: the name of the device 492 """ 493 return get_device_properties(device).name 494 495 496def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]: 497 r"""Get the cuda capability of a device. 498 499 Args: 500 device (torch.device or int or str, optional): device for which to return the 501 device capability. This function is a no-op if this argument is 502 a negative integer. It uses the current device, given by 503 :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` 504 (default). 505 506 Returns: 507 tuple(int, int): the major and minor cuda capability of the device 508 """ 509 prop = get_device_properties(device) 510 return prop.major, prop.minor 511 512 513def get_device_properties(device: _device_t) -> _CudaDeviceProperties: 514 r"""Get the properties of a device. 515 516 Args: 517 device (torch.device or int or str): device for which to return the 518 properties of the device. 519 520 Returns: 521 _CudaDeviceProperties: the properties of the device 522 """ 523 _lazy_init() # will define _get_device_properties 524 device = _get_device_index(device, optional=True) 525 if device < 0 or device >= device_count(): 526 raise AssertionError("Invalid device id") 527 return _get_device_properties(device) # type: ignore[name-defined] 528 529 530def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool: 531 r"""Check if peer access between two devices is possible.""" 532 _lazy_init() 533 device = _get_device_index(device, optional=True) 534 peer_device = _get_device_index(peer_device) 535 if device < 0 or device >= device_count(): 536 raise AssertionError("Invalid device id") 537 if peer_device < 0 or peer_device >= device_count(): 538 raise AssertionError("Invalid peer device id") 539 return torch._C._cuda_canDeviceAccessPeer(device, peer_device) 540 541 542class StreamContext: 543 r"""Context-manager that selects a given stream. 544 545 All CUDA kernels queued within its context will be enqueued on a selected 546 stream. 547 548 Args: 549 Stream (Stream): selected stream. This manager is a no-op if it's 550 ``None``. 551 .. note:: Streams are per-device. 552 """ 553 cur_stream: Optional["torch.cuda.Stream"] 554 555 def __init__(self, stream: Optional["torch.cuda.Stream"]): 556 self.stream = stream 557 self.idx = _get_device_index(None, True) 558 if not torch.jit.is_scripting(): 559 if self.idx is None: 560 self.idx = -1 561 562 self.src_prev_stream = ( 563 None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) 564 ) 565 self.dst_prev_stream = ( 566 None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) 567 ) 568 569 def __enter__(self): 570 # Local cur_stream variable for type refinement 571 cur_stream = self.stream 572 # Return if stream is None or CUDA device not available 573 if cur_stream is None or self.idx == -1: 574 return 575 self.src_prev_stream = torch.cuda.current_stream(None) 576 577 # If the stream is not on the current device, then 578 # set the current stream on the device 579 if self.src_prev_stream.device != cur_stream.device: 580 with device(cur_stream.device): 581 self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device) 582 torch.cuda.set_stream(cur_stream) 583 584 def __exit__(self, type: Any, value: Any, traceback: Any): 585 # Local cur_stream variable for type refinement 586 cur_stream = self.stream 587 # If stream is None or no CUDA device available, return 588 if cur_stream is None or self.idx == -1: 589 return 590 591 # Reset the stream on the original device 592 # and destination device 593 if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] 594 torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type] 595 torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type] 596 597 598def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext: 599 r"""Wrap around the Context-manager StreamContext that selects a given stream. 600 601 Arguments: 602 stream (Stream): selected stream. This manager is a no-op if it's 603 ``None``. 604 ..Note:: In eager mode stream is of type Stream class while in JIT it is 605 an object of the custom class ``torch.classes.cuda.Stream``. 606 """ 607 return StreamContext(stream) 608 609 610def _set_stream_by_id(stream_id, device_index, device_type): 611 r"""set stream specified by the stream id, device index and 612 device type 613 614 Args: stream_id (int): stream id in stream pool 615 device_index (int): device index in topo 616 device_type (int): enum device type 617 """ 618 torch._C._cuda_setStream( 619 stream_id=stream_id, 620 device_index=device_index, 621 device_type=device_type, 622 ) 623 624 625def set_stream(stream: Stream): 626 r"""Set the current stream.This is a wrapper API to set the stream. 627 Usage of this function is discouraged in favor of the ``stream`` 628 context manager. 629 630 Args: 631 stream (Stream): selected stream. This function is a no-op 632 if this argument is ``None``. 633 """ 634 if stream is None: 635 return 636 _set_stream_by_id( 637 stream_id=stream.stream_id, 638 device_index=stream.device_index, 639 device_type=stream.device_type, 640 ) 641 642 643def _parse_visible_devices() -> Union[List[int], List[str]]: 644 r"""Parse CUDA_VISIBLE_DEVICES environment variable.""" 645 var = os.getenv("CUDA_VISIBLE_DEVICES") 646 647 if torch.version.hip: 648 hip_devices = os.getenv("HIP_VISIBLE_DEVICES") 649 if hip_devices is not None: 650 var = hip_devices 651 652 if var is None: 653 return list(range(64)) 654 655 def _strtoul(s: str) -> int: 656 """Return -1 or positive integer sequence string starts with.""" 657 if not s: 658 return -1 659 for idx, c in enumerate(s): 660 if not (c.isdigit() or (idx == 0 and c in "+-")): 661 break 662 if idx + 1 == len(s): 663 idx += 1 664 return int(s[:idx]) if idx > 0 else -1 665 666 def parse_list_with_prefix(lst: str, prefix: str) -> List[str]: 667 rcs: List[str] = [] 668 for elem in lst.split(","): 669 # Repeated id results in empty set 670 if elem in rcs: 671 return cast(List[str], []) 672 # Anything other but prefix is ignored 673 if not elem.startswith(prefix): 674 break 675 rcs.append(elem) 676 return rcs 677 678 if var.startswith("GPU-"): 679 return parse_list_with_prefix(var, "GPU-") 680 if var.startswith("MIG-"): 681 return parse_list_with_prefix(var, "MIG-") 682 # CUDA_VISIBLE_DEVICES uses something like strtoul 683 # which makes `1gpu2,2ampere` is equivalent to `1,2` 684 rc: List[int] = [] 685 for elem in var.split(","): 686 x = _strtoul(elem.strip()) 687 # Repeated ordinal results in empty set 688 if x in rc: 689 return cast(List[int], []) 690 # Negative value aborts the sequence 691 if x < 0: 692 break 693 rc.append(x) 694 return rc 695 696 697def _raw_device_count_amdsmi() -> int: 698 if not _HAS_PYNVML: # If amdsmi is not available 699 return -1 700 try: 701 amdsmi.amdsmi_init() 702 except amdsmi.AmdSmiException as e: 703 warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") 704 return -1 705 socket_handles = amdsmi.amdsmi_get_processor_handles() 706 return len(socket_handles) 707 708 709def _raw_device_count_nvml() -> int: 710 r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed.""" 711 from ctypes import byref, c_int, CDLL 712 713 nvml_h = CDLL("libnvidia-ml.so.1") 714 rc = nvml_h.nvmlInit() 715 if rc != 0: 716 warnings.warn("Can't initialize NVML") 717 return -1 718 dev_count = c_int(-1) 719 rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) 720 if rc != 0: 721 warnings.warn("Can't get nvml device count") 722 return -1 723 del nvml_h 724 return dev_count.value 725 726 727def _raw_device_uuid_amdsmi() -> Optional[List[str]]: 728 from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer 729 730 if not _HAS_PYNVML: # If amdsmi is not available 731 return None 732 try: 733 amdsmi.amdsmi_init() 734 except amdsmi.AmdSmiException: 735 warnings.warn("Can't initialize amdsmi") 736 return None 737 try: 738 socket_handles = amdsmi.amdsmi_get_processor_handles() 739 dev_count = len(socket_handles) 740 except amdsmi.AmdSmiException: 741 warnings.warn("Can't get amdsmi device count") 742 return None 743 uuids: List[str] = [] 744 for idx in range(dev_count): 745 try: 746 handler = amdsmi.amdsmi_get_processor_handles()[idx] 747 except amdsmi.AmdSmiException: 748 warnings.warn("Cannot get amd device handler") 749 return None 750 try: 751 uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler) 752 except amdsmi.AmdSmiException: 753 warnings.warn("Cannot get uuid for amd device") 754 return None 755 uuids.append(str(uuid)) 756 return uuids 757 758 759def _raw_device_uuid_nvml() -> Optional[List[str]]: 760 r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed.""" 761 from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer 762 763 nvml_h = CDLL("libnvidia-ml.so.1") 764 rc = nvml_h.nvmlInit() 765 if rc != 0: 766 warnings.warn("Can't initialize NVML") 767 return None 768 dev_count = c_int(-1) 769 rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) 770 if rc != 0: 771 warnings.warn("Can't get nvml device count") 772 return None 773 uuids: List[str] = [] 774 for idx in range(dev_count.value): 775 dev_id = c_void_p() 776 rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) 777 if rc != 0: 778 warnings.warn("Can't get device handle") 779 return None 780 buf_len = 96 781 buf = create_string_buffer(buf_len) 782 rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) 783 if rc != 0: 784 warnings.warn("Can't get device UUID") 785 return None 786 uuids.append(buf.raw.decode("ascii").strip("\0")) 787 del nvml_h 788 return uuids 789 790 791def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]: 792 r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs.""" 793 794 def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: 795 best_match = -1 796 for idx, uuid in enumerate(uuids): 797 if not uuid.startswith(candidate): 798 continue 799 # Ambiguous candidate 800 if best_match != -1: 801 return -1 802 best_match = idx 803 return best_match 804 805 rc: List[int] = [] 806 for candidate in candidates: 807 idx = uuid_to_orinal(candidate, uuids) 808 # First invalid ordinal stops parsing 809 if idx < 0: 810 break 811 # Duplicates result in empty set 812 if idx in rc: 813 return cast(List[int], []) 814 rc.append(idx) 815 return rc 816 817 818def _device_count_amdsmi() -> int: 819 visible_devices = _parse_visible_devices() 820 if not visible_devices: 821 return 0 822 try: 823 if type(visible_devices[0]) is str: 824 return -1 825 else: 826 raw_cnt = _raw_device_count_amdsmi() 827 if raw_cnt <= 0: 828 return raw_cnt 829 # Trim the list up to a maximum available device 830 for idx, val in enumerate(visible_devices): 831 if cast(int, val) >= raw_cnt: 832 return idx 833 except OSError: 834 return -1 835 except AttributeError: 836 return -1 837 return len(visible_devices) 838 839 840def _device_count_nvml() -> int: 841 r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. 842 843 Negative value is returned if NVML discovery or initialization has failed. 844 """ 845 visible_devices = _parse_visible_devices() 846 if not visible_devices: 847 return 0 848 try: 849 if type(visible_devices[0]) is str: 850 # Skip MIG parsing 851 if visible_devices[0].startswith("MIG-"): 852 return -1 853 uuids = _raw_device_uuid_nvml() 854 if uuids is None: 855 return -1 856 visible_devices = _transform_uuid_to_ordinals( 857 cast(List[str], visible_devices), uuids 858 ) 859 else: 860 raw_cnt = _raw_device_count_nvml() 861 if raw_cnt <= 0: 862 return raw_cnt 863 # Trim the list up to a maximum available device 864 for idx, val in enumerate(visible_devices): 865 if cast(int, val) >= raw_cnt: 866 return idx 867 except OSError: 868 return -1 869 except AttributeError: 870 return -1 871 return len(visible_devices) 872 873 874def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: 875 r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account.""" 876 idx = _get_device_index(device, optional=True) 877 visible_devices = _parse_visible_devices() 878 if type(visible_devices[0]) is str: 879 uuids = _raw_device_uuid_nvml() 880 if uuids is None: 881 raise RuntimeError("Can't get device UUIDs") 882 visible_devices = _transform_uuid_to_ordinals( 883 cast(List[str], visible_devices), uuids 884 ) 885 visible_devices = cast(List[int], visible_devices) 886 if idx < 0 or idx >= len(visible_devices): 887 raise RuntimeError( 888 f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})" 889 ) 890 return visible_devices[idx] 891 892 893_cached_device_count: Optional[int] = None 894 895 896def device_count() -> int: 897 r"""Return the number of GPUs available.""" 898 global _cached_device_count 899 if not _is_compiled(): 900 return 0 901 if _cached_device_count is not None: 902 return _cached_device_count 903 # bypass _device_count_nvml() if rocm (not supported) 904 nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml() 905 r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count 906 # NB: Do not cache the device count prior to CUDA initialization, because 907 # the number of devices can change due to changes to CUDA_VISIBLE_DEVICES 908 # setting prior to CUDA initialization. 909 if _initialized: 910 _cached_device_count = r 911 return r 912 913 914def get_arch_list() -> List[str]: 915 r"""Return list CUDA architectures this library was compiled for.""" 916 if not is_available(): 917 return [] 918 arch_flags = torch._C._cuda_getArchFlags() 919 if arch_flags is None: 920 return [] 921 return arch_flags.split() 922 923 924def get_gencode_flags() -> str: 925 r"""Return NVCC gencode flags this library was compiled with.""" 926 arch_list = get_arch_list() 927 if len(arch_list) == 0: 928 return "" 929 arch_list_ = [arch.split("_") for arch in arch_list] 930 return " ".join( 931 [ 932 f"-gencode compute=compute_{arch},code={kind}_{arch}" 933 for (kind, arch) in arch_list_ 934 ] 935 ) 936 937 938def current_device() -> int: 939 r"""Return the index of a currently selected device.""" 940 _lazy_init() 941 return torch._C._cuda_getDevice() 942 943 944def synchronize(device: _device_t = None) -> None: 945 r"""Wait for all kernels in all streams on a CUDA device to complete. 946 947 Args: 948 device (torch.device or int, optional): device for which to synchronize. 949 It uses the current device, given by :func:`~torch.cuda.current_device`, 950 if :attr:`device` is ``None`` (default). 951 """ 952 _lazy_init() 953 with torch.cuda.device(device): 954 return torch._C._cuda_synchronize() 955 956 957def ipc_collect(): 958 r"""Force collects GPU memory after it has been released by CUDA IPC. 959 960 .. note:: 961 Checks if any sent CUDA tensors could be cleaned from the memory. Force 962 closes shared memory file used for reference counting if there is no 963 active counters. Useful when the producer process stopped actively sending 964 tensors and want to release unused memory. 965 """ 966 _lazy_init() 967 return torch._C._cuda_ipc_collect() 968 969 970def current_stream(device: Optional[_device_t] = None) -> Stream: 971 r"""Return the currently selected :class:`Stream` for a given device. 972 973 Args: 974 device (torch.device or int, optional): selected device. Returns 975 the currently selected :class:`Stream` for the current device, given 976 by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` 977 (default). 978 """ 979 _lazy_init() 980 streamdata = torch._C._cuda_getCurrentStream( 981 _get_device_index(device, optional=True) 982 ) 983 return Stream( 984 stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] 985 ) 986 987 988def default_stream(device: Optional[_device_t] = None) -> Stream: 989 r"""Return the default :class:`Stream` for a given device. 990 991 Args: 992 device (torch.device or int, optional): selected device. Returns 993 the default :class:`Stream` for the current device, given by 994 :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` 995 (default). 996 """ 997 _lazy_init() 998 streamdata = torch._C._cuda_getDefaultStream( 999 _get_device_index(device, optional=True) 1000 ) 1001 return Stream( 1002 stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] 1003 ) 1004 1005 1006def current_blas_handle(): 1007 r"""Return cublasHandle_t pointer to current cuBLAS handle""" 1008 _lazy_init() 1009 return torch._C._cuda_getCurrentBlasHandle() 1010 1011 1012def set_sync_debug_mode(debug_mode: Union[int, str]) -> None: 1013 r"""Set the debug mode for cuda synchronizing operations. 1014 1015 Args: 1016 debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations, 1017 if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations. 1018 1019 Warning: 1020 This is an experimental feature, and not all synchronizing operations will trigger warning or error. In 1021 particular, operations in torch.distributed and torch.sparse namespaces are not covered yet. 1022 """ 1023 _lazy_init() 1024 if isinstance(debug_mode, str): 1025 if debug_mode == "default": 1026 debug_mode = 0 1027 elif debug_mode == "warn": 1028 debug_mode = 1 1029 elif debug_mode == "error": 1030 debug_mode = 2 1031 else: 1032 raise RuntimeError( 1033 "invalid value of debug_mode, expected one of `default`, `warn`, `error`" 1034 ) 1035 1036 torch._C._cuda_set_sync_debug_mode(debug_mode) 1037 1038 1039def get_sync_debug_mode() -> int: 1040 r"""Return current value of debug mode for cuda synchronizing operations.""" 1041 _lazy_init() 1042 return torch._C._cuda_get_sync_debug_mode() 1043 1044 1045def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): 1046 if not _HAS_PYNVML: 1047 raise ModuleNotFoundError( 1048 "pynvml does not seem to be installed or it can't be imported." 1049 ) from _PYNVML_ERR 1050 from pynvml import NVMLError_DriverNotLoaded 1051 1052 try: 1053 pynvml.nvmlInit() 1054 except NVMLError_DriverNotLoaded as e: 1055 raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e 1056 1057 device = _get_nvml_device_index(device) 1058 handle = pynvml.nvmlDeviceGetHandleByIndex(device) 1059 return handle 1060 1061 1062def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None): 1063 if not _HAS_PYNVML: 1064 raise ModuleNotFoundError( 1065 "amdsmi does not seem to be installed or it can't be imported." 1066 ) from _PYNVML_ERR 1067 try: 1068 amdsmi.amdsmi_init() 1069 except amdsmi.AmdSmiException as e: 1070 raise RuntimeError( 1071 "amdsmi driver can't be loaded, requires >=ROCm5.6 installation" 1072 ) from e 1073 device = _get_amdsmi_device_index(device) 1074 handle = amdsmi.amdsmi_get_processor_handles()[device] 1075 return handle 1076 1077 1078def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: 1079 r"""Return the amdsmi index of the device, taking visible_devices into account.""" 1080 idx = _get_device_index(device, optional=True) 1081 visible_devices = _parse_visible_devices() 1082 if type(visible_devices[0]) is str: 1083 raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings") 1084 idx_map = dict(enumerate(cast(List[int], visible_devices))) 1085 if idx not in idx_map: 1086 raise RuntimeError( 1087 f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})" 1088 ) 1089 return idx_map[idx] 1090 1091 1092def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: 1093 handle = _get_amdsmi_handler() 1094 device = _get_amdsmi_device_index(device) 1095 return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] 1096 1097 1098def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: 1099 handle = _get_amdsmi_handler() 1100 device = _get_amdsmi_device_index(device) 1101 handle = amdsmi.amdsmi_get_processor_handles()[device] 1102 return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] 1103 1104 1105def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: 1106 handle = _get_amdsmi_handler(device) 1107 return amdsmi.amdsmi_get_temp_metric( 1108 handle, 1109 amdsmi.AmdSmiTemperatureType.JUNCTION, 1110 amdsmi.AmdSmiTemperatureMetric.CURRENT, 1111 ) 1112 1113 1114def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: 1115 handle = _get_amdsmi_handler(device) 1116 socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] 1117 if socket_power != "N/A": 1118 return socket_power 1119 else: 1120 return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"] 1121 1122 1123def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: 1124 handle = _get_amdsmi_handler(device) 1125 clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX) 1126 if "cur_clk" in clock_info: # ROCm 6.2 deprecation 1127 return clock_info["cur_clk"] 1128 else: 1129 return clock_info["clk"] 1130 1131 1132def memory_usage(device: Optional[Union[Device, int]] = None) -> int: 1133 r"""Return the percent of time over the past sample period during which global (device) 1134 memory was being read or written as given by `nvidia-smi`. 1135 1136 Args: 1137 device (torch.device or int, optional): selected device. Returns 1138 statistic for the current device, given by :func:`~torch.cuda.current_device`, 1139 if :attr:`device` is ``None`` (default). 1140 1141 Warning: Each sample period may be between 1 second and 1/6 second, 1142 depending on the product being queried. 1143 """ 1144 if not torch.version.hip: 1145 handle = _get_pynvml_handler() 1146 device = _get_nvml_device_index(device) 1147 handle = pynvml.nvmlDeviceGetHandleByIndex(device) 1148 return pynvml.nvmlDeviceGetUtilizationRates(handle).memory 1149 else: 1150 return _get_amdsmi_memory_usage(device) 1151 1152 1153def utilization(device: Optional[Union[Device, int]] = None) -> int: 1154 r"""Return the percent of time over the past sample period during which one or 1155 more kernels was executing on the GPU as given by `nvidia-smi`. 1156 1157 Args: 1158 device (torch.device or int, optional): selected device. Returns 1159 statistic for the current device, given by :func:`~torch.cuda.current_device`, 1160 if :attr:`device` is ``None`` (default). 1161 1162 Warning: Each sample period may be between 1 second and 1/6 second, 1163 depending on the product being queried. 1164 """ 1165 if not torch.version.hip: 1166 handle = _get_pynvml_handler(device) 1167 device = _get_nvml_device_index(device) 1168 handle = pynvml.nvmlDeviceGetHandleByIndex(device) 1169 return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu 1170 else: 1171 return _get_amdsmi_utilization(device) 1172 1173 1174def temperature(device: Optional[Union[Device, int]] = None) -> int: 1175 r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades). 1176 1177 The average temperature is computed based on past sample period as given by `nvidia-smi`. 1178 1179 Args: 1180 device (torch.device or int, optional): selected device. Returns 1181 statistic for the current device, given by :func:`~torch.cuda.current_device`, 1182 if :attr:`device` is ``None`` (default). 1183 1184 Warning: Each sample period may be between 1 second and 1/6 second, 1185 depending on the product being queried. 1186 """ 1187 if not torch.version.hip: 1188 handle = _get_pynvml_handler(device) 1189 # 0 refers to the temperature sensor for the GPU die. 1190 return pynvml.nvmlDeviceGetTemperature(handle, 0) 1191 else: 1192 return _get_amdsmi_temperature(device) 1193 1194 1195def power_draw(device: Optional[Union[Device, int]] = None) -> int: 1196 r"""Return the average power draw of the GPU sensor in mW (MilliWatts) 1197 over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices. 1198 1199 Args: 1200 device (torch.device or int, optional): selected device. Returns 1201 statistic for the current device, given by :func:`~torch.cuda.current_device`, 1202 if :attr:`device` is ``None`` (default). 1203 1204 Warning: Each sample period may be between 1 second and 1/6 second, 1205 depending on the product being queried. 1206 """ 1207 if not torch.version.hip: 1208 handle = _get_pynvml_handler(device) 1209 return pynvml.nvmlDeviceGetPowerUsage(handle) 1210 else: 1211 return _get_amdsmi_power_draw(device) 1212 1213 1214def clock_rate(device: Optional[Union[Device, int]] = None) -> int: 1215 r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`. 1216 1217 Args: 1218 device (torch.device or int, optional): selected device. Returns 1219 statistic for the current device, given by :func:`~torch.cuda.current_device`, 1220 if :attr:`device` is ``None`` (default). 1221 1222 Warning: Each sample period may be between 1 second and 1/6 second, 1223 depending on the product being queried. 1224 """ 1225 if not torch.version.hip: 1226 handle = _get_pynvml_handler(device) 1227 return pynvml.nvmlDeviceGetClockInfo(handle, 1) 1228 else: 1229 return _get_amdsmi_clock_rate(device) 1230 1231 1232def _get_device(device: Union[int, str, torch.device]) -> torch.device: 1233 r"""Return the torch.device type object from the passed in device. 1234 1235 Args: 1236 device (torch.device or int): selected device. 1237 """ 1238 if isinstance(device, str): 1239 device = torch.device(device) 1240 elif isinstance(device, int): 1241 device = torch.device("cuda", device) 1242 return device 1243 1244 1245def _get_generator(device: torch.device) -> torch._C.Generator: 1246 r"""Return the CUDA Generator object for the given device. 1247 1248 Args: 1249 device (torch.device): selected device. 1250 """ 1251 idx = device.index 1252 if idx is None: 1253 idx = current_device() 1254 return torch.cuda.default_generators[idx] 1255 1256 1257def _set_rng_state_offset( 1258 offset: int, device: Union[int, str, torch.device] = "cuda" 1259) -> None: 1260 r"""Set the random number generator state offset of the specified GPU. 1261 1262 Args: 1263 offset (int): The desired offset 1264 device (torch.device or int, optional): The device to set the RNG state. 1265 Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). 1266 """ 1267 final_device = _get_device(device) 1268 1269 def cb(): 1270 default_generator = _get_generator(final_device) 1271 default_generator.set_offset(offset) 1272 1273 _lazy_call(cb) 1274 1275 1276def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int: 1277 r"""Return the random number generator state offset of the specified GPU. 1278 1279 Args: 1280 device (torch.device or int, optional): The device to return the RNG state offset of. 1281 Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). 1282 1283 .. warning:: 1284 This function eagerly initializes CUDA. 1285 """ 1286 _lazy_init() 1287 final_device = _get_device(device) 1288 default_generator = _get_generator(final_device) 1289 return default_generator.get_offset() 1290 1291 1292from .memory import * # noqa: F403 1293from .random import * # noqa: F403 1294 1295 1296################################################################################ 1297# Define Storage and Tensor classes 1298################################################################################ 1299 1300 1301@staticmethod # type: ignore[misc] 1302def _lazy_new(cls, *args, **kwargs): 1303 _lazy_init() 1304 # We may need to call lazy init again if we are a forked child 1305 # del _CudaBase.__new__ 1306 return super(_CudaBase, cls).__new__(cls, *args, **kwargs) 1307 1308 1309class _CudaBase: 1310 is_cuda = True 1311 is_sparse = False 1312 1313 def type(self, *args, **kwargs): 1314 # We could use a Protocol here to tell mypy that self has `get_device` method 1315 # but it is only available in the typing module on Python >= 3.8 1316 # or on typing_extensions module on Python >= 3.6 1317 with device(self.get_device()): # type: ignore[attr-defined] 1318 return super().type(*args, **kwargs) # type: ignore[misc] 1319 1320 __new__ = _lazy_new 1321 1322 1323from torch.storage import _LegacyStorage, _warn_typed_storage_removal 1324 1325 1326class _CudaLegacyStorage(_LegacyStorage): 1327 @classmethod 1328 def from_buffer(cls, *args, **kwargs): 1329 _warn_typed_storage_removal() 1330 raise RuntimeError("from_buffer: Not available for CUDA storage") 1331 1332 @classmethod 1333 def _new_with_weak_ptr(cls, *args, **kwargs): 1334 raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage") 1335 1336 @classmethod 1337 def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None): 1338 raise RuntimeError("_new_shared_filename: Not available for CUDA storage") 1339 1340 1341class ByteStorage(_CudaLegacyStorage): 1342 @classproperty 1343 def dtype(self): 1344 _warn_typed_storage_removal() 1345 return self._dtype 1346 1347 @classproperty 1348 def _dtype(self): 1349 return torch.uint8 1350 1351 1352class DoubleStorage(_CudaLegacyStorage): 1353 @classproperty 1354 def dtype(self): 1355 _warn_typed_storage_removal() 1356 return self._dtype 1357 1358 @classproperty 1359 def _dtype(self): 1360 return torch.double 1361 1362 1363class FloatStorage(_CudaLegacyStorage): 1364 @classproperty 1365 def dtype(self): 1366 _warn_typed_storage_removal() 1367 return self._dtype 1368 1369 @classproperty 1370 def _dtype(self): 1371 return torch.float 1372 1373 1374class HalfStorage(_CudaLegacyStorage): 1375 @classproperty 1376 def dtype(self): 1377 _warn_typed_storage_removal() 1378 return self._dtype 1379 1380 @classproperty 1381 def _dtype(self): 1382 return torch.half 1383 1384 1385class LongStorage(_CudaLegacyStorage): 1386 @classproperty 1387 def dtype(self): 1388 _warn_typed_storage_removal() 1389 return self._dtype 1390 1391 @classproperty 1392 def _dtype(self): 1393 return torch.long 1394 1395 1396class IntStorage(_CudaLegacyStorage): 1397 @classproperty 1398 def dtype(self): 1399 _warn_typed_storage_removal() 1400 return self._dtype 1401 1402 @classproperty 1403 def _dtype(self): 1404 return torch.int 1405 1406 1407class ShortStorage(_CudaLegacyStorage): 1408 @classproperty 1409 def dtype(self): 1410 _warn_typed_storage_removal() 1411 return self._dtype 1412 1413 @classproperty 1414 def _dtype(self): 1415 return torch.short 1416 1417 1418class CharStorage(_CudaLegacyStorage): 1419 @classproperty 1420 def dtype(self): 1421 _warn_typed_storage_removal() 1422 return self._dtype 1423 1424 @classproperty 1425 def _dtype(self): 1426 return torch.int8 1427 1428 1429class BoolStorage(_CudaLegacyStorage): 1430 @classproperty 1431 def dtype(self): 1432 _warn_typed_storage_removal() 1433 return self._dtype 1434 1435 @classproperty 1436 def _dtype(self): 1437 return torch.bool 1438 1439 1440class BFloat16Storage(_CudaLegacyStorage): 1441 @classproperty 1442 def dtype(self): 1443 _warn_typed_storage_removal() 1444 return self._dtype 1445 1446 @classproperty 1447 def _dtype(self): 1448 return torch.bfloat16 1449 1450 1451class ComplexDoubleStorage(_CudaLegacyStorage): 1452 @classproperty 1453 def dtype(self): 1454 _warn_typed_storage_removal() 1455 return self._dtype 1456 1457 @classproperty 1458 def _dtype(self): 1459 return torch.cdouble 1460 1461 1462class ComplexFloatStorage(_CudaLegacyStorage): 1463 @classproperty 1464 def dtype(self): 1465 _warn_typed_storage_removal() 1466 return self._dtype 1467 1468 @classproperty 1469 def _dtype(self): 1470 return torch.cfloat 1471 1472 1473del _LegacyStorage 1474del _CudaLegacyStorage 1475 1476torch._storage_classes.add(DoubleStorage) 1477torch._storage_classes.add(FloatStorage) 1478torch._storage_classes.add(LongStorage) 1479torch._storage_classes.add(IntStorage) 1480torch._storage_classes.add(ShortStorage) 1481torch._storage_classes.add(CharStorage) 1482torch._storage_classes.add(ByteStorage) 1483torch._storage_classes.add(HalfStorage) 1484torch._storage_classes.add(BoolStorage) 1485torch._storage_classes.add(BFloat16Storage) 1486torch._storage_classes.add(ComplexDoubleStorage) 1487torch._storage_classes.add(ComplexFloatStorage) 1488 1489 1490class _WrappedTritonKernel: 1491 """Just a simple wrapper to store some metadata for testing purposes.""" 1492 1493 def __init__(self, kernel): 1494 self.kernel = kernel 1495 self.kernel_invoked = False 1496 1497 def __call__(self, *args, **kwargs): 1498 res = self.kernel(*args, **kwargs) 1499 self.kernel_invoked = True 1500 return res 1501 1502 1503def _register_triton_kernels(): 1504 if torch._running_with_deploy(): 1505 return 1506 1507 @_WrappedTritonKernel 1508 def kernel_impl(*args, **kwargs): 1509 from torch.sparse._triton_ops import bsr_dense_mm 1510 1511 return bsr_dense_mm(*args, skip_checks=True, **kwargs) 1512 1513 @_WrappedTritonKernel 1514 def addmm_kernel_impl(*args, **kwargs): 1515 from torch.sparse._triton_ops import bsr_dense_addmm 1516 1517 return bsr_dense_addmm(*args, skip_checks=True, **kwargs) 1518 1519 has_triton = importlib.util.find_spec("triton") is not None 1520 if has_triton: 1521 torch._TritonLibrary.registerOp( 1522 "_triton_bsr_dense_mm_out", 1523 "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)", 1524 kernel_impl, 1525 "SparseCsrCUDA", 1526 ) 1527 1528 torch._TritonLibrary.registerOp( 1529 "_triton_bsr_dense_addmm_out", 1530 ( 1531 "_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense," 1532 " *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)" 1533 ), 1534 addmm_kernel_impl, 1535 "SparseCsrCUDA", 1536 ) 1537 1538 1539_lazy_call(_register_triton_kernels) 1540 1541 1542from . import amp, jiterator, nvtx, profiler, sparse, tunable 1543 1544 1545__all__ = [ 1546 # Typed storage and tensors 1547 "BFloat16Storage", 1548 "BFloat16Tensor", 1549 "BoolStorage", 1550 "BoolTensor", 1551 "ByteStorage", 1552 "ByteTensor", 1553 "CharStorage", 1554 "CharTensor", 1555 "ComplexDoubleStorage", 1556 "ComplexFloatStorage", 1557 "DoubleStorage", 1558 "DoubleTensor", 1559 "FloatStorage", 1560 "FloatTensor", 1561 "HalfStorage", 1562 "HalfTensor", 1563 "IntStorage", 1564 "IntTensor", 1565 "LongStorage", 1566 "LongTensor", 1567 "ShortStorage", 1568 "ShortTensor", 1569 "CUDAGraph", 1570 "CudaError", 1571 "DeferredCudaCallError", 1572 "Event", 1573 "ExternalStream", 1574 "Stream", 1575 "StreamContext", 1576 "amp", 1577 "caching_allocator_alloc", 1578 "caching_allocator_delete", 1579 "can_device_access_peer", 1580 "check_error", 1581 "cudaStatus", 1582 "cudart", 1583 "current_blas_handle", 1584 "current_device", 1585 "current_stream", 1586 "default_generators", 1587 "default_stream", 1588 "device", 1589 "device_count", 1590 "device_of", 1591 "empty_cache", 1592 "get_allocator_backend", 1593 "CUDAPluggableAllocator", 1594 "change_current_allocator", 1595 "get_arch_list", 1596 "get_device_capability", 1597 "get_device_name", 1598 "get_device_properties", 1599 "get_gencode_flags", 1600 "get_rng_state", 1601 "get_rng_state_all", 1602 "get_sync_debug_mode", 1603 "graph", 1604 "graph_pool_handle", 1605 "graphs", 1606 "has_half", 1607 "has_magma", 1608 "init", 1609 "initial_seed", 1610 "ipc_collect", 1611 "is_available", 1612 "is_bf16_supported", 1613 "is_current_stream_capturing", 1614 "is_initialized", 1615 "jiterator", 1616 "list_gpu_processes", 1617 "make_graphed_callables", 1618 "manual_seed", 1619 "manual_seed_all", 1620 "max_memory_allocated", 1621 "max_memory_cached", 1622 "max_memory_reserved", 1623 "mem_get_info", 1624 "memory", 1625 "memory_allocated", 1626 "memory_cached", 1627 "memory_reserved", 1628 "memory_snapshot", 1629 "memory_stats", 1630 "memory_stats_as_nested_dict", 1631 "memory_summary", 1632 "memory_usage", 1633 "MemPool", 1634 "MemPoolContext", 1635 "use_mem_pool", 1636 "temperature", 1637 "power_draw", 1638 "clock_rate", 1639 "nccl", 1640 "nvtx", 1641 "profiler", 1642 "random", 1643 "reset_accumulated_memory_stats", 1644 "reset_max_memory_allocated", 1645 "reset_max_memory_cached", 1646 "reset_peak_memory_stats", 1647 "seed", 1648 "seed_all", 1649 "set_device", 1650 "set_per_process_memory_fraction", 1651 "set_rng_state", 1652 "set_rng_state_all", 1653 "set_stream", 1654 "set_sync_debug_mode", 1655 "sparse", 1656 "stream", 1657 "streams", 1658 "synchronize", 1659 "tunable", 1660 "utilization", 1661] 1662