1from __future__ import annotations 2 3import functools 4from typing import Callable, Dict, List, Sequence, Tuple, Union 5 6import torch 7from functorch._C import dim as _C 8 9from ._parsing import ( 10 _ellipsis, 11 AnonymousAxis, 12 comma_separate, 13 parse_pattern, 14 validate_rearrange_expressions, 15) 16 17 18__all__ = ["rearrange"] 19 20dims = _C.dims 21 22 23@functools.lru_cache(256) 24def _create_rearrange_callable( 25 tensor_ndim: int, pattern: str, **axes_lengths: int 26) -> Callable[[torch.Tensor], torch.Tensor]: 27 r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions. 28 29 Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and 30 specified axes lengths, this function can be memoized. 31 32 Args: 33 tensor_ndim (int): the number of dimensions in the tensor to rearrange 34 pattern (str): the `einops`-style rearrangement pattern 35 axes_lengths (int): any additional length specifications for dimensions 36 37 Returns: 38 Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement 39 """ 40 left, right = parse_pattern(pattern, axes_lengths) 41 validate_rearrange_expressions(left, right, axes_lengths) 42 43 n_anon_dims = sum(not dim for dim in left.composition) 44 if left.has_ellipsis: 45 n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1) 46 n_named_dims = len(left.identifiers) - 1 47 48 if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim: 49 raise ValueError( 50 f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of " 51 f"dimensions in the tensor ({tensor_ndim})" 52 ) 53 else: 54 n_ellipsis_dims = 0 55 n_named_dims = len(left.identifiers) 56 57 if (pattern_ndim := len(left.composition)) != tensor_ndim: 58 raise ValueError( 59 f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in " 60 f"the tensor ({tensor_ndim})" 61 ) 62 n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims 63 64 if n_dims == 0: 65 # an identity rearrangement on a 0-dimension tensor 66 return lambda tensor: tensor 67 68 first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims)) 69 identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {} 70 anon_axes: List[AnonymousAxis] = [] 71 72 # map the left-hand side identifiers to strings representing first class dims 73 dims_i = 0 74 for dimension in left.composition: 75 if isinstance(dimension, list): 76 for identifier in dimension: 77 # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists 78 assert isinstance(identifier, str) 79 identifier_dim_map[identifier] = (first_class_dims[dims_i],) 80 dims_i += 1 81 if not dimension: 82 # unitary anonymous axis 83 anon_axis = AnonymousAxis("1") 84 identifier_dim_map[anon_axis] = (first_class_dims[dims_i],) 85 anon_axes.append(anon_axis) 86 dimension.append(anon_axis) 87 dims_i += 1 88 elif dimension == _ellipsis: 89 identifier = _ellipsis 90 identifier_dim_map[identifier] = tuple( 91 first_class_dims[dims_i + j] for j in range(n_ellipsis_dims) 92 ) 93 dims_i += n_ellipsis_dims 94 else: 95 raise ValueError(f"Unexpected dimension: {dimension}") 96 97 def composition_to_dims( 98 composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]] 99 ) -> List[Union[str, Tuple[str, ...]]]: 100 """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first 101 class dims.""" 102 dim_composition: List[Union[str, Tuple[str, ...]]] = [] 103 for dimension in composition: 104 if isinstance(dimension, list): 105 dim_composition.append( 106 tuple( 107 dim 108 for identifier in dimension 109 for dim in identifier_dim_map[identifier] 110 ) 111 ) 112 elif dimension == _ellipsis: 113 dim_composition.extend(identifier_dim_map[_ellipsis]) 114 else: 115 raise ValueError(f"Unexpected dimension: {dimension}") 116 return dim_composition 117 118 left_dims = composition_to_dims(left.composition) 119 right_dims = composition_to_dims(right.composition) 120 anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes) 121 specified_lengths = tuple( 122 (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items() 123 ) 124 125 custom_rearrange_callable_name = "do_rearrange" 126 custom_rearrange_callable_code = ( 127 ( 128 f"def {custom_rearrange_callable_name}(tensor):\n" 129 f" {comma_separate(first_class_dims)} = dims({n_dims})\n" 130 ) 131 + ( 132 "".join( 133 f" {dim}.size = {length}\n" for (dim, length) in specified_lengths 134 ) 135 if specified_lengths 136 else "" 137 ) 138 + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n" 139 + ( 140 f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n" 141 if anon_dims 142 else " return tensor\n" 143 ) 144 ) 145 146 exec(custom_rearrange_callable_code) 147 return locals()[custom_rearrange_callable_name] 148 149 150def rearrange( 151 tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], 152 pattern: str, 153 **axes_lengths: int, 154) -> torch.Tensor: 155 r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional 156 tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, 157 stack, concatenate and other operations. 158 159 See: https://einops.rocks/api/rearrange/ 160 161 Args: 162 tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange 163 pattern (str): the rearrangement pattern 164 axes_lengths (int): any additional length specifications for dimensions 165 166 Returns: 167 Tensor: the rearranged tensor 168 169 Examples: 170 >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel) 171 >>> images = torch.randn((32, 30, 40, 3)) 172 173 >>> # stack along first (batch) axis, output is a single array 174 >>> rearrange(images, 'b h w c -> b h w c').shape 175 torch.Size([32, 30, 40, 3]) 176 177 >>> # concatenate images along height (vertical axis), 960 = 32 * 30 178 >>> rearrange(images, 'b h w c -> (b h) w c').shape 179 torch.Size([960, 40, 3]) 180 181 >>> # concatenated images along horizontal axis, 1280 = 32 * 40 182 >>> rearrange(images, 'b h w c -> h (b w) c').shape 183 torch.Size([30, 1280, 3]) 184 185 >>> # reordered axes to "b c h w" format for deep learning 186 >>> rearrange(images, 'b h w c -> b c h w').shape 187 torch.Size([32, 3, 30, 40]) 188 189 >>> # flattened each image into a vector, 3600 = 30 * 40 * 3 190 >>> rearrange(images, 'b h w c -> b (c h w)').shape 191 torch.Size([32, 3600]) 192 193 >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 194 >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape 195 torch.Size([128, 15, 20, 3]) 196 197 >>> # space-to-depth operation 198 >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape 199 torch.Size([32, 15, 20, 12]) 200 """ 201 if not isinstance(tensor, torch.Tensor): 202 tensor = torch.stack(tensor) 203 204 rearrange_callable = _create_rearrange_callable( 205 tensor.ndim, pattern, **axes_lengths 206 ) 207 208 return rearrange_callable(tensor) 209