1# mypy: allow-untyped-defs 2from functools import lru_cache as _lru_cache 3from typing import Optional, TYPE_CHECKING 4 5import torch 6from torch.library import Library as _Library 7 8 9__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"] 10 11 12def is_built() -> bool: 13 r"""Return whether PyTorch is built with MPS support. 14 15 Note that this doesn't necessarily mean MPS is available; just that 16 if this PyTorch binary were run a machine with working MPS drivers 17 and devices, we would be able to use it. 18 """ 19 return torch._C._has_mps 20 21 22@_lru_cache 23def is_available() -> bool: 24 r"""Return a bool indicating if MPS is currently available.""" 25 return torch._C._mps_is_available() 26 27 28@_lru_cache 29def is_macos_or_newer(major: int, minor: int) -> bool: 30 r"""Return a bool indicating whether MPS is running on given MacOS or newer.""" 31 return torch._C._mps_is_on_macos_or_newer(major, minor) 32 33 34@_lru_cache 35def is_macos13_or_newer(minor: int = 0) -> bool: 36 r"""Return a bool indicating whether MPS is running on MacOS 13 or newer.""" 37 return torch._C._mps_is_on_macos_or_newer(13, minor) 38 39 40_lib: Optional[_Library] = None 41 42 43def _init(): 44 r"""Register prims as implementation of var_mean and group_norm.""" 45 global _lib 46 47 if _lib is not None or not is_built(): 48 return 49 50 from torch._decomp.decompositions import native_group_norm_backward 51 from torch._refs import native_group_norm 52 53 _lib = _Library("aten", "IMPL") # noqa: TOR901 54 _lib.impl("native_group_norm", native_group_norm, "MPS") 55 _lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS") 56