xref: /aosp_15_r20/external/pytorch/torch/cuda/memory.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""This package adds support for device memory management implemented in CUDA."""
3
4import collections
5import contextlib
6import ctypes
7import pickle
8import sys
9import warnings
10from inspect import signature
11from typing import Any, Dict, Optional, Tuple, Union
12from typing_extensions import deprecated
13
14import torch
15from torch import _C
16from torch._utils import _dummy_type
17from torch.types import Device
18
19from . import (
20    _get_amdsmi_device_index,
21    _get_device_index,
22    _get_nvml_device_index,
23    _lazy_init,
24    is_initialized,
25)
26from ._memory_viz import memory as _memory, segments as _segments
27
28
29__all__ = [
30    "caching_allocator_alloc",
31    "caching_allocator_delete",
32    "set_per_process_memory_fraction",
33    "empty_cache",
34    "memory_stats",
35    "memory_stats_as_nested_dict",
36    "reset_accumulated_memory_stats",
37    "reset_peak_memory_stats",
38    "reset_max_memory_allocated",
39    "reset_max_memory_cached",
40    "memory_allocated",
41    "max_memory_allocated",
42    "memory_reserved",
43    "max_memory_reserved",
44    "memory_cached",
45    "max_memory_cached",
46    "memory_snapshot",
47    "memory_summary",
48    "list_gpu_processes",
49    "mem_get_info",
50    "get_allocator_backend",
51    "CUDAPluggableAllocator",
52    "change_current_allocator",
53    "MemPool",
54    "MemPoolContext",
55    "use_mem_pool",
56]
57
58
59if not hasattr(torch._C, "_cuda_CUDAAllocator"):
60    # Define dummy base classes
61    torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
62
63
64if not hasattr(torch._C, "_MemPool"):
65    # Define dummy base classes
66    torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
67    torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
68    torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
69        "_cuda_beginAllocateToPool"
70    )
71    torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type(
72        "_cuda_endAllocateCurrentStreamToPool"
73    )
74
75from torch._C import (  # noqa: F401
76    _cuda_beginAllocateToPool,
77    _cuda_CUDAAllocator,
78    _cuda_endAllocateCurrentStreamToPool,
79    _MemPool,
80    _MemPoolContext,
81)
82
83
84def _host_allocator():
85    _lazy_init()
86    return torch._C._cuda_cudaHostAllocator()
87
88
89@contextlib.contextmanager
90def _free_mutex():
91    torch._C._cuda_lock_mutex()
92    try:
93        yield
94    finally:
95        torch._C._cuda_unlock_mutex()
96
97
98def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
99    r"""Perform a memory allocation using the CUDA memory allocator.
100
101    Memory is allocated for a given device and a stream, this
102    function is intended to be used for interoperability with other
103    frameworks. Allocated memory is released through
104    :func:`~torch.cuda.caching_allocator_delete`.
105
106    Args:
107        size (int): number of bytes to be allocated.
108        device (torch.device or int, optional): selected device. If it is
109            ``None`` the default CUDA device is used.
110        stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
111            the default stream for the selected device is used.
112
113    .. note::
114        See :ref:`cuda-memory-management` for more details about GPU memory
115        management.
116    """
117    if device is None:
118        device = torch.cuda.current_device()
119    device = _get_device_index(device)
120    if stream is None:
121        stream = torch.cuda.current_stream(device)
122    if isinstance(stream, torch.cuda.streams.Stream):
123        stream = stream.cuda_stream
124    if not isinstance(stream, int):
125        raise TypeError(
126            "Invalid type for stream argument, must be "
127            "`torch.cuda.Stream` or `int` representing a pointer "
128            "to a existing stream"
129        )
130    with torch.cuda.device(device):
131        return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
132
133
134def caching_allocator_delete(mem_ptr):
135    r"""Delete memory allocated using the CUDA memory allocator.
136
137    Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
138    is freed here. The associated device and stream are tracked inside
139    the allocator.
140
141    Args:
142        mem_ptr (int): memory address to be freed by the allocator.
143
144    .. note::
145        See :ref:`cuda-memory-management` for more details about GPU memory
146        management.
147    """
148    torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
149
150
151def set_per_process_memory_fraction(
152    fraction, device: Union[Device, int] = None
153) -> None:
154    r"""Set memory fraction for a process.
155
156    The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
157    The allowed value equals the total visible memory multiplied fraction.
158    If trying to allocate more than the allowed value in a process, will raise an out of
159    memory error in allocator.
160
161    Args:
162        fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
163        device (torch.device or int, optional): selected device. If it is
164            ``None`` the default CUDA device is used.
165    .. note::
166        In general, the total available free memory is less than the total capacity.
167    """
168    _lazy_init()
169    if device is None:
170        device = torch.cuda.current_device()
171    device = _get_device_index(device)
172    if not isinstance(fraction, float):
173        raise TypeError("Invalid type for fraction argument, must be `float`")
174    if fraction < 0 or fraction > 1:
175        raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1")
176
177    torch._C._cuda_setMemoryFraction(fraction, device)
178
179
180def empty_cache() -> None:
181    r"""Release all unoccupied cached memory currently held by the caching
182    allocator so that those can be used in other GPU application and visible in
183    `nvidia-smi`.
184
185    .. note::
186        :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
187        memory available for PyTorch. However, it may help reduce fragmentation
188        of GPU memory in certain cases. See :ref:`cuda-memory-management` for
189        more details about GPU memory management.
190    """
191    if is_initialized():
192        torch._C._cuda_emptyCache()
193
194
195def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
196    r"""Return a dictionary of CUDA memory allocator statistics for a given device.
197
198    The return value of this function is a dictionary of statistics, each of
199    which is a non-negative integer.
200
201    Core statistics:
202
203    - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
204      number of allocation requests received by the memory allocator.
205    - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
206      amount of allocated memory.
207    - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
208      number of reserved segments from ``cudaMalloc()``.
209    - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
210      amount of reserved memory.
211    - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
212      number of active memory blocks.
213    - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
214      amount of active memory.
215    - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
216      number of inactive, non-releasable memory blocks.
217    - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
218      amount of inactive, non-releasable memory.
219
220    For these core statistics, values are broken down as follows.
221
222    Pool type:
223
224    - ``all``: combined statistics across all memory pools.
225    - ``large_pool``: statistics for the large allocation pool
226      (as of October 2019, for size >= 1MB allocations).
227    - ``small_pool``: statistics for the small allocation pool
228      (as of October 2019, for size < 1MB allocations).
229
230    Metric type:
231
232    - ``current``: current value of this metric.
233    - ``peak``: maximum value of this metric.
234    - ``allocated``: historical total increase in this metric.
235    - ``freed``: historical total decrease in this metric.
236
237    In addition to the core statistics, we also provide some simple event
238    counters:
239
240    - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
241      result in a cache flush and retry.
242    - ``"num_ooms"``: number of out-of-memory errors thrown.
243    - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls.
244    - ``"num_device_alloc"``: number of CUDA allocation calls. This includes both
245      cuMemMap and cudaMalloc.
246    - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap
247      and cudaFree.
248
249    The caching allocator can be configured via ENV to not split blocks larger than a
250    defined size (see Memory Management section of the Cuda Semantics documentation).
251    This helps avoid memory fragmentation but may have a performance
252    penalty. Additional outputs to assist with tuning and evaluating impact:
253
254    - ``"max_split_size"``: blocks above this size will not be split.
255    - ``"oversize_allocations.{current,peak,allocated,freed}"``:
256      number of over-size allocation requests received by the memory allocator.
257    - ``"oversize_segments.{current,peak,allocated,freed}"``:
258      number of over-size reserved segments from ``cudaMalloc()``.
259
260    The caching allocator can be configured via ENV to round memory allocations in order
261    to reduce fragmentation. Sometimes the overhead from rounding can be higher than
262    the fragmentation it helps reduce. The following stat can be used to check if
263    rounding adds too much overhead:
264
265    - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
266      memory requested by client code, compare this with allocated_bytes to check if
267      allocation rounding adds too much overhead.
268
269    Args:
270        device (torch.device or int, optional): selected device. Returns
271            statistics for the current device, given by :func:`~torch.cuda.current_device`,
272            if :attr:`device` is ``None`` (default).
273
274    .. note::
275        See :ref:`cuda-memory-management` for more details about GPU memory
276        management.
277
278    .. note::
279        With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
280        meaningful, and are always reported as zero.
281    """
282    result = []
283
284    def _recurse_add_to_result(prefix, obj):
285        if isinstance(obj, dict):
286            if len(prefix) > 0:
287                prefix += "."
288            for k, v in obj.items():
289                _recurse_add_to_result(prefix + k, v)
290        else:
291            result.append((prefix, obj))
292
293    stats = memory_stats_as_nested_dict(device=device)
294    _recurse_add_to_result("", stats)
295    result.sort()
296
297    return collections.OrderedDict(result)
298
299
300def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
301    r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
302    if not is_initialized():
303        return {}
304    device = _get_device_index(device, optional=True)
305    return torch._C._cuda_memoryStats(device)
306
307
308def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
309    r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator.
310
311    See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
312    the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
313    `"num_alloc_retries"` and `"num_ooms"`.
314
315    Args:
316        device (torch.device or int, optional): selected device. Returns
317            statistic for the current device, given by :func:`~torch.cuda.current_device`,
318            if :attr:`device` is ``None`` (default).
319
320    .. note::
321        See :ref:`cuda-memory-management` for more details about GPU memory
322        management.
323    """
324    device = _get_device_index(device, optional=True)
325    return torch._C._cuda_resetAccumulatedMemoryStats(device)
326
327
328def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
329    r"""Reset the "peak" stats tracked by the CUDA memory allocator.
330
331    See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
332    `"peak"` key in each individual stat dict.
333
334    Args:
335        device (torch.device or int, optional): selected device. Returns
336            statistic for the current device, given by :func:`~torch.cuda.current_device`,
337            if :attr:`device` is ``None`` (default).
338
339    .. note::
340        See :ref:`cuda-memory-management` for more details about GPU memory
341        management.
342    """
343    device = _get_device_index(device, optional=True)
344    return torch._C._cuda_resetPeakMemoryStats(device)
345
346
347def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
348    r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.
349
350    See :func:`~torch.cuda.max_memory_allocated` for details.
351
352    Args:
353        device (torch.device or int, optional): selected device. Returns
354            statistic for the current device, given by :func:`~torch.cuda.current_device`,
355            if :attr:`device` is ``None`` (default).
356
357    .. warning::
358        This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
359        /all/ peak memory stats.
360
361    .. note::
362        See :ref:`cuda-memory-management` for more details about GPU memory
363        management.
364    """
365    warnings.warn(
366        "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
367        "which resets /all/ peak memory stats.",
368        FutureWarning,
369    )
370    return reset_peak_memory_stats(device=device)
371
372
373def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
374    r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
375
376    See :func:`~torch.cuda.max_memory_cached` for details.
377
378    Args:
379        device (torch.device or int, optional): selected device. Returns
380            statistic for the current device, given by :func:`~torch.cuda.current_device`,
381            if :attr:`device` is ``None`` (default).
382
383    .. warning::
384        This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
385        /all/ peak memory stats.
386
387    .. note::
388        See :ref:`cuda-memory-management` for more details about GPU memory
389        management.
390    """
391    warnings.warn(
392        "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
393        "which resets /all/ peak memory stats.",
394        FutureWarning,
395    )
396    return reset_peak_memory_stats(device=device)
397
398
399def memory_allocated(device: Union[Device, int] = None) -> int:
400    r"""Return the current GPU memory occupied by tensors in bytes for a given device.
401
402    Args:
403        device (torch.device or int, optional): selected device. Returns
404            statistic for the current device, given by :func:`~torch.cuda.current_device`,
405            if :attr:`device` is ``None`` (default).
406
407    .. note::
408        This is likely less than the amount shown in `nvidia-smi` since some
409        unused memory can be held by the caching allocator and some context
410        needs to be created on GPU. See :ref:`cuda-memory-management` for more
411        details about GPU memory management.
412    """
413    return memory_stats(device=device).get("allocated_bytes.all.current", 0)
414
415
416def max_memory_allocated(device: Union[Device, int] = None) -> int:
417    r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
418
419    By default, this returns the peak allocated memory since the beginning of
420    this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
421    reset the starting point in tracking this metric. For example, these two
422    functions can measure the peak allocated memory usage of each iteration in a
423    training loop.
424
425    Args:
426        device (torch.device or int, optional): selected device. Returns
427            statistic for the current device, given by :func:`~torch.cuda.current_device`,
428            if :attr:`device` is ``None`` (default).
429
430    .. note::
431        See :ref:`cuda-memory-management` for more details about GPU memory
432        management.
433    """
434    return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
435
436
437def memory_reserved(device: Union[Device, int] = None) -> int:
438    r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
439
440    Args:
441        device (torch.device or int, optional): selected device. Returns
442            statistic for the current device, given by :func:`~torch.cuda.current_device`,
443            if :attr:`device` is ``None`` (default).
444
445    .. note::
446        See :ref:`cuda-memory-management` for more details about GPU memory
447        management.
448    """
449    return memory_stats(device=device).get("reserved_bytes.all.current", 0)
450
451
452def max_memory_reserved(device: Union[Device, int] = None) -> int:
453    r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
454
455    By default, this returns the peak cached memory since the beginning of this
456    program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
457    the starting point in tracking this metric. For example, these two functions
458    can measure the peak cached memory amount of each iteration in a training
459    loop.
460
461    Args:
462        device (torch.device or int, optional): selected device. Returns
463            statistic for the current device, given by :func:`~torch.cuda.current_device`,
464            if :attr:`device` is ``None`` (default).
465
466    .. note::
467        See :ref:`cuda-memory-management` for more details about GPU memory
468        management.
469    """
470    return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
471
472
473@deprecated(
474    "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`",
475    category=FutureWarning,
476)
477def memory_cached(device: Union[Device, int] = None) -> int:
478    r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
479    return memory_reserved(device=device)
480
481
482@deprecated(
483    "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`",
484    category=FutureWarning,
485)
486def max_memory_cached(device: Union[Device, int] = None) -> int:
487    r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
488    return max_memory_reserved(device=device)
489
490
491def memory_snapshot():
492    r"""Return a snapshot of the CUDA memory allocator state across all devices.
493
494    Interpreting the output of this function requires familiarity with the
495    memory allocator internals.
496
497    .. note::
498        See :ref:`cuda-memory-management` for more details about GPU memory
499        management.
500    """
501    return torch._C._cuda_memorySnapshot()["segments"]
502
503
504def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
505    r"""Return a human-readable printout of the current memory allocator statistics for a given device.
506
507    This can be useful to display periodically during training, or when
508    handling out-of-memory exceptions.
509
510    Args:
511        device (torch.device or int, optional): selected device. Returns
512            printout for the current device, given by :func:`~torch.cuda.current_device`,
513            if :attr:`device` is ``None`` (default).
514        abbreviated (bool, optional): whether to return an abbreviated summary
515            (default: False).
516
517    .. note::
518        See :ref:`cuda-memory-management` for more details about GPU memory
519        management.
520    """
521    device = _get_device_index(device, optional=True)
522    stats = memory_stats(device=device)
523
524    def _format_size(sz, pref_sz):
525        prefixes = ["B  ", "KiB", "MiB", "GiB", "TiB", "PiB"]
526        prefix = prefixes[0]
527        for new_prefix in prefixes[1:]:
528            if pref_sz < 768 * 1024:
529                break
530            prefix = new_prefix
531            sz //= 1024
532            pref_sz /= 1024
533        return f"{sz:6d} {prefix}"
534
535    def _format_count(cnt, pref_cnt):
536        prefixes = [" ", "K", "M"]
537        prefix = prefixes[0]
538        for new_prefix in prefixes[1:]:
539            if pref_cnt < 750 * 1000:
540                break
541            prefix = new_prefix
542            cnt //= 1000
543            pref_cnt /= 1000
544        return f"{cnt:7d} {prefix} "
545
546    metrics_to_display = [
547        ("allocated_bytes", "Allocated memory", _format_size),
548        ("active_bytes", "Active memory", _format_size),
549        ("requested_bytes", "Requested memory", _format_size),
550        ("reserved_bytes", "GPU reserved memory", _format_size),
551        ("inactive_split_bytes", "Non-releasable memory", _format_size),
552        ("allocation", "Allocations", _format_count),
553        ("active", "Active allocs", _format_count),
554        ("segment", "GPU reserved segments", _format_count),
555        ("inactive_split", "Non-releasable allocs", _format_count),
556    ]
557
558    lines = []
559    lines.append("=" * 75)
560    lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
561    lines.append("-" * 75)
562    lines.append(
563        "  {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d}  "
564    )
565    lines.append("=" * 75)
566    lines.append(
567        "        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  "
568    )
569
570    for metric_key, metric_name, formatter in metrics_to_display:
571        lines.append("-" * 75)
572        submetrics = [("all", metric_name)]
573        if not abbreviated:
574            submetrics.append(("large_pool", "      from large pool"))
575            submetrics.append(("small_pool", "      from small pool"))
576
577        current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
578            None,
579            None,
580            None,
581            None,
582        )
583
584        for submetric_key, submetric_name in submetrics:
585            prefix = metric_key + "." + submetric_key + "."
586
587            current = stats[prefix + "current"]
588            peak = stats[prefix + "peak"]
589            allocated = stats[prefix + "allocated"]
590            freed = stats[prefix + "freed"]
591
592            if current_prefval is None:
593                current_prefval = current
594                peak_prefval = peak
595                allocated_prefval = allocated
596                freed_prefval = freed
597
598            lines.append(
599                f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | "
600                f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ",
601            )
602
603    metrics_to_display = [
604        ("oversize_allocations", "Oversize allocations", _format_count),
605        ("oversize_segments", "Oversize GPU segments", _format_count),
606    ]
607
608    for metric_key, metric_name, formatter in metrics_to_display:
609        lines.append("-" * 75)
610
611        prefix = metric_key + "."
612
613        current = stats[prefix + "current"]
614        peak = stats[prefix + "peak"]
615        allocated = stats[prefix + "allocated"]
616        freed = stats[prefix + "freed"]
617
618        lines.append(
619            f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | "
620            f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ",
621        )
622
623    lines.append("=" * 75)
624
625    fmt_dict = {"_": "", "device": device}
626    for k, v in stats.items():
627        fmt_dict[k.replace(".", "-")] = v
628    return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
629
630
631def list_gpu_processes(device: Union[Device, int] = None) -> str:
632    r"""Return a human-readable printout of the running processes and their GPU memory use for a given device.
633
634    This can be useful to display periodically during training, or when
635    handling out-of-memory exceptions.
636
637    Args:
638        device (torch.device or int, optional): selected device. Returns
639            printout for the current device, given by :func:`~torch.cuda.current_device`,
640            if :attr:`device` is ``None`` (default).
641    """
642    if not torch.version.hip:
643        try:
644            import pynvml  # type: ignore[import]
645        except ModuleNotFoundError:
646            return "pynvml module not found, please install pynvml"
647        from pynvml import NVMLError_DriverNotLoaded
648
649        try:
650            pynvml.nvmlInit()
651        except NVMLError_DriverNotLoaded:
652            return "cuda driver can't be loaded, is cuda enabled?"
653
654        device = _get_nvml_device_index(device)
655        handle = pynvml.nvmlDeviceGetHandleByIndex(device)
656        procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
657    else:
658        try:
659            import amdsmi  # type: ignore[import]
660        except ModuleNotFoundError:
661            return "amdsmi module not found, please install amdsmi"
662        try:
663            amdsmi.amdsmi_init()  # type: ignore[attr-defined]
664        except amdsmi.AmdSmiException:  # type: ignore[attr-defined]
665            return "amdsmi driver can't be loaded, is ROCm installed?"
666
667        device = _get_amdsmi_device_index(device)
668
669        try:
670            handle = amdsmi.amdsmi_get_processor_handles()[device]  # type: ignore[attr-defined]
671            procs = amdsmi.amdsmi_get_gpu_process_list(handle)  # type: ignore[attr-defined]
672        except amdsmi.AmdSmiException:  # type: ignore[attr-defined]
673            return "amdsmi cannot list processes from other users"
674
675    lines = []
676    lines.append(f"GPU:{device}")
677    if len(procs) == 0:
678        lines.append("no processes are running")
679    for p in procs:
680        if not torch.version.hip:
681            mem = p.usedGpuMemory / (1024 * 1024)
682            pid = p.pid
683        else:
684            try:
685                proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p)  # type: ignore[possibly-undefined]
686            except AttributeError:
687                # https://github.com/ROCm/amdsmi/commit/c551c3caedbd903ba828e7fdffa5b56d475a15e7
688                # is a BC-breaking change that removes amdsmi_get_gpu_process_info API from amdsmi
689                proc_info = p
690            mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024)
691            pid = proc_info["pid"]
692        lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory")
693    return "\n".join(lines)
694
695
696def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
697    r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
698
699    Args:
700        device (torch.device or int or str, optional): selected device. Returns
701            statistic for the current device, given by :func:`~torch.cuda.current_device`,
702            if :attr:`device` is ``None`` (default) or if the device index is not specified.
703
704    .. note::
705        See :ref:`cuda-memory-management` for more
706        details about GPU memory management.
707    """
708    if device is None:
709        device = torch.cuda.current_device()
710    # optional=True allows `device = torch.device('cuda')` for which device.index is None
711    device = _get_device_index(device, optional=True)
712    return torch.cuda.cudart().cudaMemGetInfo(device)
713
714
715def _record_memory_history_legacy(
716    enabled: bool,
717    record_context=True,
718    trace_alloc_max_entries=1,
719    trace_alloc_record_context=False,
720    device: Union[Device, int] = None,
721    record_context_cpp=False,
722):
723    _C._cuda_record_memory_history_legacy(
724        enabled,
725        record_context,
726        trace_alloc_max_entries,
727        trace_alloc_record_context,
728        record_context_cpp,
729    )
730
731
732def _record_memory_history(enabled="all", *args, **kwargs):
733    """Enable recording of stack traces associated with memory
734    allocations, so you can tell what allocated any piece of memory in
735    :func:`torch.cuda.memory._snapshot()`.
736
737    In addition too keeping stack traces with each current allocation and free,
738    this will also enable recording of a history of all alloc/free events.
739
740    Use :func:`torch.cuda.memory._snapshot()` to retrieve this information,
741    and the tools in `_memory_viz.py` to visualize snapshots.
742
743    The Python trace collection is fast (2us per trace), so you may consider
744    enabling this on production jobs if you anticipate ever having to debug
745    memory issues.
746
747    C++ trace collection is also fast (~50ns/frame), which for many typical programs
748    works out to ~2us per trace, but can vary depending on stack depth.
749
750    Args:
751        enabled (Literal[None, "state", "all"], optional):
752            `None`, disable recording memory history.
753            `"state"`, keep information for currenly allocated memory.
754            `"all"`, additionally keep a history of all alloc/free calls.
755            Defaults to "all".
756        context (Literal[None, "state", "alloc", "all"], optional):
757            `None`, Do not record any tracebacks.
758            `"state"`, Record tracebacks for currently allocated memory.
759            `"alloc"`, additionally keep tracebacks for alloc calls.
760            `"all"`, additionally keep tracebacks for free calls.
761            Defaults to "all".
762        stacks (Literal["python", "all"], optional):
763            `"python"`, include Python, TorchScript, and inductor frames in tracebacks
764            `"all"`, additionally include C++ frames
765            Defaults to "all".
766        max_entries (int, optional): Keep a maximum of `max_entries`
767            alloc/free events in the recorded history recorded.
768    """
769    if isinstance(enabled, bool):
770        return _record_memory_history_legacy(enabled, *args, **kwargs)
771    else:
772        return _record_memory_history_impl(enabled, *args, **kwargs)
773
774
775def _record_memory_history_impl(
776    enabled: Optional[str] = "all",
777    context: Optional[str] = "all",
778    stacks: str = "all",
779    max_entries: int = sys.maxsize,
780    device: Union[Device, int] = None,
781):
782    _C._cuda_record_memory_history(enabled, context, stacks, max_entries)
783
784
785_record_memory_history.__signature__ = signature(_record_memory_history_impl)  # type: ignore[attr-defined]
786
787
788def _snapshot(device: Union[Device, int] = None):
789    """Save a snapshot of CUDA memory state at the time it was called.
790
791    The state is represented as a dictionary with the following structure.
792
793    .. code-block:: python
794
795        class Snapshot(TypedDict):
796            segments : List[Segment]
797            device_traces: List[List[TraceEntry]]
798
799        class Segment(TypedDict):
800            # Segments are memory returned from a cudaMalloc call.
801            # The size of reserved memory is the sum of all Segments.
802            # Segments are cached and reused for future allocations.
803            # If the reuse is smaller than the segment, the segment
804            # is split into more then one Block.
805            # empty_cache() frees Segments that are entirely inactive.
806            address: int
807            total_size: int #  cudaMalloc'd size of segment
808            stream: int
809            segment_type: Literal['small', 'large'] # 'large' (>1MB)
810            allocated_size: int # size of memory in use
811            active_size: int # size of memory in use or in active_awaiting_free state
812            blocks : List[Block]
813
814        class Block(TypedDict):
815            # A piece of memory returned from the allocator, or
816            # current cached but inactive.
817            size: int
818            requested_size: int # size requested during malloc, may be smaller than
819                                # size due to rounding
820            address: int
821            state: Literal['active_allocated', # used by a tensor
822                        'active_awaiting_free', # waiting for another stream to finish using
823                                                # this, then it will become free
824                        'inactive',] # free for reuse
825            frames: List[Frame] # stack trace from where the allocation occurred
826
827        class Frame(TypedDict):
828                filename: str
829                line: int
830                name: str
831
832        class TraceEntry(TypedDict):
833            # When `torch.cuda.memory._record_memory_history()` is enabled,
834            # the snapshot will contain TraceEntry objects that record each
835            # action the allocator took.
836            action: Literal[
837            'alloc'  # memory allocated
838            'free_requested', # the allocated received a call to free memory
839            'free_completed', # the memory that was requested to be freed is now
840                            # able to be used in future allocation calls
841            'segment_alloc', # the caching allocator ask cudaMalloc for more memory
842                            # and added it as a segment in its cache
843            'segment_free',  # the caching allocator called cudaFree to return memory
844                            # to cuda possibly trying free up memory to
845                            # allocate more segments or because empty_caches was called
846            'oom',          # the allocator threw an OOM exception. 'size' is
847                            # the requested number of bytes that did not succeed
848            'snapshot'      # the allocator generated a memory snapshot
849                            # useful to coorelate a previously taken
850                            # snapshot with this trace
851            ]
852            addr: int # not present for OOM
853            frames: List[Frame]
854            size: int
855            stream: int
856            device_free: int # only present for OOM, the amount of
857                            # memory cuda still reports to be free
858
859    Returns:
860        The Snapshot dictionary object
861    """
862    return _C._cuda_memorySnapshot()
863
864
865def _dump_snapshot(filename="dump_snapshot.pickle"):
866    """
867    Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
868
869    This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz
870
871    Args:
872        filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
873    """
874    s = _snapshot()
875    with open(filename, "wb") as f:
876        pickle.dump(s, f)
877
878
879def _save_segment_usage(filename="output.svg", snapshot=None):
880    if snapshot is None:
881        snapshot = _snapshot()
882    with open(filename, "w") as f:
883        f.write(_segments(snapshot))
884
885
886def _save_memory_usage(filename="output.svg", snapshot=None):
887    if snapshot is None:
888        snapshot = _snapshot()
889    with open(filename, "w") as f:
890        f.write(_memory(snapshot))
891
892
893def _set_allocator_settings(env: str):
894    return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
895
896
897def get_allocator_backend() -> str:
898    r"""Return a string describing the active allocator backend as set by
899    ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
900    ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
901    (CUDA's built-in asynchronous allocator).
902
903    .. note::
904        See :ref:`cuda-memory-management` for details on choosing the allocator backend.
905    """
906    return torch._C._cuda_getAllocatorBackend()
907
908
909class _CUDAAllocator:
910    r"""Wrapper over internal CUDA memory allocators."""
911
912    def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
913        self._allocator = allocator
914
915    def allocator(self):
916        return self._allocator
917
918
919class CUDAPluggableAllocator(_CUDAAllocator):
920    r"""CUDA memory allocator loaded from a so file."""
921
922    def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
923        r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes.
924
925        To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function.
926
927        Args:
928            path_to_so_file(str): Path in the filesystem to the `.so` file containing
929                the allocator functions
930            alloc_fn_name(str): Name of the function to perform the memory allocation
931                in the so file. The signature must be:
932                void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
933            free_fn_name(str): Name of the function to perform the memory release
934                in the so file. The signature must be:
935                void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
936
937        .. warning::
938            This is currently supported only in unix OSs
939
940        .. note::
941            See :ref:`cuda-memory-management` for details on creating and using a custom allocator
942        """
943        allocator = ctypes.CDLL(path_to_so_file)
944        alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
945        free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
946        assert alloc_fn is not None
947        assert free_fn is not None
948        self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
949
950
951def change_current_allocator(allocator: _CUDAAllocator) -> None:
952    r"""Change the currently used memory allocator to be the one provided.
953
954    If the current allocator has already been used/initialized, this function will error.
955
956
957    Args:
958        allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
959    .. note::
960        See :ref:`cuda-memory-management` for details on creating and using a custom allocator
961    """
962    torch._C._cuda_changeCurrentAllocator(allocator.allocator())
963
964
965def _get_current_allocator() -> _CUDAAllocator:
966    r"""Return the allocator being currently used.
967
968    .. note::
969        See :ref:`cuda-memory-management` for details on creating and using a custom allocator
970    """
971    return _CUDAAllocator(torch._C._cuda_getAllocator())
972
973
974class MemPool(_MemPool):
975    r"""MemPool represents a pool of memory in a caching allocator. Currently,
976    it's just the ID of the pool object maintained in the CUDACachingAllocator.
977
978    Args:
979        allocator(torch._C._cuda_CUDAAllocator, optional): a
980            torch._C._cuda_CUDAAllocator object that can be used to
981            define how memory gets allocated in the pool. If :attr:`allocator`
982            is ``None`` (default), memory allocation follows the default/
983            current configuration of the CUDACachingAllocator.
984
985    """
986
987    def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None):
988        super().__init__(allocator, True)
989
990    @property
991    def id(self) -> Tuple[int, int]:
992        r"""Returns the ID of this pool as a tuple of two ints."""
993        return super().id
994
995    @property
996    def allocator(self) -> Optional[_cuda_CUDAAllocator]:
997        r"""Returns the allocator this MemPool routes allocations to"""
998        return super().allocator
999
1000
1001class MemPoolContext(_MemPoolContext):
1002    r"""MemPoolContext holds the currently active pool and stashes the previous
1003    pool. On deletion it makes the previous pool active.
1004
1005    Args:
1006        pool(torch.cuda.MemPool): a MemPool object to be made active so that
1007        allocations route to this pool.
1008
1009    """
1010
1011    def __init__(self, pool: MemPool):
1012        super().__init__(pool)
1013
1014    @staticmethod
1015    def active_pool() -> Optional[_MemPool]:
1016        r"""Returns the active MemPool"""
1017        return _MemPoolContext.active_pool()
1018
1019
1020@contextlib.contextmanager
1021def use_mem_pool(pool: MemPool, device: Union[Device, int] = None):
1022    r"""A context manager that routes allocations to a given pool.
1023
1024    Args:
1025        pool(torch.cuda.MemPool): a MemPool object to be made active so that
1026            allocations route to this pool.
1027        device (torch.device or int, optional): selected device. Uses MemPool on
1028            the current device, given by :func:`~torch.cuda.current_device`,
1029            if :attr:`device` is ``None`` (default).
1030
1031    """
1032    ctx = MemPoolContext(pool)
1033    device_index = (
1034        torch.cuda.current_device() if device is None else _get_device_index(device)
1035    )
1036    _cuda_beginAllocateToPool(device_index, pool.id)
1037    try:
1038        yield
1039    finally:
1040        _cuda_endAllocateCurrentStreamToPool(device_index, pool.id)
1041        del ctx
1042