xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_ops/_view_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3# Copyright (c) Meta Platforms, Inc. and affiliates
4from dataclasses import dataclass
5from typing import (
6    Callable,
7    cast,
8    Dict,
9    Iterable,
10    List,
11    Optional,
12    Sequence,
13    Set,
14    Tuple,
15    Union,
16)
17
18import torch
19from torch import Tensor
20from torch.distributed.device_mesh import DeviceMesh
21from torch.distributed.tensor._dtensor_spec import DTensorSpec
22from torch.distributed.tensor._op_schema import (
23    OpSchema,
24    OpStrategy,
25    PlacementStrategy,
26    RuntimeSchemaInfo,
27    StrategyType,
28)
29from torch.distributed.tensor._ops.utils import (
30    generate_redistribute_costs,
31    normalize_dim,
32    normalize_dims,
33    prod,
34    register_op_strategy,
35)
36from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
37
38
39aten = torch.ops.aten
40
41Shape = Tuple[int, ...]
42
43
44@dataclass
45class DimSpec:
46    """Specifies how an output dimension maps to an input dimension."""
47
48    def inputs(self) -> Iterable["DimSpec"]:
49        return ()
50
51
52# Rules that map each dimension of the output to dimensions of the input tensor
53DimMap = Tuple[DimSpec, ...]
54
55
56@dataclass
57class Singleton(DimSpec):
58    """Output dimension is a singleton."""
59
60
61@dataclass
62class InputDim(DimSpec):
63    """Output dimension maps directly to an input dimension."""
64
65    input_dim: int
66
67
68@dataclass
69class Broadcast(DimSpec):
70    """Output is the broadcast of a singleton input dimension."""
71
72    dim: DimSpec
73    dim_size: int
74
75    @classmethod
76    def new(cls, dim: DimSpec, dim_size: int) -> DimSpec:
77        return Broadcast(dim, dim_size)
78
79    def inputs(self) -> Iterable[DimSpec]:
80        return (self.dim,)
81
82
83@dataclass
84class NewDim(DimSpec):
85    """This is a new dimension created by the op."""
86
87    size: int
88
89    @classmethod
90    def new(cls, size: int) -> DimSpec:
91        return Singleton() if size == 1 else NewDim(size)
92
93
94@dataclass
95class Repeat(DimSpec):
96    """Output dimension is the input dimension repeated n-times."""
97
98    input_dim: DimSpec
99    times: int
100
101    @classmethod
102    def new(cls, dim: DimSpec, times: int) -> DimSpec:
103        if times == 1:
104            return dim
105        elif isinstance(dim, Singleton):
106            # repeating a singleton is the same as broadcasting it
107            return Broadcast(dim, times)
108        else:
109            return Repeat(dim, times)
110
111    def inputs(self) -> Iterable[DimSpec]:
112        return (self.input_dim,)
113
114
115@dataclass
116class Flatten(DimSpec):
117    """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output."""
118
119    input_dims: Sequence[DimSpec]
120
121    @classmethod
122    def new(cls, dims: Sequence[DimSpec]) -> DimSpec:
123        if len(dims) == 0:
124            # flattening a scalar leads to a singleton
125            return Singleton()
126        elif len(dims) == 1:
127            # flattening a single dimension is no-op
128            return dims[0]
129        else:
130            return Flatten(dims)
131
132    def inputs(self) -> Iterable[DimSpec]:
133        return self.input_dims
134
135
136@dataclass
137class Split(DimSpec):
138    """
139    This dimension is a member of a decomposition of the input dim.
140
141    Note that input_dim itself could be a Flattened set of input dims.
142    """
143
144    input_dim: DimSpec
145    group_shape: Shape
146    split_id: int
147
148    @classmethod
149    def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec:
150        assert len(group_shape) > 0
151        if len(group_shape) == 1:
152            # not really a group, just return the input dim back
153            assert idx == 0
154            return dim
155        elif group_shape[idx] == 1:
156            return Singleton()
157        else:
158            # remove singletons from group
159            # group_mapping = [(new_index, (shape, old_index)) ...]
160            group_mapping = list(
161                enumerate((s, i) for i, s in enumerate(group_shape) if s != 1)
162            )
163            new_group_shape = tuple(m[1][0] for m in group_mapping)
164            new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0]
165            return Split(dim, new_group_shape, new_idx)
166
167    def inputs(self) -> Iterable[DimSpec]:
168        return (self.input_dim,)
169
170
171def dim_pad_left(ndim: int, min_dims: int) -> DimMap:
172    return (Singleton(),) * max(0, min_dims - ndim) + tuple(
173        InputDim(i) for i in range(ndim)
174    )
175
176
177def dim_atleast_3d(ndim: int) -> DimMap:
178    if ndim == 0:
179        return (Singleton(), Singleton(), Singleton())
180    elif ndim == 1:
181        return (Singleton(), InputDim(0), Singleton())
182    elif ndim == 2:
183        return (InputDim(0), InputDim(1), Singleton())
184    else:
185        return tuple(InputDim(i) for i in range(ndim))
186
187
188def expand(input_shape: Shape, shape: Shape) -> DimMap:
189    """Implement broadcast on multiple dimensions."""
190    assert len(shape) >= len(input_shape)
191
192    # 1. create padded input dimensions
193    padded_input = dim_pad_left(len(input_shape), len(shape))
194    # 2. check that input shapes are compatible
195    mapping = []
196    for p, desired_s in zip(padded_input, shape):
197        if isinstance(p, Singleton):
198            actual_s = 1
199            assert desired_s >= 0
200        else:
201            assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}"
202            actual_s = input_shape[p.input_dim]
203            assert actual_s == 1 or desired_s == -1 or desired_s == actual_s
204        mapping.append(
205            p
206            if desired_s in (1, -1) or desired_s == actual_s
207            else Broadcast.new(p, desired_s)
208        )
209    return tuple(mapping)
210
211
212def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape:
213    if isinstance(sizes[0], int):
214        return cast(Shape, sizes)
215    elif len(sizes) == 1:
216        return sizes[0]
217    else:
218        raise RuntimeError("Size must be int... or tuple")
219
220
221def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
222    if ndim == 0:
223        return (Singleton(),)
224    elif ndim == 1:
225        return (InputDim(0),)
226    else:
227        # only flattening dims from start_dim to end_dim (inclusive)
228        # other dims are passed through
229        if end_dim < 0:
230            end_dim += ndim
231        results: List[DimSpec] = [InputDim(i) for i in range(start_dim)]
232        results.append(
233            Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1)))
234        )
235        results.extend([InputDim(i) for i in range(end_dim + 1, ndim)])
236        return tuple(results)
237
238
239def dim_movedim(
240    ndim: int,
241    input: Union[int, Sequence[int]],
242    destination: Union[int, Sequence[int]],
243) -> DimMap:
244    input = normalize_dims(input, ndim)
245    destination = normalize_dims(destination, ndim)
246
247    assert len(input) == len(destination)
248    input_set = set(input)
249    assert len(input_set) == len(input), "Found repeated input dims"
250    assert len(set(destination)) == len(destination), "Found repeated output dims"
251    assert max(input) < ndim
252    assert max(destination) < ndim
253
254    dest = [-1] * ndim
255    for i, d in zip(input, destination):
256        dest[d] = i
257
258    unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set)
259    for i in range(ndim):
260        if dest[i] == -1:
261            dest[i] = next(unused_inputs_iter)
262
263    return tuple(InputDim(i) for i in dest)
264
265
266def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
267    sizes = normalize_sizes(sizes)
268    assert (
269        len(sizes) >= ndim
270    ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
271    pad = len(sizes) - ndim
272    return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
273        Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
274    )
275
276
277def infer_size(total_size: int, sizes: Shape) -> Shape:
278    """
279    One dimension input to view may be "-1".
280
281    Infer the size of this dimension given the total_size.
282    """
283    infers = [i for i, s in enumerate(sizes) if s == -1]
284    size = prod(sizes)
285    assert len(infers) <= 1, "can only infer one size"
286    if infers:
287        size = -size
288        missing_size = total_size // size
289        assert (
290            total_size % size == 0
291        ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
292        return tuple(s if s != -1 else missing_size for s in sizes)
293    assert size == total_size, f"sizes do not match {total_size} vs {size}"
294    return sizes
295
296
297def view_groups(from_size: Shape, to_size: Shape) -> DimMap:
298    """
299    Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension.
300
301    A view or reshape operation can be decomposed into a set of 3 types of smaller operations:
302    1) Forward a dimension from input to output
303    2) Flatten a set of dimensions into a single dimension
304    3) Split one dimension into multiple dimensions
305
306    view_groups identifies these operations and returns, for each output dimension, what
307    is operation was performed in the input dimension. For example:
308
309        view_groups([2, 3, 4], [2, 12]) -> (
310            InputDim(0),
311            Flatten((InputDim(1), InputDim(2)))
312        )
313
314    - ouptut dimension 0 maps to input dimension 0
315    - output dimension 1 maps to a flattened input dimensions 1 and 2
316
317
318        view_groups([2, 3], [3, 2]) -> (
319            Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
320            Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
321        )
322
323    - in the above, input is flattened into a single dimension and then split
324      into two separate dimensions with different sizes from the input.
325    """
326    from_nelem = prod(from_size)
327    to_size = infer_size(from_nelem, normalize_sizes(to_size))
328
329    assert from_nelem == prod(to_size), "Total view shape does not add up"
330
331    from_idx = 0
332    to_idx = 0
333    from_len = len(from_size)
334    to_len = len(to_size)
335
336    result_pp = []
337
338    while from_idx < from_len or to_idx < to_len:
339        from_group_dim, to_group_shape = [], []
340
341        if from_idx >= from_len:
342            f = 1
343        else:
344            f = from_size[from_idx]
345            from_group_dim.append(from_idx)
346            from_idx += 1
347
348        if to_idx >= to_len:
349            t = 1
350        else:
351            t = to_size[to_idx]
352            to_group_shape.append(t)
353            to_idx += 1
354
355        # if any of the groups is singleton, great, we need to backtrack though
356        if f == 1 and t != 1:
357            # produces ([1], [])
358            to_idx -= 1
359            to_group_shape = []
360        elif f != 1 and t == 1:
361            # produces ([], [1])
362            from_idx -= 1
363            from_group_dim = []
364        else:
365            # produces ([1], [1]),  ([2], [2]), ([2,3], [6])
366            while f != t:
367                if f < t:
368                    nf = from_size[from_idx]
369                    from_group_dim.append(from_idx)
370                    from_idx += 1
371                    f *= nf
372                else:
373                    nt = to_size[to_idx]
374                    to_group_shape.append(nt)
375                    to_idx += 1
376                    t *= nt
377
378        if len(to_group_shape) > 0:
379            flattened = Flatten.new(
380                tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1)
381            )
382            result_pp += [
383                Split.new(flattened, tuple(to_group_shape), i)
384                for i in range(len(to_group_shape))
385            ]
386
387    return tuple(result_pp)
388
389
390def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap:
391    if len(dims) < ndim:
392        dims = (1,) * (ndim - len(dims)) + dims
393    return dim_repeat(ndim, dims)
394
395
396def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
397    dim1 = normalize_dim(dim1, ndim)
398    dim2 = normalize_dim(dim2, ndim)
399    assert dim1 < ndim
400    assert dim2 < ndim
401    dimmap = [InputDim(i) for i in range(ndim)]
402    swapdim = dimmap[dim1]
403    dimmap[dim1] = dimmap[dim2]
404    dimmap[dim2] = swapdim
405    return tuple(dimmap)
406
407
408def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap:
409    # FIXME: this is wrong when dim=None and one of the dimensions
410    # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could
411    # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to
412    # removal of a dimension that is not actually a singleton.
413    return tuple(
414        InputDim(i)
415        for i, s in enumerate(shape)
416        if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape)))
417    )
418
419
420def dim_unsqueeze(ndim: int, dim: int) -> DimMap:
421    dims = tuple(InputDim(i) for i in range(ndim))
422    if dim < 0:
423        dim += ndim + 1
424    return dims[:dim] + (Singleton(),) + dims[dim:]
425
426
427def dim_view_as_real(shape: Shape) -> DimMap:
428    ndim = len(shape)
429    results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)]
430    # each complex number is split into two real numbers,
431    # resulting in one more dimension of size 2
432    results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0))
433    results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1))
434    return tuple(results)
435
436
437def dim_reduction(
438    ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool
439) -> DimMap:
440    """
441    General fallback for reduction ops where Partial() does not apply.
442
443    This will cause incoming tensor to be replicated on the reducing dimensions.
444    """
445    if dim_or_dims is None:
446        dim_or_dims = tuple(range(ndim))
447    if isinstance(dim_or_dims, int):
448        dim_or_dims = (dim_or_dims,)
449    dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims)
450    return tuple(
451        InputDim(i) if i not in dim_or_dims else Singleton()
452        for i in range(ndim)
453        if i not in dim_or_dims or keepdim
454    )
455
456
457dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
458    torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1),
459    torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2),
460    torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim),
461    torch.broadcast_to: lambda input, shape: expand(input.shape, shape),
462    Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
463    torch.flatten: lambda tensor: dim_flatten(tensor.ndim),
464    torch.movedim: lambda input, source, destination: dim_movedim(
465        input.ndim, source, destination
466    ),
467    torch.permute: lambda input, dims: tuple(
468        InputDim(i) for i in normalize_dims(dims, input.ndim)
469    ),
470    torch.ravel: lambda tensor: dim_flatten(tensor.ndim),
471    Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes),
472    torch.reshape: lambda input, shape: view_groups(input.shape, shape),
473    torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim),
474    torch.tile: lambda input, dims: dim_tile(input.ndim, dims),
475    torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1),
476    torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim),
477    Tensor.view: lambda input, *shape: view_groups(input.shape, shape),
478    torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2),
479    torch.view_as_real: lambda input: dim_view_as_real(input.shape),
480}
481
482
483def propagate_shape_and_sharding(
484    input_src_placements: Sequence[Placement],
485    local_in_shape: Shape,
486    rule: DimMap,
487    mesh_sizes: Shape,
488) -> Tuple[Sequence[Placement], Sequence[Placement]]:
489    """
490    Determine input target sharding and output sharding based on
491    given global tensor shape and input source sharding.
492
493    Sharding propagation follows mapped dimensions:
494    - An output dimension that maps directly to an input dimension is sharded equally
495    - An output dimension that is a flattened set of input dimensions can only be
496      sharded if only the leftmost flattened dimension is sharded.
497    - An output dimension that is a split of the input dimension can only be sharded
498      if the leftmost split size is divisible by the mesh dimension
499    """
500    assert len(input_src_placements) == len(mesh_sizes)
501    # for each input dim, for each mesh dim, provides a list of possible shardable dimensions
502    mesh_ndim = len(mesh_sizes)
503    shardable_dims: Dict[int, List[bool]] = {}
504
505    # in case an input dimension disappears (e.g. collapsing, reduction)
506    # we cannot shard in that dimension (we need a replication fall-back rule)
507    seen_input_dims: Set[int] = set()
508
509    def collect_used_inputs(cmd: DimSpec) -> None:
510        if isinstance(cmd, InputDim):
511            seen_input_dims.add(cmd.input_dim)
512        for inp in cmd.inputs():
513            collect_used_inputs(inp)
514
515    for cmd in rule:
516        collect_used_inputs(cmd)
517    for dim in range(len(local_in_shape)):
518        shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim
519
520    def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]:
521        if isinstance(cmd, InputDim):
522            return cmd
523        elif isinstance(cmd, Flatten):
524            for dim in cmd.input_dims[1:]:
525                if isinstance(dim, InputDim):
526                    shardable_dims[dim.input_dim] = [False] * mesh_ndim
527            dim0 = cmd.input_dims[0]
528            return dim0 if isinstance(dim0, InputDim) else None
529        elif isinstance(cmd, Split):
530            in_dim = get_in_dim_to_shard(cmd.input_dim)
531            out_size = cmd.group_shape[cmd.split_id]
532            if cmd.split_id == 0 and in_dim is not None:
533                # we need to check that the input dimension is divisible
534                # by the size of the submesh we're sharding it on
535                # NOTE: it would be possible to shard the same input dimension
536                # on more than one mesh dimension. In that case, the dimension
537                # needs to be divisible by the product of mesh sizes.
538                # In order to keep the problem more tractable, we will not consider
539                # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ])
540                # but we will allow it if that's the input and it's compatible
541
542                # 1. is this dimension shardable on each individual mesh dim?
543                shardable_dims[in_dim.input_dim] = [
544                    out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes
545                ]
546
547                # 2. here we special case things like [Shard(0), Shard(0)]
548                submesh_size = 1
549                for size, shard in zip(mesh_sizes, input_src_placements):
550                    if isinstance(shard, Shard) and shard.dim == in_dim:
551                        submesh_size *= size
552                assert (
553                    out_size % submesh_size == 0
554                ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
555
556            # we will only shard our first component of the split
557            return in_dim if cmd.split_id == 0 else None
558        elif isinstance(cmd, Repeat):
559            in_dim = get_in_dim_to_shard(cmd.input_dim)
560            if in_dim is not None:
561                shardable_dims[in_dim.input_dim] = [False] * mesh_ndim
562            return None
563        else:
564            return None
565
566    # for each output dim, find the corresponding input dim in terms of sharding prop
567    shard_dim_map = {}
568    for dim, cmd in enumerate(rule):
569        in_dim = get_in_dim_to_shard(cmd)
570        if in_dim is not None:
571            shard_dim_map[in_dim.input_dim] = dim
572
573    input_tgt_placements = [
574        Replicate()
575        if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim]
576        else p
577        for mesh_dim, p in enumerate(input_src_placements)
578    ]
579    output_placements = [
580        Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p
581        for p in input_tgt_placements
582    ]
583
584    return input_tgt_placements, output_placements
585
586
587def register_op_strategy_map(
588    aten_op_overload: torch._ops.OpOverload,
589    local_op_name: Callable[..., torch.Tensor],
590    schema_info: Optional[RuntimeSchemaInfo] = None,
591) -> None:
592    dim_map: Callable[..., DimMap] = dim_maps[local_op_name]
593
594    @register_op_strategy(aten_op_overload, schema_info=schema_info)
595    def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
596        rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
597        input_strategy = cast(OpStrategy, op_schema.args_schema[0])
598        global_in_shape = input_strategy.shape
599        assert global_in_shape is not None, "Shape required."
600
601        output_strategy = OpStrategy([])
602        for input_placement_strategy in input_strategy.strategies:
603            input_src_spec = input_placement_strategy.output_spec
604
605            input_tgt_placements, output_placements = propagate_shape_and_sharding(
606                input_src_spec.placements,
607                tuple(global_in_shape),
608                rules,
609                mesh.shape,
610            )
611
612            # TODO: optimize this. we shouldn't simply blindly replicate
613            #       unshardable dims ...
614            # FIXME: this can be wrong for situations where we have
615            #        [Shard(0), Shard(0)]
616            input_tgt_spec = DTensorSpec(
617                placements=tuple(input_tgt_placements),
618                mesh=input_src_spec.mesh,
619                tensor_meta=input_src_spec.tensor_meta,
620            )
621            redistribute_costs = [
622                generate_redistribute_costs(input_strategy, input_tgt_spec)
623            ]
624
625            output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements))
626            output_strategy.strategies.append(
627                PlacementStrategy(
628                    output_specs=output_spec,
629                    input_specs=(input_tgt_spec,),
630                    redistribute_cost=redistribute_costs,
631                )
632            )
633
634        return output_strategy
635
636
637register_op_strategy_map(aten.squeeze.default, torch.squeeze)
638register_op_strategy_map(
639    aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1)
640)
641register_op_strategy_map(
642    aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
643)
644register_op_strategy_map(
645    aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1)
646)
647register_op_strategy_map(
648    aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
649)
650register_op_strategy_map(
651    aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1)
652)
653register_op_strategy_map(
654    aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1)
655)
656register_op_strategy_map(
657    aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1)
658)
659register_op_strategy_map(
660    aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1)
661)
662register_op_strategy_map(
663    aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1)
664)
665register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex)
666register_op_strategy_map(aten.view_as_real.default, torch.view_as_real)
667