1# mypy: allow-untyped-defs 2r"""This package adds support for device memory management implemented in CUDA.""" 3 4import collections 5import contextlib 6import ctypes 7import pickle 8import sys 9import warnings 10from inspect import signature 11from typing import Any, Dict, Optional, Tuple, Union 12from typing_extensions import deprecated 13 14import torch 15from torch import _C 16from torch._utils import _dummy_type 17from torch.types import Device 18 19from . import ( 20 _get_amdsmi_device_index, 21 _get_device_index, 22 _get_nvml_device_index, 23 _lazy_init, 24 is_initialized, 25) 26from ._memory_viz import memory as _memory, segments as _segments 27 28 29__all__ = [ 30 "caching_allocator_alloc", 31 "caching_allocator_delete", 32 "set_per_process_memory_fraction", 33 "empty_cache", 34 "memory_stats", 35 "memory_stats_as_nested_dict", 36 "reset_accumulated_memory_stats", 37 "reset_peak_memory_stats", 38 "reset_max_memory_allocated", 39 "reset_max_memory_cached", 40 "memory_allocated", 41 "max_memory_allocated", 42 "memory_reserved", 43 "max_memory_reserved", 44 "memory_cached", 45 "max_memory_cached", 46 "memory_snapshot", 47 "memory_summary", 48 "list_gpu_processes", 49 "mem_get_info", 50 "get_allocator_backend", 51 "CUDAPluggableAllocator", 52 "change_current_allocator", 53 "MemPool", 54 "MemPoolContext", 55 "use_mem_pool", 56] 57 58 59if not hasattr(torch._C, "_cuda_CUDAAllocator"): 60 # Define dummy base classes 61 torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator") 62 63 64if not hasattr(torch._C, "_MemPool"): 65 # Define dummy base classes 66 torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") 67 torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext") 68 torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( 69 "_cuda_beginAllocateToPool" 70 ) 71 torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type( 72 "_cuda_endAllocateCurrentStreamToPool" 73 ) 74 75from torch._C import ( # noqa: F401 76 _cuda_beginAllocateToPool, 77 _cuda_CUDAAllocator, 78 _cuda_endAllocateCurrentStreamToPool, 79 _MemPool, 80 _MemPoolContext, 81) 82 83 84def _host_allocator(): 85 _lazy_init() 86 return torch._C._cuda_cudaHostAllocator() 87 88 89@contextlib.contextmanager 90def _free_mutex(): 91 torch._C._cuda_lock_mutex() 92 try: 93 yield 94 finally: 95 torch._C._cuda_unlock_mutex() 96 97 98def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None): 99 r"""Perform a memory allocation using the CUDA memory allocator. 100 101 Memory is allocated for a given device and a stream, this 102 function is intended to be used for interoperability with other 103 frameworks. Allocated memory is released through 104 :func:`~torch.cuda.caching_allocator_delete`. 105 106 Args: 107 size (int): number of bytes to be allocated. 108 device (torch.device or int, optional): selected device. If it is 109 ``None`` the default CUDA device is used. 110 stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then 111 the default stream for the selected device is used. 112 113 .. note:: 114 See :ref:`cuda-memory-management` for more details about GPU memory 115 management. 116 """ 117 if device is None: 118 device = torch.cuda.current_device() 119 device = _get_device_index(device) 120 if stream is None: 121 stream = torch.cuda.current_stream(device) 122 if isinstance(stream, torch.cuda.streams.Stream): 123 stream = stream.cuda_stream 124 if not isinstance(stream, int): 125 raise TypeError( 126 "Invalid type for stream argument, must be " 127 "`torch.cuda.Stream` or `int` representing a pointer " 128 "to a existing stream" 129 ) 130 with torch.cuda.device(device): 131 return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream) 132 133 134def caching_allocator_delete(mem_ptr): 135 r"""Delete memory allocated using the CUDA memory allocator. 136 137 Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`. 138 is freed here. The associated device and stream are tracked inside 139 the allocator. 140 141 Args: 142 mem_ptr (int): memory address to be freed by the allocator. 143 144 .. note:: 145 See :ref:`cuda-memory-management` for more details about GPU memory 146 management. 147 """ 148 torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr) 149 150 151def set_per_process_memory_fraction( 152 fraction, device: Union[Device, int] = None 153) -> None: 154 r"""Set memory fraction for a process. 155 156 The fraction is used to limit an caching allocator to allocated memory on a CUDA device. 157 The allowed value equals the total visible memory multiplied fraction. 158 If trying to allocate more than the allowed value in a process, will raise an out of 159 memory error in allocator. 160 161 Args: 162 fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction. 163 device (torch.device or int, optional): selected device. If it is 164 ``None`` the default CUDA device is used. 165 .. note:: 166 In general, the total available free memory is less than the total capacity. 167 """ 168 _lazy_init() 169 if device is None: 170 device = torch.cuda.current_device() 171 device = _get_device_index(device) 172 if not isinstance(fraction, float): 173 raise TypeError("Invalid type for fraction argument, must be `float`") 174 if fraction < 0 or fraction > 1: 175 raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1") 176 177 torch._C._cuda_setMemoryFraction(fraction, device) 178 179 180def empty_cache() -> None: 181 r"""Release all unoccupied cached memory currently held by the caching 182 allocator so that those can be used in other GPU application and visible in 183 `nvidia-smi`. 184 185 .. note:: 186 :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU 187 memory available for PyTorch. However, it may help reduce fragmentation 188 of GPU memory in certain cases. See :ref:`cuda-memory-management` for 189 more details about GPU memory management. 190 """ 191 if is_initialized(): 192 torch._C._cuda_emptyCache() 193 194 195def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]: 196 r"""Return a dictionary of CUDA memory allocator statistics for a given device. 197 198 The return value of this function is a dictionary of statistics, each of 199 which is a non-negative integer. 200 201 Core statistics: 202 203 - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 204 number of allocation requests received by the memory allocator. 205 - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 206 amount of allocated memory. 207 - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 208 number of reserved segments from ``cudaMalloc()``. 209 - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 210 amount of reserved memory. 211 - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 212 number of active memory blocks. 213 - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 214 amount of active memory. 215 - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 216 number of inactive, non-releasable memory blocks. 217 - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 218 amount of inactive, non-releasable memory. 219 220 For these core statistics, values are broken down as follows. 221 222 Pool type: 223 224 - ``all``: combined statistics across all memory pools. 225 - ``large_pool``: statistics for the large allocation pool 226 (as of October 2019, for size >= 1MB allocations). 227 - ``small_pool``: statistics for the small allocation pool 228 (as of October 2019, for size < 1MB allocations). 229 230 Metric type: 231 232 - ``current``: current value of this metric. 233 - ``peak``: maximum value of this metric. 234 - ``allocated``: historical total increase in this metric. 235 - ``freed``: historical total decrease in this metric. 236 237 In addition to the core statistics, we also provide some simple event 238 counters: 239 240 - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that 241 result in a cache flush and retry. 242 - ``"num_ooms"``: number of out-of-memory errors thrown. 243 - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. 244 - ``"num_device_alloc"``: number of CUDA allocation calls. This includes both 245 cuMemMap and cudaMalloc. 246 - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap 247 and cudaFree. 248 249 The caching allocator can be configured via ENV to not split blocks larger than a 250 defined size (see Memory Management section of the Cuda Semantics documentation). 251 This helps avoid memory fragmentation but may have a performance 252 penalty. Additional outputs to assist with tuning and evaluating impact: 253 254 - ``"max_split_size"``: blocks above this size will not be split. 255 - ``"oversize_allocations.{current,peak,allocated,freed}"``: 256 number of over-size allocation requests received by the memory allocator. 257 - ``"oversize_segments.{current,peak,allocated,freed}"``: 258 number of over-size reserved segments from ``cudaMalloc()``. 259 260 The caching allocator can be configured via ENV to round memory allocations in order 261 to reduce fragmentation. Sometimes the overhead from rounding can be higher than 262 the fragmentation it helps reduce. The following stat can be used to check if 263 rounding adds too much overhead: 264 265 - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: 266 memory requested by client code, compare this with allocated_bytes to check if 267 allocation rounding adds too much overhead. 268 269 Args: 270 device (torch.device or int, optional): selected device. Returns 271 statistics for the current device, given by :func:`~torch.cuda.current_device`, 272 if :attr:`device` is ``None`` (default). 273 274 .. note:: 275 See :ref:`cuda-memory-management` for more details about GPU memory 276 management. 277 278 .. note:: 279 With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not 280 meaningful, and are always reported as zero. 281 """ 282 result = [] 283 284 def _recurse_add_to_result(prefix, obj): 285 if isinstance(obj, dict): 286 if len(prefix) > 0: 287 prefix += "." 288 for k, v in obj.items(): 289 _recurse_add_to_result(prefix + k, v) 290 else: 291 result.append((prefix, obj)) 292 293 stats = memory_stats_as_nested_dict(device=device) 294 _recurse_add_to_result("", stats) 295 result.sort() 296 297 return collections.OrderedDict(result) 298 299 300def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]: 301 r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary.""" 302 if not is_initialized(): 303 return {} 304 device = _get_device_index(device, optional=True) 305 return torch._C._cuda_memoryStats(device) 306 307 308def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None: 309 r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator. 310 311 See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to 312 the `"allocated"` and `"freed"` keys in each individual stat dict, as well as 313 `"num_alloc_retries"` and `"num_ooms"`. 314 315 Args: 316 device (torch.device or int, optional): selected device. Returns 317 statistic for the current device, given by :func:`~torch.cuda.current_device`, 318 if :attr:`device` is ``None`` (default). 319 320 .. note:: 321 See :ref:`cuda-memory-management` for more details about GPU memory 322 management. 323 """ 324 device = _get_device_index(device, optional=True) 325 return torch._C._cuda_resetAccumulatedMemoryStats(device) 326 327 328def reset_peak_memory_stats(device: Union[Device, int] = None) -> None: 329 r"""Reset the "peak" stats tracked by the CUDA memory allocator. 330 331 See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the 332 `"peak"` key in each individual stat dict. 333 334 Args: 335 device (torch.device or int, optional): selected device. Returns 336 statistic for the current device, given by :func:`~torch.cuda.current_device`, 337 if :attr:`device` is ``None`` (default). 338 339 .. note:: 340 See :ref:`cuda-memory-management` for more details about GPU memory 341 management. 342 """ 343 device = _get_device_index(device, optional=True) 344 return torch._C._cuda_resetPeakMemoryStats(device) 345 346 347def reset_max_memory_allocated(device: Union[Device, int] = None) -> None: 348 r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device. 349 350 See :func:`~torch.cuda.max_memory_allocated` for details. 351 352 Args: 353 device (torch.device or int, optional): selected device. Returns 354 statistic for the current device, given by :func:`~torch.cuda.current_device`, 355 if :attr:`device` is ``None`` (default). 356 357 .. warning:: 358 This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets 359 /all/ peak memory stats. 360 361 .. note:: 362 See :ref:`cuda-memory-management` for more details about GPU memory 363 management. 364 """ 365 warnings.warn( 366 "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, " 367 "which resets /all/ peak memory stats.", 368 FutureWarning, 369 ) 370 return reset_peak_memory_stats(device=device) 371 372 373def reset_max_memory_cached(device: Union[Device, int] = None) -> None: 374 r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. 375 376 See :func:`~torch.cuda.max_memory_cached` for details. 377 378 Args: 379 device (torch.device or int, optional): selected device. Returns 380 statistic for the current device, given by :func:`~torch.cuda.current_device`, 381 if :attr:`device` is ``None`` (default). 382 383 .. warning:: 384 This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets 385 /all/ peak memory stats. 386 387 .. note:: 388 See :ref:`cuda-memory-management` for more details about GPU memory 389 management. 390 """ 391 warnings.warn( 392 "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, " 393 "which resets /all/ peak memory stats.", 394 FutureWarning, 395 ) 396 return reset_peak_memory_stats(device=device) 397 398 399def memory_allocated(device: Union[Device, int] = None) -> int: 400 r"""Return the current GPU memory occupied by tensors in bytes for a given device. 401 402 Args: 403 device (torch.device or int, optional): selected device. Returns 404 statistic for the current device, given by :func:`~torch.cuda.current_device`, 405 if :attr:`device` is ``None`` (default). 406 407 .. note:: 408 This is likely less than the amount shown in `nvidia-smi` since some 409 unused memory can be held by the caching allocator and some context 410 needs to be created on GPU. See :ref:`cuda-memory-management` for more 411 details about GPU memory management. 412 """ 413 return memory_stats(device=device).get("allocated_bytes.all.current", 0) 414 415 416def max_memory_allocated(device: Union[Device, int] = None) -> int: 417 r"""Return the maximum GPU memory occupied by tensors in bytes for a given device. 418 419 By default, this returns the peak allocated memory since the beginning of 420 this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to 421 reset the starting point in tracking this metric. For example, these two 422 functions can measure the peak allocated memory usage of each iteration in a 423 training loop. 424 425 Args: 426 device (torch.device or int, optional): selected device. Returns 427 statistic for the current device, given by :func:`~torch.cuda.current_device`, 428 if :attr:`device` is ``None`` (default). 429 430 .. note:: 431 See :ref:`cuda-memory-management` for more details about GPU memory 432 management. 433 """ 434 return memory_stats(device=device).get("allocated_bytes.all.peak", 0) 435 436 437def memory_reserved(device: Union[Device, int] = None) -> int: 438 r"""Return the current GPU memory managed by the caching allocator in bytes for a given device. 439 440 Args: 441 device (torch.device or int, optional): selected device. Returns 442 statistic for the current device, given by :func:`~torch.cuda.current_device`, 443 if :attr:`device` is ``None`` (default). 444 445 .. note:: 446 See :ref:`cuda-memory-management` for more details about GPU memory 447 management. 448 """ 449 return memory_stats(device=device).get("reserved_bytes.all.current", 0) 450 451 452def max_memory_reserved(device: Union[Device, int] = None) -> int: 453 r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device. 454 455 By default, this returns the peak cached memory since the beginning of this 456 program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset 457 the starting point in tracking this metric. For example, these two functions 458 can measure the peak cached memory amount of each iteration in a training 459 loop. 460 461 Args: 462 device (torch.device or int, optional): selected device. Returns 463 statistic for the current device, given by :func:`~torch.cuda.current_device`, 464 if :attr:`device` is ``None`` (default). 465 466 .. note:: 467 See :ref:`cuda-memory-management` for more details about GPU memory 468 management. 469 """ 470 return memory_stats(device=device).get("reserved_bytes.all.peak", 0) 471 472 473@deprecated( 474 "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`", 475 category=FutureWarning, 476) 477def memory_cached(device: Union[Device, int] = None) -> int: 478 r"""Deprecated; see :func:`~torch.cuda.memory_reserved`.""" 479 return memory_reserved(device=device) 480 481 482@deprecated( 483 "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`", 484 category=FutureWarning, 485) 486def max_memory_cached(device: Union[Device, int] = None) -> int: 487 r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`.""" 488 return max_memory_reserved(device=device) 489 490 491def memory_snapshot(): 492 r"""Return a snapshot of the CUDA memory allocator state across all devices. 493 494 Interpreting the output of this function requires familiarity with the 495 memory allocator internals. 496 497 .. note:: 498 See :ref:`cuda-memory-management` for more details about GPU memory 499 management. 500 """ 501 return torch._C._cuda_memorySnapshot()["segments"] 502 503 504def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str: 505 r"""Return a human-readable printout of the current memory allocator statistics for a given device. 506 507 This can be useful to display periodically during training, or when 508 handling out-of-memory exceptions. 509 510 Args: 511 device (torch.device or int, optional): selected device. Returns 512 printout for the current device, given by :func:`~torch.cuda.current_device`, 513 if :attr:`device` is ``None`` (default). 514 abbreviated (bool, optional): whether to return an abbreviated summary 515 (default: False). 516 517 .. note:: 518 See :ref:`cuda-memory-management` for more details about GPU memory 519 management. 520 """ 521 device = _get_device_index(device, optional=True) 522 stats = memory_stats(device=device) 523 524 def _format_size(sz, pref_sz): 525 prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"] 526 prefix = prefixes[0] 527 for new_prefix in prefixes[1:]: 528 if pref_sz < 768 * 1024: 529 break 530 prefix = new_prefix 531 sz //= 1024 532 pref_sz /= 1024 533 return f"{sz:6d} {prefix}" 534 535 def _format_count(cnt, pref_cnt): 536 prefixes = [" ", "K", "M"] 537 prefix = prefixes[0] 538 for new_prefix in prefixes[1:]: 539 if pref_cnt < 750 * 1000: 540 break 541 prefix = new_prefix 542 cnt //= 1000 543 pref_cnt /= 1000 544 return f"{cnt:7d} {prefix} " 545 546 metrics_to_display = [ 547 ("allocated_bytes", "Allocated memory", _format_size), 548 ("active_bytes", "Active memory", _format_size), 549 ("requested_bytes", "Requested memory", _format_size), 550 ("reserved_bytes", "GPU reserved memory", _format_size), 551 ("inactive_split_bytes", "Non-releasable memory", _format_size), 552 ("allocation", "Allocations", _format_count), 553 ("active", "Active allocs", _format_count), 554 ("segment", "GPU reserved segments", _format_count), 555 ("inactive_split", "Non-releasable allocs", _format_count), 556 ] 557 558 lines = [] 559 lines.append("=" * 75) 560 lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ") 561 lines.append("-" * 75) 562 lines.append( 563 " {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} " 564 ) 565 lines.append("=" * 75) 566 lines.append( 567 " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed " 568 ) 569 570 for metric_key, metric_name, formatter in metrics_to_display: 571 lines.append("-" * 75) 572 submetrics = [("all", metric_name)] 573 if not abbreviated: 574 submetrics.append(("large_pool", " from large pool")) 575 submetrics.append(("small_pool", " from small pool")) 576 577 current_prefval, peak_prefval, allocated_prefval, freed_prefval = ( 578 None, 579 None, 580 None, 581 None, 582 ) 583 584 for submetric_key, submetric_name in submetrics: 585 prefix = metric_key + "." + submetric_key + "." 586 587 current = stats[prefix + "current"] 588 peak = stats[prefix + "peak"] 589 allocated = stats[prefix + "allocated"] 590 freed = stats[prefix + "freed"] 591 592 if current_prefval is None: 593 current_prefval = current 594 peak_prefval = peak 595 allocated_prefval = allocated 596 freed_prefval = freed 597 598 lines.append( 599 f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | " 600 f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ", 601 ) 602 603 metrics_to_display = [ 604 ("oversize_allocations", "Oversize allocations", _format_count), 605 ("oversize_segments", "Oversize GPU segments", _format_count), 606 ] 607 608 for metric_key, metric_name, formatter in metrics_to_display: 609 lines.append("-" * 75) 610 611 prefix = metric_key + "." 612 613 current = stats[prefix + "current"] 614 peak = stats[prefix + "peak"] 615 allocated = stats[prefix + "allocated"] 616 freed = stats[prefix + "freed"] 617 618 lines.append( 619 f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | " 620 f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ", 621 ) 622 623 lines.append("=" * 75) 624 625 fmt_dict = {"_": "", "device": device} 626 for k, v in stats.items(): 627 fmt_dict[k.replace(".", "-")] = v 628 return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n" 629 630 631def list_gpu_processes(device: Union[Device, int] = None) -> str: 632 r"""Return a human-readable printout of the running processes and their GPU memory use for a given device. 633 634 This can be useful to display periodically during training, or when 635 handling out-of-memory exceptions. 636 637 Args: 638 device (torch.device or int, optional): selected device. Returns 639 printout for the current device, given by :func:`~torch.cuda.current_device`, 640 if :attr:`device` is ``None`` (default). 641 """ 642 if not torch.version.hip: 643 try: 644 import pynvml # type: ignore[import] 645 except ModuleNotFoundError: 646 return "pynvml module not found, please install pynvml" 647 from pynvml import NVMLError_DriverNotLoaded 648 649 try: 650 pynvml.nvmlInit() 651 except NVMLError_DriverNotLoaded: 652 return "cuda driver can't be loaded, is cuda enabled?" 653 654 device = _get_nvml_device_index(device) 655 handle = pynvml.nvmlDeviceGetHandleByIndex(device) 656 procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) 657 else: 658 try: 659 import amdsmi # type: ignore[import] 660 except ModuleNotFoundError: 661 return "amdsmi module not found, please install amdsmi" 662 try: 663 amdsmi.amdsmi_init() # type: ignore[attr-defined] 664 except amdsmi.AmdSmiException: # type: ignore[attr-defined] 665 return "amdsmi driver can't be loaded, is ROCm installed?" 666 667 device = _get_amdsmi_device_index(device) 668 669 try: 670 handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined] 671 procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined] 672 except amdsmi.AmdSmiException: # type: ignore[attr-defined] 673 return "amdsmi cannot list processes from other users" 674 675 lines = [] 676 lines.append(f"GPU:{device}") 677 if len(procs) == 0: 678 lines.append("no processes are running") 679 for p in procs: 680 if not torch.version.hip: 681 mem = p.usedGpuMemory / (1024 * 1024) 682 pid = p.pid 683 else: 684 try: 685 proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined] 686 except AttributeError: 687 # https://github.com/ROCm/amdsmi/commit/c551c3caedbd903ba828e7fdffa5b56d475a15e7 688 # is a BC-breaking change that removes amdsmi_get_gpu_process_info API from amdsmi 689 proc_info = p 690 mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024) 691 pid = proc_info["pid"] 692 lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory") 693 return "\n".join(lines) 694 695 696def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: 697 r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo. 698 699 Args: 700 device (torch.device or int or str, optional): selected device. Returns 701 statistic for the current device, given by :func:`~torch.cuda.current_device`, 702 if :attr:`device` is ``None`` (default) or if the device index is not specified. 703 704 .. note:: 705 See :ref:`cuda-memory-management` for more 706 details about GPU memory management. 707 """ 708 if device is None: 709 device = torch.cuda.current_device() 710 # optional=True allows `device = torch.device('cuda')` for which device.index is None 711 device = _get_device_index(device, optional=True) 712 return torch.cuda.cudart().cudaMemGetInfo(device) 713 714 715def _record_memory_history_legacy( 716 enabled: bool, 717 record_context=True, 718 trace_alloc_max_entries=1, 719 trace_alloc_record_context=False, 720 device: Union[Device, int] = None, 721 record_context_cpp=False, 722): 723 _C._cuda_record_memory_history_legacy( 724 enabled, 725 record_context, 726 trace_alloc_max_entries, 727 trace_alloc_record_context, 728 record_context_cpp, 729 ) 730 731 732def _record_memory_history(enabled="all", *args, **kwargs): 733 """Enable recording of stack traces associated with memory 734 allocations, so you can tell what allocated any piece of memory in 735 :func:`torch.cuda.memory._snapshot()`. 736 737 In addition too keeping stack traces with each current allocation and free, 738 this will also enable recording of a history of all alloc/free events. 739 740 Use :func:`torch.cuda.memory._snapshot()` to retrieve this information, 741 and the tools in `_memory_viz.py` to visualize snapshots. 742 743 The Python trace collection is fast (2us per trace), so you may consider 744 enabling this on production jobs if you anticipate ever having to debug 745 memory issues. 746 747 C++ trace collection is also fast (~50ns/frame), which for many typical programs 748 works out to ~2us per trace, but can vary depending on stack depth. 749 750 Args: 751 enabled (Literal[None, "state", "all"], optional): 752 `None`, disable recording memory history. 753 `"state"`, keep information for currenly allocated memory. 754 `"all"`, additionally keep a history of all alloc/free calls. 755 Defaults to "all". 756 context (Literal[None, "state", "alloc", "all"], optional): 757 `None`, Do not record any tracebacks. 758 `"state"`, Record tracebacks for currently allocated memory. 759 `"alloc"`, additionally keep tracebacks for alloc calls. 760 `"all"`, additionally keep tracebacks for free calls. 761 Defaults to "all". 762 stacks (Literal["python", "all"], optional): 763 `"python"`, include Python, TorchScript, and inductor frames in tracebacks 764 `"all"`, additionally include C++ frames 765 Defaults to "all". 766 max_entries (int, optional): Keep a maximum of `max_entries` 767 alloc/free events in the recorded history recorded. 768 """ 769 if isinstance(enabled, bool): 770 return _record_memory_history_legacy(enabled, *args, **kwargs) 771 else: 772 return _record_memory_history_impl(enabled, *args, **kwargs) 773 774 775def _record_memory_history_impl( 776 enabled: Optional[str] = "all", 777 context: Optional[str] = "all", 778 stacks: str = "all", 779 max_entries: int = sys.maxsize, 780 device: Union[Device, int] = None, 781): 782 _C._cuda_record_memory_history(enabled, context, stacks, max_entries) 783 784 785_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined] 786 787 788def _snapshot(device: Union[Device, int] = None): 789 """Save a snapshot of CUDA memory state at the time it was called. 790 791 The state is represented as a dictionary with the following structure. 792 793 .. code-block:: python 794 795 class Snapshot(TypedDict): 796 segments : List[Segment] 797 device_traces: List[List[TraceEntry]] 798 799 class Segment(TypedDict): 800 # Segments are memory returned from a cudaMalloc call. 801 # The size of reserved memory is the sum of all Segments. 802 # Segments are cached and reused for future allocations. 803 # If the reuse is smaller than the segment, the segment 804 # is split into more then one Block. 805 # empty_cache() frees Segments that are entirely inactive. 806 address: int 807 total_size: int # cudaMalloc'd size of segment 808 stream: int 809 segment_type: Literal['small', 'large'] # 'large' (>1MB) 810 allocated_size: int # size of memory in use 811 active_size: int # size of memory in use or in active_awaiting_free state 812 blocks : List[Block] 813 814 class Block(TypedDict): 815 # A piece of memory returned from the allocator, or 816 # current cached but inactive. 817 size: int 818 requested_size: int # size requested during malloc, may be smaller than 819 # size due to rounding 820 address: int 821 state: Literal['active_allocated', # used by a tensor 822 'active_awaiting_free', # waiting for another stream to finish using 823 # this, then it will become free 824 'inactive',] # free for reuse 825 frames: List[Frame] # stack trace from where the allocation occurred 826 827 class Frame(TypedDict): 828 filename: str 829 line: int 830 name: str 831 832 class TraceEntry(TypedDict): 833 # When `torch.cuda.memory._record_memory_history()` is enabled, 834 # the snapshot will contain TraceEntry objects that record each 835 # action the allocator took. 836 action: Literal[ 837 'alloc' # memory allocated 838 'free_requested', # the allocated received a call to free memory 839 'free_completed', # the memory that was requested to be freed is now 840 # able to be used in future allocation calls 841 'segment_alloc', # the caching allocator ask cudaMalloc for more memory 842 # and added it as a segment in its cache 843 'segment_free', # the caching allocator called cudaFree to return memory 844 # to cuda possibly trying free up memory to 845 # allocate more segments or because empty_caches was called 846 'oom', # the allocator threw an OOM exception. 'size' is 847 # the requested number of bytes that did not succeed 848 'snapshot' # the allocator generated a memory snapshot 849 # useful to coorelate a previously taken 850 # snapshot with this trace 851 ] 852 addr: int # not present for OOM 853 frames: List[Frame] 854 size: int 855 stream: int 856 device_free: int # only present for OOM, the amount of 857 # memory cuda still reports to be free 858 859 Returns: 860 The Snapshot dictionary object 861 """ 862 return _C._cuda_memorySnapshot() 863 864 865def _dump_snapshot(filename="dump_snapshot.pickle"): 866 """ 867 Save a pickled version of the `torch.memory._snapshot()` dictionary to a file. 868 869 This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz 870 871 Args: 872 filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". 873 """ 874 s = _snapshot() 875 with open(filename, "wb") as f: 876 pickle.dump(s, f) 877 878 879def _save_segment_usage(filename="output.svg", snapshot=None): 880 if snapshot is None: 881 snapshot = _snapshot() 882 with open(filename, "w") as f: 883 f.write(_segments(snapshot)) 884 885 886def _save_memory_usage(filename="output.svg", snapshot=None): 887 if snapshot is None: 888 snapshot = _snapshot() 889 with open(filename, "w") as f: 890 f.write(_memory(snapshot)) 891 892 893def _set_allocator_settings(env: str): 894 return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env) 895 896 897def get_allocator_backend() -> str: 898 r"""Return a string describing the active allocator backend as set by 899 ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are 900 ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync`` 901 (CUDA's built-in asynchronous allocator). 902 903 .. note:: 904 See :ref:`cuda-memory-management` for details on choosing the allocator backend. 905 """ 906 return torch._C._cuda_getAllocatorBackend() 907 908 909class _CUDAAllocator: 910 r"""Wrapper over internal CUDA memory allocators.""" 911 912 def __init__(self, allocator: torch._C._cuda_CUDAAllocator): 913 self._allocator = allocator 914 915 def allocator(self): 916 return self._allocator 917 918 919class CUDAPluggableAllocator(_CUDAAllocator): 920 r"""CUDA memory allocator loaded from a so file.""" 921 922 def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str): 923 r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes. 924 925 To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function. 926 927 Args: 928 path_to_so_file(str): Path in the filesystem to the `.so` file containing 929 the allocator functions 930 alloc_fn_name(str): Name of the function to perform the memory allocation 931 in the so file. The signature must be: 932 void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream); 933 free_fn_name(str): Name of the function to perform the memory release 934 in the so file. The signature must be: 935 void free_fn_name(void* ptr, size_t size, cudaStream_t stream); 936 937 .. warning:: 938 This is currently supported only in unix OSs 939 940 .. note:: 941 See :ref:`cuda-memory-management` for details on creating and using a custom allocator 942 """ 943 allocator = ctypes.CDLL(path_to_so_file) 944 alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value 945 free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value 946 assert alloc_fn is not None 947 assert free_fn is not None 948 self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn) 949 950 951def change_current_allocator(allocator: _CUDAAllocator) -> None: 952 r"""Change the currently used memory allocator to be the one provided. 953 954 If the current allocator has already been used/initialized, this function will error. 955 956 957 Args: 958 allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one. 959 .. note:: 960 See :ref:`cuda-memory-management` for details on creating and using a custom allocator 961 """ 962 torch._C._cuda_changeCurrentAllocator(allocator.allocator()) 963 964 965def _get_current_allocator() -> _CUDAAllocator: 966 r"""Return the allocator being currently used. 967 968 .. note:: 969 See :ref:`cuda-memory-management` for details on creating and using a custom allocator 970 """ 971 return _CUDAAllocator(torch._C._cuda_getAllocator()) 972 973 974class MemPool(_MemPool): 975 r"""MemPool represents a pool of memory in a caching allocator. Currently, 976 it's just the ID of the pool object maintained in the CUDACachingAllocator. 977 978 Args: 979 allocator(torch._C._cuda_CUDAAllocator, optional): a 980 torch._C._cuda_CUDAAllocator object that can be used to 981 define how memory gets allocated in the pool. If :attr:`allocator` 982 is ``None`` (default), memory allocation follows the default/ 983 current configuration of the CUDACachingAllocator. 984 985 """ 986 987 def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None): 988 super().__init__(allocator, True) 989 990 @property 991 def id(self) -> Tuple[int, int]: 992 r"""Returns the ID of this pool as a tuple of two ints.""" 993 return super().id 994 995 @property 996 def allocator(self) -> Optional[_cuda_CUDAAllocator]: 997 r"""Returns the allocator this MemPool routes allocations to""" 998 return super().allocator 999 1000 1001class MemPoolContext(_MemPoolContext): 1002 r"""MemPoolContext holds the currently active pool and stashes the previous 1003 pool. On deletion it makes the previous pool active. 1004 1005 Args: 1006 pool(torch.cuda.MemPool): a MemPool object to be made active so that 1007 allocations route to this pool. 1008 1009 """ 1010 1011 def __init__(self, pool: MemPool): 1012 super().__init__(pool) 1013 1014 @staticmethod 1015 def active_pool() -> Optional[_MemPool]: 1016 r"""Returns the active MemPool""" 1017 return _MemPoolContext.active_pool() 1018 1019 1020@contextlib.contextmanager 1021def use_mem_pool(pool: MemPool, device: Union[Device, int] = None): 1022 r"""A context manager that routes allocations to a given pool. 1023 1024 Args: 1025 pool(torch.cuda.MemPool): a MemPool object to be made active so that 1026 allocations route to this pool. 1027 device (torch.device or int, optional): selected device. Uses MemPool on 1028 the current device, given by :func:`~torch.cuda.current_device`, 1029 if :attr:`device` is ``None`` (default). 1030 1031 """ 1032 ctx = MemPoolContext(pool) 1033 device_index = ( 1034 torch.cuda.current_device() if device is None else _get_device_index(device) 1035 ) 1036 _cuda_beginAllocateToPool(device_index, pool.id) 1037 try: 1038 yield 1039 finally: 1040 _cuda_endAllocateCurrentStreamToPool(device_index, pool.id) 1041 del ctx 1042