xref: /aosp_15_r20/external/pytorch/torch/_namedtensor_internals.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import OrderedDict
3
4
5"""
6This file contains helper functions that implement experimental functionality
7for named tensors in python. All of these are experimental, unstable, and
8subject to change or deletion.
9"""
10
11
12def check_serializing_named_tensor(tensor):
13    if tensor.has_names():
14        raise RuntimeError(
15            "NYI: Named tensors don't support serialization. Please drop "
16            "names via `tensor = tensor.rename(None)` before serialization."
17        )
18
19
20def build_dim_map(tensor):
21    """Returns a map of { dim: dim_name } where dim is a name if the dim is named
22    and the dim index otherwise."""
23    return OrderedDict(
24        [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]
25    )
26
27
28def unzip_namedshape(namedshape):
29    if isinstance(namedshape, OrderedDict):
30        namedshape = namedshape.items()
31    if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple):
32        raise RuntimeError(
33            f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}"
34        )
35    if len(namedshape) == 0:
36        raise RuntimeError("Expected namedshape to non-empty.")
37    return zip(*namedshape)
38
39
40def namer_api_name(inplace):
41    if inplace:
42        return "rename_"
43    else:
44        return "rename"
45
46
47def is_ellipsis(item):
48    return item == Ellipsis or item == "..."
49
50
51def single_ellipsis_index(names, fn_name):
52    ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
53    if len(ellipsis_indices) >= 2:
54        raise RuntimeError(
55            f"{fn_name}: More than one Ellipsis ('...') found in names ("
56            f"{names}). This function supports up to one Ellipsis."
57        )
58    if len(ellipsis_indices) == 1:
59        return ellipsis_indices[0]
60    return None
61
62
63def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
64    return names[numel_pre_glob : len(names) - numel_post_glob]
65
66
67def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
68    globbed_names = expand_single_ellipsis(
69        ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names
70    )
71    return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :]
72
73
74def resolve_ellipsis(names, tensor_names, fn_name):
75    """
76    Expands ... inside `names` to be equal to a list of names from `tensor_names`.
77    """
78    ellipsis_idx = single_ellipsis_index(names, fn_name)
79    if ellipsis_idx is None:
80        return names
81    return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
82
83
84def update_names_with_list(tensor, names, inplace):
85    # Special case for tensor.rename(None)
86    if len(names) == 1 and names[0] is None:
87        return tensor._update_names(None, inplace)
88
89    return tensor._update_names(
90        resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace
91    )
92
93
94def update_names_with_mapping(tensor, rename_map, inplace):
95    dim_map = build_dim_map(tensor)
96    for old_dim in rename_map.keys():
97        new_dim = rename_map[old_dim]
98        if old_dim in dim_map.keys():
99            dim_map[old_dim] = new_dim
100        else:
101            raise RuntimeError(
102                f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim "
103                f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist"
104            )
105    return tensor._update_names(tuple(dim_map.values()), inplace)
106
107
108def update_names(tensor, names, rename_map, inplace):
109    """There are two usages:
110
111    tensor.rename(*names) returns a view on tensor with named dims `names`.
112    `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`,
113    then it is expanded greedily to be equal to the corresponding names from
114    `tensor.names`.
115
116    For example,
117    ```
118    >>> # xdoctest: +SKIP
119    >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
120    >>> x.rename('...', 'height', 'width').names
121    ('N', 'C', 'height', 'width')
122
123    >>> # xdoctest: +SKIP
124    >>> x.rename('batch', '...', 'width').names
125    ('batch', 'C', 'H', 'width')
126
127    ```
128
129    tensor.rename(**rename_map) returns a view on tensor that has rename dims
130        as specified in the mapping `rename_map`.
131
132    For example,
133    ```
134    >>> # xdoctest: +SKIP
135    >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
136    >>> x.rename(W='width', H='height').names
137    ('N', 'C', 'height', 'width')
138
139    ```
140
141    Finally, tensor.rename has an in-place version called tensor.rename_.
142    """
143    has_names = len(names) > 0
144    has_rename_pairs = bool(rename_map)
145    if has_names and has_rename_pairs:
146        raise RuntimeError(
147            f"{namer_api_name(inplace)}: This function takes either positional "
148            f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) "
149            f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename "
150            "dims."
151        )
152
153    # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor.
154    if not has_names and not has_rename_pairs:
155        return update_names_with_list(tensor, names, inplace)
156
157    if has_names:
158        return update_names_with_list(tensor, names, inplace)
159    return update_names_with_mapping(tensor, rename_map, inplace)
160