xref: /aosp_15_r20/external/pytorch/functorch/einops/rearrange.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import functools
4from typing import Callable, Dict, List, Sequence, Tuple, Union
5
6import torch
7from functorch._C import dim as _C
8
9from ._parsing import (
10    _ellipsis,
11    AnonymousAxis,
12    comma_separate,
13    parse_pattern,
14    validate_rearrange_expressions,
15)
16
17
18__all__ = ["rearrange"]
19
20dims = _C.dims
21
22
23@functools.lru_cache(256)
24def _create_rearrange_callable(
25    tensor_ndim: int, pattern: str, **axes_lengths: int
26) -> Callable[[torch.Tensor], torch.Tensor]:
27    r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
28
29    Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
30    specified axes lengths, this function can be memoized.
31
32    Args:
33        tensor_ndim (int): the number of dimensions in the tensor to rearrange
34        pattern (str): the `einops`-style rearrangement pattern
35        axes_lengths (int): any additional length specifications for dimensions
36
37    Returns:
38        Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
39    """
40    left, right = parse_pattern(pattern, axes_lengths)
41    validate_rearrange_expressions(left, right, axes_lengths)
42
43    n_anon_dims = sum(not dim for dim in left.composition)
44    if left.has_ellipsis:
45        n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
46        n_named_dims = len(left.identifiers) - 1
47
48        if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
49            raise ValueError(
50                f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
51                f"dimensions in the tensor ({tensor_ndim})"
52            )
53    else:
54        n_ellipsis_dims = 0
55        n_named_dims = len(left.identifiers)
56
57        if (pattern_ndim := len(left.composition)) != tensor_ndim:
58            raise ValueError(
59                f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
60                f"the tensor ({tensor_ndim})"
61            )
62    n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
63
64    if n_dims == 0:
65        # an identity rearrangement on a 0-dimension tensor
66        return lambda tensor: tensor
67
68    first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
69    identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
70    anon_axes: List[AnonymousAxis] = []
71
72    # map the left-hand side identifiers to strings representing first class dims
73    dims_i = 0
74    for dimension in left.composition:
75        if isinstance(dimension, list):
76            for identifier in dimension:
77                # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
78                assert isinstance(identifier, str)
79                identifier_dim_map[identifier] = (first_class_dims[dims_i],)
80                dims_i += 1
81            if not dimension:
82                # unitary anonymous axis
83                anon_axis = AnonymousAxis("1")
84                identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
85                anon_axes.append(anon_axis)
86                dimension.append(anon_axis)
87                dims_i += 1
88        elif dimension == _ellipsis:
89            identifier = _ellipsis
90            identifier_dim_map[identifier] = tuple(
91                first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
92            )
93            dims_i += n_ellipsis_dims
94        else:
95            raise ValueError(f"Unexpected dimension: {dimension}")
96
97    def composition_to_dims(
98        composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
99    ) -> List[Union[str, Tuple[str, ...]]]:
100        """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
101        class dims."""
102        dim_composition: List[Union[str, Tuple[str, ...]]] = []
103        for dimension in composition:
104            if isinstance(dimension, list):
105                dim_composition.append(
106                    tuple(
107                        dim
108                        for identifier in dimension
109                        for dim in identifier_dim_map[identifier]
110                    )
111                )
112            elif dimension == _ellipsis:
113                dim_composition.extend(identifier_dim_map[_ellipsis])
114            else:
115                raise ValueError(f"Unexpected dimension: {dimension}")
116        return dim_composition
117
118    left_dims = composition_to_dims(left.composition)
119    right_dims = composition_to_dims(right.composition)
120    anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
121    specified_lengths = tuple(
122        (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
123    )
124
125    custom_rearrange_callable_name = "do_rearrange"
126    custom_rearrange_callable_code = (
127        (
128            f"def {custom_rearrange_callable_name}(tensor):\n"
129            f"    {comma_separate(first_class_dims)} = dims({n_dims})\n"
130        )
131        + (
132            "".join(
133                f"    {dim}.size = {length}\n" for (dim, length) in specified_lengths
134            )
135            if specified_lengths
136            else ""
137        )
138        + f"    tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
139        + (
140            f"    return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
141            if anon_dims
142            else "    return tensor\n"
143        )
144    )
145
146    exec(custom_rearrange_callable_code)
147    return locals()[custom_rearrange_callable_name]
148
149
150def rearrange(
151    tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
152    pattern: str,
153    **axes_lengths: int,
154) -> torch.Tensor:
155    r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
156    tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
157    stack, concatenate and other operations.
158
159    See: https://einops.rocks/api/rearrange/
160
161    Args:
162        tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
163        pattern (str): the rearrangement pattern
164        axes_lengths (int): any additional length specifications for dimensions
165
166    Returns:
167        Tensor: the rearranged tensor
168
169    Examples:
170        >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
171        >>> images = torch.randn((32, 30, 40, 3))
172
173        >>> # stack along first (batch) axis, output is a single array
174        >>> rearrange(images, 'b h w c -> b h w c').shape
175        torch.Size([32, 30, 40, 3])
176
177        >>> # concatenate images along height (vertical axis), 960 = 32 * 30
178        >>> rearrange(images, 'b h w c -> (b h) w c').shape
179        torch.Size([960, 40, 3])
180
181        >>> # concatenated images along horizontal axis, 1280 = 32 * 40
182        >>> rearrange(images, 'b h w c -> h (b w) c').shape
183        torch.Size([30, 1280, 3])
184
185        >>> # reordered axes to "b c h w" format for deep learning
186        >>> rearrange(images, 'b h w c -> b c h w').shape
187        torch.Size([32, 3, 30, 40])
188
189        >>> # flattened each image into a vector, 3600 = 30 * 40 * 3
190        >>> rearrange(images, 'b h w c -> b (c h w)').shape
191        torch.Size([32, 3600])
192
193        >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
194        >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
195        torch.Size([128, 15, 20, 3])
196
197        >>> # space-to-depth operation
198        >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
199        torch.Size([32, 15, 20, 12])
200    """
201    if not isinstance(tensor, torch.Tensor):
202        tensor = torch.stack(tensor)
203
204    rearrange_callable = _create_rearrange_callable(
205        tensor.ndim, pattern, **axes_lengths
206    )
207
208    return rearrange_callable(tensor)
209