xref: /aosp_15_r20/external/pytorch/torch/utils/_foreach_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import List, Dict, Tuple, Optional
2
3import torch
4from torch import Tensor
5from torch.autograd.grad_mode import no_grad
6from typing_extensions import TypeAlias
7
8def _get_foreach_kernels_supported_devices() -> List[str]:
9    r"""Return the device type list that supports foreach kernels."""
10    return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
11
12def _get_fused_kernels_supported_devices() -> List[str]:
13    r"""Return the device type list that supports fused kernels in optimizer."""
14    return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
15
16TensorListList: TypeAlias = List[List[Optional[Tensor]]]
17Indices: TypeAlias = List[int]
18_foreach_supported_types = [torch.Tensor]
19
20
21# This util function splits tensors into groups by device and dtype, which is useful before sending
22# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
23# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
24#   - tensorlists CAN be None
25#   - all tensors in the first specified list cannot be None
26#   - given an index i, all specified tensorlist[i]s match in dtype and device
27# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
28#   It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
29#   Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
30#   original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
31#   may be necessary. Check out torch/optim/sgd.py for an example.
32@no_grad()
33def _group_tensors_by_device_and_dtype(
34    tensorlistlist: TensorListList,
35    with_indices: bool = False,
36) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
37    return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
38
39def _device_has_foreach_support(device: torch.device) -> bool:
40    return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
41
42
43def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
44    return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
45