xref: /aosp_15_r20/external/pytorch/torch/xpu/memory.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import collections
2from typing import Any, Dict, Union
3
4import torch
5from torch.types import Device
6
7from . import _get_device_index, is_initialized
8
9
10_device_t = Union[Device, str, int, None]
11
12
13def empty_cache() -> None:
14    r"""Release all unoccupied cached memory currently held by the caching
15    allocator so that those can be used in other XPU application.
16
17    .. note::
18        :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU
19        memory available for PyTorch. However, it may help reduce fragmentation
20        of XPU memory in certain cases.
21    """
22    if is_initialized():
23        torch._C._xpu_emptyCache()
24
25
26def reset_peak_memory_stats(device: _device_t = None) -> None:
27    r"""Reset the "peak" stats tracked by the XPU memory allocator.
28
29    See :func:`~torch.xpu.memory_stats` for details. Peak stats correspond to the
30    `"peak"` key in each individual stat dict.
31
32    Args:
33        device (torch.device or int or str, optional): selected device. Returns
34            statistic for the current device, given by :func:`~torch.xpu.current_device`,
35            if :attr:`device` is ``None`` (default).
36    """
37    device = _get_device_index(device, optional=True)
38    return torch._C._xpu_resetPeakMemoryStats(device)
39
40
41def reset_accumulated_memory_stats(device: _device_t = None) -> None:
42    r"""Reset the "accumulated" (historical) stats tracked by the XPU memory allocator.
43
44    See :func:`~torch.xpu.memory_stats` for details. Accumulated stats correspond to
45    the `"allocated"` and `"freed"` keys in each individual stat dict.
46
47    Args:
48        device (torch.device or int or str, optional): selected device. Returns
49            statistic for the current device, given by :func:`~torch.xpu.current_device`,
50            if :attr:`device` is ``None`` (default).
51    """
52    device = _get_device_index(device, optional=True)
53    return torch._C._xpu_resetAccumulatedMemoryStats(device)
54
55
56def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]:
57    r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary."""
58    if not is_initialized():
59        return {}
60    device = _get_device_index(device, optional=True)
61    return torch._C._xpu_memoryStats(device)
62
63
64def memory_stats(device: _device_t = None) -> Dict[str, Any]:
65    r"""Return a dictionary of XPU memory allocator statistics for a given device.
66
67    The return value of this function is a dictionary of statistics, each of
68    which is a non-negative integer.
69
70    Core statistics:
71
72    - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
73      amount of allocated memory.
74    - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
75      amount of reserved memory.
76    - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
77      amount of active memory.
78    - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
79      memory requested by client code, compare this with allocated_bytes to check if
80      allocation rounding adds too much overhead.
81
82    For these core statistics, values are broken down as follows.
83
84    Pool type:
85
86    - ``all``: combined statistics across all memory pools.
87    - ``large_pool``: statistics for the large allocation pool (for size >= 1MB allocations).
88    - ``small_pool``: statistics for the small allocation pool (for size < 1MB allocations).
89
90    Metric type:
91
92    - ``current``: current value of this metric.
93    - ``peak``: maximum value of this metric.
94    - ``allocated``: historical total increase in this metric.
95    - ``freed``: historical total decrease in this metric.
96
97    Args:
98        device (torch.device or int or str, optional): selected device. Returns
99            statistics for the current device, given by :func:`~torch.xpu.current_device`,
100            if :attr:`device` is ``None`` (default).
101    """
102    result = []
103
104    def _recurse_add_to_result(prefix: str, obj: Any) -> None:
105        if isinstance(obj, dict):
106            if len(prefix) > 0:
107                prefix += "."
108            for k, v in obj.items():
109                _recurse_add_to_result(prefix + k, v)
110        else:
111            result.append((prefix, obj))
112
113    stats = memory_stats_as_nested_dict(device=device)
114    _recurse_add_to_result("", stats)
115    result.sort()
116
117    return collections.OrderedDict(result)
118
119
120def memory_allocated(device: _device_t = None) -> int:
121    r"""Return the current GPU memory occupied by tensors in bytes for a given device.
122
123    Args:
124        device (torch.device or int or str, optional): selected device. Returns
125            statistic for the current device, given by :func:`~torch.xpu.current_device`,
126            if :attr:`device` is ``None`` (default).
127
128    .. note::
129        This is likely less than the amount shown in `xpu-smi` since some
130        unused memory can be held by the caching allocator and some context
131        needs to be created on GPU.
132    """
133    return memory_stats(device=device).get("allocated_bytes.all.current", 0)
134
135
136def max_memory_allocated(device: _device_t = None) -> int:
137    r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
138
139    By default, this returns the peak allocated memory since the beginning of
140    this program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to
141    reset the starting point in tracking this metric. For example, these two
142    functions can measure the peak allocated memory usage of each iteration in a
143    training loop.
144
145    Args:
146        device (torch.device or int or str, optional): selected device. Returns
147            statistic for the current device, given by :func:`~torch.xpu.current_device`,
148            if :attr:`device` is ``None`` (default).
149    """
150    return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
151
152
153def memory_reserved(device: _device_t = None) -> int:
154    r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
155
156    Args:
157        device (torch.device or int or str, optional): selected device. Returns
158            statistic for the current device, given by :func:`~torch.xpu.current_device`,
159            if :attr:`device` is ``None`` (default).
160    """
161    return memory_stats(device=device).get("reserved_bytes.all.current", 0)
162
163
164def max_memory_reserved(device: _device_t = None) -> int:
165    r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
166
167    By default, this returns the peak cached memory since the beginning of this
168    program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to reset
169    the starting point in tracking this metric. For example, these two functions
170    can measure the peak cached memory amount of each iteration in a training
171    loop.
172
173    Args:
174        device (torch.device or int or str, optional): selected device. Returns
175            statistic for the current device, given by :func:`~torch.xpu.current_device`,
176            if :attr:`device` is ``None`` (default).
177    """
178    return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
179
180
181__all__ = [
182    "empty_cache",
183    "max_memory_allocated",
184    "max_memory_reserved",
185    "memory_allocated",
186    "memory_reserved",
187    "memory_stats",
188    "memory_stats_as_nested_dict",
189    "reset_accumulated_memory_stats",
190    "reset_peak_memory_stats",
191]
192