1from typing import List, Dict, Tuple, Optional 2 3import torch 4from torch import Tensor 5from torch.autograd.grad_mode import no_grad 6from typing_extensions import TypeAlias 7 8def _get_foreach_kernels_supported_devices() -> List[str]: 9 r"""Return the device type list that supports foreach kernels.""" 10 return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] 11 12def _get_fused_kernels_supported_devices() -> List[str]: 13 r"""Return the device type list that supports fused kernels in optimizer.""" 14 return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()] 15 16TensorListList: TypeAlias = List[List[Optional[Tensor]]] 17Indices: TypeAlias = List[int] 18_foreach_supported_types = [torch.Tensor] 19 20 21# This util function splits tensors into groups by device and dtype, which is useful before sending 22# tensors off to a foreach implementation, which requires tensors to be on one device and dtype. 23# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified: 24# - tensorlists CAN be None 25# - all tensors in the first specified list cannot be None 26# - given an index i, all specified tensorlist[i]s match in dtype and device 27# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry. 28# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out. 29# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the 30# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation 31# may be necessary. Check out torch/optim/sgd.py for an example. 32@no_grad() 33def _group_tensors_by_device_and_dtype( 34 tensorlistlist: TensorListList, 35 with_indices: bool = False, 36) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]: 37 return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) 38 39def _device_has_foreach_support(device: torch.device) -> bool: 40 return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() 41 42 43def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool: 44 return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors) 45