1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Sequence, Tuple, Union 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom functorch._C import dim as _C 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerfrom ._parsing import ( 10*da0073e9SAndroid Build Coastguard Worker _ellipsis, 11*da0073e9SAndroid Build Coastguard Worker AnonymousAxis, 12*da0073e9SAndroid Build Coastguard Worker comma_separate, 13*da0073e9SAndroid Build Coastguard Worker parse_pattern, 14*da0073e9SAndroid Build Coastguard Worker validate_rearrange_expressions, 15*da0073e9SAndroid Build Coastguard Worker) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker__all__ = ["rearrange"] 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerdims = _C.dims 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(256) 24*da0073e9SAndroid Build Coastguard Workerdef _create_rearrange_callable( 25*da0073e9SAndroid Build Coastguard Worker tensor_ndim: int, pattern: str, **axes_lengths: int 26*da0073e9SAndroid Build Coastguard Worker) -> Callable[[torch.Tensor], torch.Tensor]: 27*da0073e9SAndroid Build Coastguard Worker r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions. 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and 30*da0073e9SAndroid Build Coastguard Worker specified axes lengths, this function can be memoized. 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker Args: 33*da0073e9SAndroid Build Coastguard Worker tensor_ndim (int): the number of dimensions in the tensor to rearrange 34*da0073e9SAndroid Build Coastguard Worker pattern (str): the `einops`-style rearrangement pattern 35*da0073e9SAndroid Build Coastguard Worker axes_lengths (int): any additional length specifications for dimensions 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker Returns: 38*da0073e9SAndroid Build Coastguard Worker Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement 39*da0073e9SAndroid Build Coastguard Worker """ 40*da0073e9SAndroid Build Coastguard Worker left, right = parse_pattern(pattern, axes_lengths) 41*da0073e9SAndroid Build Coastguard Worker validate_rearrange_expressions(left, right, axes_lengths) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker n_anon_dims = sum(not dim for dim in left.composition) 44*da0073e9SAndroid Build Coastguard Worker if left.has_ellipsis: 45*da0073e9SAndroid Build Coastguard Worker n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1) 46*da0073e9SAndroid Build Coastguard Worker n_named_dims = len(left.identifiers) - 1 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim: 49*da0073e9SAndroid Build Coastguard Worker raise ValueError( 50*da0073e9SAndroid Build Coastguard Worker f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of " 51*da0073e9SAndroid Build Coastguard Worker f"dimensions in the tensor ({tensor_ndim})" 52*da0073e9SAndroid Build Coastguard Worker ) 53*da0073e9SAndroid Build Coastguard Worker else: 54*da0073e9SAndroid Build Coastguard Worker n_ellipsis_dims = 0 55*da0073e9SAndroid Build Coastguard Worker n_named_dims = len(left.identifiers) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker if (pattern_ndim := len(left.composition)) != tensor_ndim: 58*da0073e9SAndroid Build Coastguard Worker raise ValueError( 59*da0073e9SAndroid Build Coastguard Worker f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in " 60*da0073e9SAndroid Build Coastguard Worker f"the tensor ({tensor_ndim})" 61*da0073e9SAndroid Build Coastguard Worker ) 62*da0073e9SAndroid Build Coastguard Worker n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker if n_dims == 0: 65*da0073e9SAndroid Build Coastguard Worker # an identity rearrangement on a 0-dimension tensor 66*da0073e9SAndroid Build Coastguard Worker return lambda tensor: tensor 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims)) 69*da0073e9SAndroid Build Coastguard Worker identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {} 70*da0073e9SAndroid Build Coastguard Worker anon_axes: List[AnonymousAxis] = [] 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker # map the left-hand side identifiers to strings representing first class dims 73*da0073e9SAndroid Build Coastguard Worker dims_i = 0 74*da0073e9SAndroid Build Coastguard Worker for dimension in left.composition: 75*da0073e9SAndroid Build Coastguard Worker if isinstance(dimension, list): 76*da0073e9SAndroid Build Coastguard Worker for identifier in dimension: 77*da0073e9SAndroid Build Coastguard Worker # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists 78*da0073e9SAndroid Build Coastguard Worker assert isinstance(identifier, str) 79*da0073e9SAndroid Build Coastguard Worker identifier_dim_map[identifier] = (first_class_dims[dims_i],) 80*da0073e9SAndroid Build Coastguard Worker dims_i += 1 81*da0073e9SAndroid Build Coastguard Worker if not dimension: 82*da0073e9SAndroid Build Coastguard Worker # unitary anonymous axis 83*da0073e9SAndroid Build Coastguard Worker anon_axis = AnonymousAxis("1") 84*da0073e9SAndroid Build Coastguard Worker identifier_dim_map[anon_axis] = (first_class_dims[dims_i],) 85*da0073e9SAndroid Build Coastguard Worker anon_axes.append(anon_axis) 86*da0073e9SAndroid Build Coastguard Worker dimension.append(anon_axis) 87*da0073e9SAndroid Build Coastguard Worker dims_i += 1 88*da0073e9SAndroid Build Coastguard Worker elif dimension == _ellipsis: 89*da0073e9SAndroid Build Coastguard Worker identifier = _ellipsis 90*da0073e9SAndroid Build Coastguard Worker identifier_dim_map[identifier] = tuple( 91*da0073e9SAndroid Build Coastguard Worker first_class_dims[dims_i + j] for j in range(n_ellipsis_dims) 92*da0073e9SAndroid Build Coastguard Worker ) 93*da0073e9SAndroid Build Coastguard Worker dims_i += n_ellipsis_dims 94*da0073e9SAndroid Build Coastguard Worker else: 95*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unexpected dimension: {dimension}") 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker def composition_to_dims( 98*da0073e9SAndroid Build Coastguard Worker composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]] 99*da0073e9SAndroid Build Coastguard Worker ) -> List[Union[str, Tuple[str, ...]]]: 100*da0073e9SAndroid Build Coastguard Worker """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first 101*da0073e9SAndroid Build Coastguard Worker class dims.""" 102*da0073e9SAndroid Build Coastguard Worker dim_composition: List[Union[str, Tuple[str, ...]]] = [] 103*da0073e9SAndroid Build Coastguard Worker for dimension in composition: 104*da0073e9SAndroid Build Coastguard Worker if isinstance(dimension, list): 105*da0073e9SAndroid Build Coastguard Worker dim_composition.append( 106*da0073e9SAndroid Build Coastguard Worker tuple( 107*da0073e9SAndroid Build Coastguard Worker dim 108*da0073e9SAndroid Build Coastguard Worker for identifier in dimension 109*da0073e9SAndroid Build Coastguard Worker for dim in identifier_dim_map[identifier] 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker elif dimension == _ellipsis: 113*da0073e9SAndroid Build Coastguard Worker dim_composition.extend(identifier_dim_map[_ellipsis]) 114*da0073e9SAndroid Build Coastguard Worker else: 115*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unexpected dimension: {dimension}") 116*da0073e9SAndroid Build Coastguard Worker return dim_composition 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker left_dims = composition_to_dims(left.composition) 119*da0073e9SAndroid Build Coastguard Worker right_dims = composition_to_dims(right.composition) 120*da0073e9SAndroid Build Coastguard Worker anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes) 121*da0073e9SAndroid Build Coastguard Worker specified_lengths = tuple( 122*da0073e9SAndroid Build Coastguard Worker (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items() 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker custom_rearrange_callable_name = "do_rearrange" 126*da0073e9SAndroid Build Coastguard Worker custom_rearrange_callable_code = ( 127*da0073e9SAndroid Build Coastguard Worker ( 128*da0073e9SAndroid Build Coastguard Worker f"def {custom_rearrange_callable_name}(tensor):\n" 129*da0073e9SAndroid Build Coastguard Worker f" {comma_separate(first_class_dims)} = dims({n_dims})\n" 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker + ( 132*da0073e9SAndroid Build Coastguard Worker "".join( 133*da0073e9SAndroid Build Coastguard Worker f" {dim}.size = {length}\n" for (dim, length) in specified_lengths 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker if specified_lengths 136*da0073e9SAndroid Build Coastguard Worker else "" 137*da0073e9SAndroid Build Coastguard Worker ) 138*da0073e9SAndroid Build Coastguard Worker + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n" 139*da0073e9SAndroid Build Coastguard Worker + ( 140*da0073e9SAndroid Build Coastguard Worker f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n" 141*da0073e9SAndroid Build Coastguard Worker if anon_dims 142*da0073e9SAndroid Build Coastguard Worker else " return tensor\n" 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker ) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker exec(custom_rearrange_callable_code) 147*da0073e9SAndroid Build Coastguard Worker return locals()[custom_rearrange_callable_name] 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Workerdef rearrange( 151*da0073e9SAndroid Build Coastguard Worker tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], 152*da0073e9SAndroid Build Coastguard Worker pattern: str, 153*da0073e9SAndroid Build Coastguard Worker **axes_lengths: int, 154*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 155*da0073e9SAndroid Build Coastguard Worker r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional 156*da0073e9SAndroid Build Coastguard Worker tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, 157*da0073e9SAndroid Build Coastguard Worker stack, concatenate and other operations. 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker See: https://einops.rocks/api/rearrange/ 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker Args: 162*da0073e9SAndroid Build Coastguard Worker tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange 163*da0073e9SAndroid Build Coastguard Worker pattern (str): the rearrangement pattern 164*da0073e9SAndroid Build Coastguard Worker axes_lengths (int): any additional length specifications for dimensions 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker Returns: 167*da0073e9SAndroid Build Coastguard Worker Tensor: the rearranged tensor 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker Examples: 170*da0073e9SAndroid Build Coastguard Worker >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel) 171*da0073e9SAndroid Build Coastguard Worker >>> images = torch.randn((32, 30, 40, 3)) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker >>> # stack along first (batch) axis, output is a single array 174*da0073e9SAndroid Build Coastguard Worker >>> rearrange(images, 'b h w c -> b h w c').shape 175*da0073e9SAndroid Build Coastguard Worker torch.Size([32, 30, 40, 3]) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker >>> # concatenate images along height (vertical axis), 960 = 32 * 30 178*da0073e9SAndroid Build Coastguard Worker >>> rearrange(images, 'b h w c -> (b h) w c').shape 179*da0073e9SAndroid Build Coastguard Worker torch.Size([960, 40, 3]) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker >>> # concatenated images along horizontal axis, 1280 = 32 * 40 182*da0073e9SAndroid Build Coastguard Worker >>> rearrange(images, 'b h w c -> h (b w) c').shape 183*da0073e9SAndroid Build Coastguard Worker torch.Size([30, 1280, 3]) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker >>> # reordered axes to "b c h w" format for deep learning 186*da0073e9SAndroid Build Coastguard Worker >>> rearrange(images, 'b h w c -> b c h w').shape 187*da0073e9SAndroid Build Coastguard Worker torch.Size([32, 3, 30, 40]) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker >>> # flattened each image into a vector, 3600 = 30 * 40 * 3 190*da0073e9SAndroid Build Coastguard Worker >>> rearrange(images, 'b h w c -> b (c h w)').shape 191*da0073e9SAndroid Build Coastguard Worker torch.Size([32, 3600]) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 194*da0073e9SAndroid Build Coastguard Worker >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape 195*da0073e9SAndroid Build Coastguard Worker torch.Size([128, 15, 20, 3]) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker >>> # space-to-depth operation 198*da0073e9SAndroid Build Coastguard Worker >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape 199*da0073e9SAndroid Build Coastguard Worker torch.Size([32, 15, 20, 12]) 200*da0073e9SAndroid Build Coastguard Worker """ 201*da0073e9SAndroid Build Coastguard Worker if not isinstance(tensor, torch.Tensor): 202*da0073e9SAndroid Build Coastguard Worker tensor = torch.stack(tensor) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker rearrange_callable = _create_rearrange_callable( 205*da0073e9SAndroid Build Coastguard Worker tensor.ndim, pattern, **axes_lengths 206*da0073e9SAndroid Build Coastguard Worker ) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker return rearrange_callable(tensor) 209