xref: /aosp_15_r20/external/pytorch/functorch/einops/rearrange.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Sequence, Tuple, Union
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom functorch._C import dim as _C
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom ._parsing import (
10*da0073e9SAndroid Build Coastguard Worker    _ellipsis,
11*da0073e9SAndroid Build Coastguard Worker    AnonymousAxis,
12*da0073e9SAndroid Build Coastguard Worker    comma_separate,
13*da0073e9SAndroid Build Coastguard Worker    parse_pattern,
14*da0073e9SAndroid Build Coastguard Worker    validate_rearrange_expressions,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker__all__ = ["rearrange"]
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerdims = _C.dims
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(256)
24*da0073e9SAndroid Build Coastguard Workerdef _create_rearrange_callable(
25*da0073e9SAndroid Build Coastguard Worker    tensor_ndim: int, pattern: str, **axes_lengths: int
26*da0073e9SAndroid Build Coastguard Worker) -> Callable[[torch.Tensor], torch.Tensor]:
27*da0073e9SAndroid Build Coastguard Worker    r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
30*da0073e9SAndroid Build Coastguard Worker    specified axes lengths, this function can be memoized.
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    Args:
33*da0073e9SAndroid Build Coastguard Worker        tensor_ndim (int): the number of dimensions in the tensor to rearrange
34*da0073e9SAndroid Build Coastguard Worker        pattern (str): the `einops`-style rearrangement pattern
35*da0073e9SAndroid Build Coastguard Worker        axes_lengths (int): any additional length specifications for dimensions
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    Returns:
38*da0073e9SAndroid Build Coastguard Worker        Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
39*da0073e9SAndroid Build Coastguard Worker    """
40*da0073e9SAndroid Build Coastguard Worker    left, right = parse_pattern(pattern, axes_lengths)
41*da0073e9SAndroid Build Coastguard Worker    validate_rearrange_expressions(left, right, axes_lengths)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    n_anon_dims = sum(not dim for dim in left.composition)
44*da0073e9SAndroid Build Coastguard Worker    if left.has_ellipsis:
45*da0073e9SAndroid Build Coastguard Worker        n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
46*da0073e9SAndroid Build Coastguard Worker        n_named_dims = len(left.identifiers) - 1
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
49*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
50*da0073e9SAndroid Build Coastguard Worker                f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
51*da0073e9SAndroid Build Coastguard Worker                f"dimensions in the tensor ({tensor_ndim})"
52*da0073e9SAndroid Build Coastguard Worker            )
53*da0073e9SAndroid Build Coastguard Worker    else:
54*da0073e9SAndroid Build Coastguard Worker        n_ellipsis_dims = 0
55*da0073e9SAndroid Build Coastguard Worker        n_named_dims = len(left.identifiers)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        if (pattern_ndim := len(left.composition)) != tensor_ndim:
58*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
59*da0073e9SAndroid Build Coastguard Worker                f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
60*da0073e9SAndroid Build Coastguard Worker                f"the tensor ({tensor_ndim})"
61*da0073e9SAndroid Build Coastguard Worker            )
62*da0073e9SAndroid Build Coastguard Worker    n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    if n_dims == 0:
65*da0073e9SAndroid Build Coastguard Worker        # an identity rearrangement on a 0-dimension tensor
66*da0073e9SAndroid Build Coastguard Worker        return lambda tensor: tensor
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
69*da0073e9SAndroid Build Coastguard Worker    identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
70*da0073e9SAndroid Build Coastguard Worker    anon_axes: List[AnonymousAxis] = []
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    # map the left-hand side identifiers to strings representing first class dims
73*da0073e9SAndroid Build Coastguard Worker    dims_i = 0
74*da0073e9SAndroid Build Coastguard Worker    for dimension in left.composition:
75*da0073e9SAndroid Build Coastguard Worker        if isinstance(dimension, list):
76*da0073e9SAndroid Build Coastguard Worker            for identifier in dimension:
77*da0073e9SAndroid Build Coastguard Worker                # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
78*da0073e9SAndroid Build Coastguard Worker                assert isinstance(identifier, str)
79*da0073e9SAndroid Build Coastguard Worker                identifier_dim_map[identifier] = (first_class_dims[dims_i],)
80*da0073e9SAndroid Build Coastguard Worker                dims_i += 1
81*da0073e9SAndroid Build Coastguard Worker            if not dimension:
82*da0073e9SAndroid Build Coastguard Worker                # unitary anonymous axis
83*da0073e9SAndroid Build Coastguard Worker                anon_axis = AnonymousAxis("1")
84*da0073e9SAndroid Build Coastguard Worker                identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
85*da0073e9SAndroid Build Coastguard Worker                anon_axes.append(anon_axis)
86*da0073e9SAndroid Build Coastguard Worker                dimension.append(anon_axis)
87*da0073e9SAndroid Build Coastguard Worker                dims_i += 1
88*da0073e9SAndroid Build Coastguard Worker        elif dimension == _ellipsis:
89*da0073e9SAndroid Build Coastguard Worker            identifier = _ellipsis
90*da0073e9SAndroid Build Coastguard Worker            identifier_dim_map[identifier] = tuple(
91*da0073e9SAndroid Build Coastguard Worker                first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
92*da0073e9SAndroid Build Coastguard Worker            )
93*da0073e9SAndroid Build Coastguard Worker            dims_i += n_ellipsis_dims
94*da0073e9SAndroid Build Coastguard Worker        else:
95*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Unexpected dimension: {dimension}")
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    def composition_to_dims(
98*da0073e9SAndroid Build Coastguard Worker        composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
99*da0073e9SAndroid Build Coastguard Worker    ) -> List[Union[str, Tuple[str, ...]]]:
100*da0073e9SAndroid Build Coastguard Worker        """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
101*da0073e9SAndroid Build Coastguard Worker        class dims."""
102*da0073e9SAndroid Build Coastguard Worker        dim_composition: List[Union[str, Tuple[str, ...]]] = []
103*da0073e9SAndroid Build Coastguard Worker        for dimension in composition:
104*da0073e9SAndroid Build Coastguard Worker            if isinstance(dimension, list):
105*da0073e9SAndroid Build Coastguard Worker                dim_composition.append(
106*da0073e9SAndroid Build Coastguard Worker                    tuple(
107*da0073e9SAndroid Build Coastguard Worker                        dim
108*da0073e9SAndroid Build Coastguard Worker                        for identifier in dimension
109*da0073e9SAndroid Build Coastguard Worker                        for dim in identifier_dim_map[identifier]
110*da0073e9SAndroid Build Coastguard Worker                    )
111*da0073e9SAndroid Build Coastguard Worker                )
112*da0073e9SAndroid Build Coastguard Worker            elif dimension == _ellipsis:
113*da0073e9SAndroid Build Coastguard Worker                dim_composition.extend(identifier_dim_map[_ellipsis])
114*da0073e9SAndroid Build Coastguard Worker            else:
115*da0073e9SAndroid Build Coastguard Worker                raise ValueError(f"Unexpected dimension: {dimension}")
116*da0073e9SAndroid Build Coastguard Worker        return dim_composition
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    left_dims = composition_to_dims(left.composition)
119*da0073e9SAndroid Build Coastguard Worker    right_dims = composition_to_dims(right.composition)
120*da0073e9SAndroid Build Coastguard Worker    anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
121*da0073e9SAndroid Build Coastguard Worker    specified_lengths = tuple(
122*da0073e9SAndroid Build Coastguard Worker        (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
123*da0073e9SAndroid Build Coastguard Worker    )
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    custom_rearrange_callable_name = "do_rearrange"
126*da0073e9SAndroid Build Coastguard Worker    custom_rearrange_callable_code = (
127*da0073e9SAndroid Build Coastguard Worker        (
128*da0073e9SAndroid Build Coastguard Worker            f"def {custom_rearrange_callable_name}(tensor):\n"
129*da0073e9SAndroid Build Coastguard Worker            f"    {comma_separate(first_class_dims)} = dims({n_dims})\n"
130*da0073e9SAndroid Build Coastguard Worker        )
131*da0073e9SAndroid Build Coastguard Worker        + (
132*da0073e9SAndroid Build Coastguard Worker            "".join(
133*da0073e9SAndroid Build Coastguard Worker                f"    {dim}.size = {length}\n" for (dim, length) in specified_lengths
134*da0073e9SAndroid Build Coastguard Worker            )
135*da0073e9SAndroid Build Coastguard Worker            if specified_lengths
136*da0073e9SAndroid Build Coastguard Worker            else ""
137*da0073e9SAndroid Build Coastguard Worker        )
138*da0073e9SAndroid Build Coastguard Worker        + f"    tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
139*da0073e9SAndroid Build Coastguard Worker        + (
140*da0073e9SAndroid Build Coastguard Worker            f"    return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
141*da0073e9SAndroid Build Coastguard Worker            if anon_dims
142*da0073e9SAndroid Build Coastguard Worker            else "    return tensor\n"
143*da0073e9SAndroid Build Coastguard Worker        )
144*da0073e9SAndroid Build Coastguard Worker    )
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    exec(custom_rearrange_callable_code)
147*da0073e9SAndroid Build Coastguard Worker    return locals()[custom_rearrange_callable_name]
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Workerdef rearrange(
151*da0073e9SAndroid Build Coastguard Worker    tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
152*da0073e9SAndroid Build Coastguard Worker    pattern: str,
153*da0073e9SAndroid Build Coastguard Worker    **axes_lengths: int,
154*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
155*da0073e9SAndroid Build Coastguard Worker    r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
156*da0073e9SAndroid Build Coastguard Worker    tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
157*da0073e9SAndroid Build Coastguard Worker    stack, concatenate and other operations.
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    See: https://einops.rocks/api/rearrange/
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    Args:
162*da0073e9SAndroid Build Coastguard Worker        tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
163*da0073e9SAndroid Build Coastguard Worker        pattern (str): the rearrangement pattern
164*da0073e9SAndroid Build Coastguard Worker        axes_lengths (int): any additional length specifications for dimensions
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    Returns:
167*da0073e9SAndroid Build Coastguard Worker        Tensor: the rearranged tensor
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    Examples:
170*da0073e9SAndroid Build Coastguard Worker        >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
171*da0073e9SAndroid Build Coastguard Worker        >>> images = torch.randn((32, 30, 40, 3))
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        >>> # stack along first (batch) axis, output is a single array
174*da0073e9SAndroid Build Coastguard Worker        >>> rearrange(images, 'b h w c -> b h w c').shape
175*da0073e9SAndroid Build Coastguard Worker        torch.Size([32, 30, 40, 3])
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        >>> # concatenate images along height (vertical axis), 960 = 32 * 30
178*da0073e9SAndroid Build Coastguard Worker        >>> rearrange(images, 'b h w c -> (b h) w c').shape
179*da0073e9SAndroid Build Coastguard Worker        torch.Size([960, 40, 3])
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        >>> # concatenated images along horizontal axis, 1280 = 32 * 40
182*da0073e9SAndroid Build Coastguard Worker        >>> rearrange(images, 'b h w c -> h (b w) c').shape
183*da0073e9SAndroid Build Coastguard Worker        torch.Size([30, 1280, 3])
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        >>> # reordered axes to "b c h w" format for deep learning
186*da0073e9SAndroid Build Coastguard Worker        >>> rearrange(images, 'b h w c -> b c h w').shape
187*da0073e9SAndroid Build Coastguard Worker        torch.Size([32, 3, 30, 40])
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        >>> # flattened each image into a vector, 3600 = 30 * 40 * 3
190*da0073e9SAndroid Build Coastguard Worker        >>> rearrange(images, 'b h w c -> b (c h w)').shape
191*da0073e9SAndroid Build Coastguard Worker        torch.Size([32, 3600])
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
194*da0073e9SAndroid Build Coastguard Worker        >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
195*da0073e9SAndroid Build Coastguard Worker        torch.Size([128, 15, 20, 3])
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        >>> # space-to-depth operation
198*da0073e9SAndroid Build Coastguard Worker        >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
199*da0073e9SAndroid Build Coastguard Worker        torch.Size([32, 15, 20, 12])
200*da0073e9SAndroid Build Coastguard Worker    """
201*da0073e9SAndroid Build Coastguard Worker    if not isinstance(tensor, torch.Tensor):
202*da0073e9SAndroid Build Coastguard Worker        tensor = torch.stack(tensor)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    rearrange_callable = _create_rearrange_callable(
205*da0073e9SAndroid Build Coastguard Worker        tensor.ndim, pattern, **axes_lengths
206*da0073e9SAndroid Build Coastguard Worker    )
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    return rearrange_callable(tensor)
209