xref: /aosp_15_r20/external/pytorch/torch/mps/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python.
4Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased
5performance can be achieved, by running work on the metal GPU(s).
6See https://developer.apple.com/documentation/metalperformanceshaders for more details.
7"""
8from typing import Union
9
10import torch
11from torch import Tensor
12
13
14_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
15_default_mps_generator: torch._C.Generator = None  # type: ignore[assignment]
16
17
18# local helper function (not public or exported)
19def _get_default_mps_generator() -> torch._C.Generator:
20    global _default_mps_generator
21    if _default_mps_generator is None:
22        _default_mps_generator = torch._C._mps_get_default_generator()
23    return _default_mps_generator
24
25
26def device_count() -> int:
27    r"""Returns the number of available MPS devices."""
28    return int(torch._C._has_mps and torch._C._mps_is_available())
29
30
31def synchronize() -> None:
32    r"""Waits for all kernels in all streams on a MPS device to complete."""
33    return torch._C._mps_deviceSynchronize()
34
35
36def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor:
37    r"""Returns the random number generator state as a ByteTensor.
38
39    Args:
40        device (torch.device or int, optional): The device to return the RNG state of.
41            Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
42    """
43    return _get_default_mps_generator().get_state()
44
45
46def set_rng_state(
47    new_state: Tensor, device: Union[int, str, torch.device] = "mps"
48) -> None:
49    r"""Sets the random number generator state.
50
51    Args:
52        new_state (torch.ByteTensor): The desired state
53        device (torch.device or int, optional): The device to set the RNG state.
54            Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
55    """
56    new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
57    _get_default_mps_generator().set_state(new_state_copy)
58
59
60def manual_seed(seed: int) -> None:
61    r"""Sets the seed for generating random numbers.
62
63    Args:
64        seed (int): The desired seed.
65    """
66    # the torch.mps.manual_seed() can be called from the global
67    # torch.manual_seed() in torch/random.py. So we need to make
68    # sure mps is available (otherwise we just return without
69    # erroring out)
70    if not torch._C._has_mps:
71        return
72    seed = int(seed)
73    _get_default_mps_generator().manual_seed(seed)
74
75
76def seed() -> None:
77    r"""Sets the seed for generating random numbers to a random number."""
78    _get_default_mps_generator().seed()
79
80
81def empty_cache() -> None:
82    r"""Releases all unoccupied cached memory currently held by the caching
83    allocator so that those can be used in other GPU applications.
84    """
85    torch._C._mps_emptyCache()
86
87
88def set_per_process_memory_fraction(fraction) -> None:
89    r"""Set memory fraction for limiting process's memory allocation on MPS device.
90    The allowed value equals the fraction multiplied by recommended maximum device memory
91    (obtained from Metal API device.recommendedMaxWorkingSetSize).
92    If trying to allocate more than the allowed value in a process, it will raise an out of
93    memory error in allocator.
94
95    Args:
96        fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
97
98    .. note::
99       Passing 0 to fraction means unlimited allocations
100       (may cause system failure if out of memory).
101       Passing fraction greater than 1.0 allows limits beyond the value
102       returned from device.recommendedMaxWorkingSetSize.
103    """
104
105    if not isinstance(fraction, float):
106        raise TypeError("Invalid type for fraction argument, must be `float`")
107    if fraction < 0 or fraction > 2:
108        raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2")
109
110    torch._C._mps_setMemoryFraction(fraction)
111
112
113def current_allocated_memory() -> int:
114    r"""Returns the current GPU memory occupied by tensors in bytes.
115
116    .. note::
117       The returned size does not include cached allocations in
118       memory pools of MPSAllocator.
119    """
120    return torch._C._mps_currentAllocatedMemory()
121
122
123def driver_allocated_memory() -> int:
124    r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
125
126    .. note::
127       The returned size includes cached allocations in MPSAllocator pools
128       as well as allocations from MPS/MPSGraph frameworks.
129    """
130    return torch._C._mps_driverAllocatedMemory()
131
132
133def recommended_max_memory() -> int:
134    r"""Returns recommended max Working set size for GPU memory in bytes.
135
136    .. note::
137       Recommended max working set size for Metal.
138       returned from device.recommendedMaxWorkingSetSize.
139    """
140    return torch._C._mps_recommendedMaxMemory()
141
142
143def is_available() -> bool:
144    return device_count() > 0
145
146
147from . import profiler
148from .event import Event
149
150
151__all__ = [
152    "device_count",
153    "get_rng_state",
154    "manual_seed",
155    "seed",
156    "set_rng_state",
157    "synchronize",
158    "empty_cache",
159    "set_per_process_memory_fraction",
160    "current_allocated_memory",
161    "driver_allocated_memory",
162    "Event",
163    "profiler",
164    "recommended_max_memory",
165    "is_available",
166]
167