xref: /aosp_15_r20/external/pytorch/torch/backends/mps/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from functools import lru_cache as _lru_cache
3from typing import Optional, TYPE_CHECKING
4
5import torch
6from torch.library import Library as _Library
7
8
9__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
10
11
12def is_built() -> bool:
13    r"""Return whether PyTorch is built with MPS support.
14
15    Note that this doesn't necessarily mean MPS is available; just that
16    if this PyTorch binary were run a machine with working MPS drivers
17    and devices, we would be able to use it.
18    """
19    return torch._C._has_mps
20
21
22@_lru_cache
23def is_available() -> bool:
24    r"""Return a bool indicating if MPS is currently available."""
25    return torch._C._mps_is_available()
26
27
28@_lru_cache
29def is_macos_or_newer(major: int, minor: int) -> bool:
30    r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
31    return torch._C._mps_is_on_macos_or_newer(major, minor)
32
33
34@_lru_cache
35def is_macos13_or_newer(minor: int = 0) -> bool:
36    r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
37    return torch._C._mps_is_on_macos_or_newer(13, minor)
38
39
40_lib: Optional[_Library] = None
41
42
43def _init():
44    r"""Register prims as implementation of var_mean and group_norm."""
45    global _lib
46
47    if _lib is not None or not is_built():
48        return
49
50    from torch._decomp.decompositions import native_group_norm_backward
51    from torch._refs import native_group_norm
52
53    _lib = _Library("aten", "IMPL")  # noqa: TOR901
54    _lib.impl("native_group_norm", native_group_norm, "MPS")
55    _lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")
56