xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import cast, List, Sequence, Tuple
2
3import torch
4import torch.distributed.tensor._api as dtensor
5from torch._prims_common import ShapeType
6from torch.distributed.device_mesh import DeviceMesh
7from torch.distributed.tensor._dtensor_spec import DTensorSpec
8from torch.distributed.tensor.placement_types import (
9    _StridedShard,
10    Partial,
11    Placement,
12    Replicate,
13    Shard,
14)
15
16
17# TODO: audit existing code base to see if we can safely remove this API.
18def compute_local_shape(
19    global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
20) -> Tuple[int, ...]:
21    """
22    Compute the shape of a local shard of the given DTensor on its current
23    coordinate of the mesh.
24    """
25    my_coordinate = mesh.get_coordinate()
26
27    if my_coordinate is None:
28        # if rank not in the mesh, return empty shape
29        return (0,)
30    else:
31        local_shape = list(global_shape)  # start with global shape
32        ndim = len(global_shape)
33        for idx, placement in enumerate(placements):
34            mesh_dim_size = mesh.size(idx)
35            if isinstance(placement, Shard):
36                shard_dim = placement.dim
37                assert (
38                    shard_dim < ndim
39                ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}"
40                local_shard_size, _ = placement._local_shard_size_on_dim(
41                    local_shape[shard_dim], mesh_dim_size, my_coordinate[idx]
42                )
43                assert isinstance(local_shard_size, int)
44                local_shape[shard_dim] = local_shard_size
45
46        return tuple(local_shape)
47
48
49def compute_local_shape_and_global_offset(
50    global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
51) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
52    """
53    Compute the local tensor shape and the global offsets into the original tensor
54    of a DTensor on its current global rank. This is useful for checkpointing purpose.
55
56    Example (2 host with 4GPUs each):
57    # Below is a DeviceMesh with mesh_shape of (2, 4)
58    mesh = DeviceMesh(device_type="cuda",
59                        mesh=[
60                        [0, 1, 2, 3],
61                        [4, 5, 6, 7]
62                        ],
63    )
64
65    Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh
66    with a placements of [Shard(0), Shard(0)].
67    The local shape and global offset will be as follows:
68    rank0 -- local_shape:[1, 4], global_offset:[0, 0]
69    rank1 -- local_shape:[1, 4], global_offset:[1, 0]
70    rank2 -- local_shape:[1, 4], global_offset:[2, 0]
71    rank5 -- local_shape:[1, 4], global_offset:[5, 0]
72    rank3 -- local_shape:[1, 4], global_offset:[3, 0]
73    rank4 -- local_shape:[1, 4], global_offset:[4, 0]
74    rank6 -- local_shape:[1, 4], global_offset:[6, 0]
75    rank7 -- local_shape:[1, 4], global_offset:[7, 0]
76
77    Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with
78    a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks.
79    The local shape and global offset will be as follows:
80    rank0 -- local_shape:[1,], global_offset:[0,]
81    rank1 -- local_shape:[1,], global_offset:[1,]
82    rank2 -- local_shape:[0,], global_offset:[2,]
83    rank5 -- local_shape:[0,], global_offset:[2,]
84    rank3 -- local_shape:[0,], global_offset:[2,]
85    rank4 -- local_shape:[0,], global_offset:[2,]
86    rank6 -- local_shape:[0,], global_offset:[2,]
87    rank7 -- local_shape:[0,], global_offset:[2,]
88    """
89    my_coordinate = mesh.get_coordinate()
90
91    if my_coordinate is None:
92        # if rank not in the mesh, return empty offset
93        return ((), ())
94    else:
95        local_shape = list(global_shape)
96        global_offset = [0] * len(global_shape)
97        shard_idx_stride_by_mesh_dim = [
98            [0] * mesh.ndim for _ in range(len(global_shape))
99        ]  # index by (shard_dim, mesh_dim)
100        num_shards_by_tensor_dim = [1] * len(global_shape)
101
102        for idx, placement in enumerate(placements):
103            mesh_dim_size = mesh.size(idx)
104            if isinstance(placement, Shard):
105                shard_dim = placement.dim
106                local_offset = [0] * len(global_shape)
107                assert shard_dim < len(
108                    local_shape
109                ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
110                shard_size, shard_offset = placement._local_shard_size_on_dim(
111                    local_shape[shard_dim],
112                    mesh_dim_size,
113                    my_coordinate[idx],
114                    return_offset=True,
115                )
116
117                local_shape[shard_dim] = shard_size
118                local_offset[shard_dim] = shard_offset
119
120                # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
121                # it means that this dimension has been already sharded in previous placement.
122                # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
123                # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
124                if global_offset[shard_dim] <= local_offset[shard_dim]:
125                    global_offset[shard_dim] = local_offset[shard_dim]
126                else:
127                    global_offset[shard_dim] += local_offset[shard_dim]
128
129                num_shards_by_tensor_dim[shard_dim] *= mesh_dim_size
130
131        # NOTE: the offset compute relies on the local shard index and it has no
132        # problem when strided sharding is not present. To correctly compute, we assume
133        # that the ``_StridedShard.split_factor`` field encodes how many partitions
134        # each local tensor will be further split into when sharding on higher mesh
135        # dimensions. However, this number is only correct if the DTensor is not
136        # sharded after the strided sharding completes. For example,
137        # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements
138        # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on
139        # device mesh dim-2, and last on mesh dim-1. We define the
140        # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding
141        # part because strided sharding happens on mesh dim-1 and it was caused by
142        # the fact that sharding on dim-2 occurred ahead. In this case, there's no
143        # further sharding after this strided sharding part and ``split_factor``
144        # correctly encodes the number. Another example is
145        # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's
146        # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh
147        # dim-2. This violates our assumption that no further sharding shall occur
148        # after the strided sharding part and ``split_factor`` won't correctly
149        # encode the number of further split. So far, the only case where _StridedShard
150        # placement would appear is FSDP2 + TP on 2D mesh and the above case could only
151        # happen on mesh of 3 or more dimensions.
152        # TODO: change this function to correctly address this.
153        # TODO: this logic can be applied to contiguous sharding as well
154        strided_sharding = any(isinstance(p, _StridedShard) for p in placements)
155        if strided_sharding:
156            strided_part_seen = [False] * len(global_shape)
157            strided_part_end = [False] * len(global_shape)
158            for idx, placement in enumerate(placements):
159                mesh_dim_size = mesh.size(idx)
160                if isinstance(placement, Shard):
161                    shard_dim = placement.dim
162
163                    if strided_part_end[shard_dim]:
164                        raise NotImplementedError(
165                            f"Strided sharding does not allow Shard() to appear after "
166                            f"the strided part has ended. {placement} at idx {idx} in "
167                            f"{placements} violates this assumption."
168                        )
169
170                    if strided_part_seen[shard_dim]:
171                        strided_part_end[shard_dim] = True
172
173                    if isinstance(placement, _StridedShard):
174                        strided_part_seen[shard_dim] = True
175                        shard_idx_stride_by_mesh_dim[shard_dim][
176                            idx
177                        ] = num_shards_by_tensor_dim[shard_dim] // (
178                            placement.split_factor * mesh_dim_size
179                        )
180                    else:
181                        num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
182                        shard_idx_stride_by_mesh_dim[shard_dim][
183                            idx
184                        ] = num_shards_by_tensor_dim[shard_dim]
185
186            shard_idx = [
187                sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
188                for shard_dim, shard_idx_stride in enumerate(
189                    shard_idx_stride_by_mesh_dim
190                )
191            ]
192
193            global_offset = [x * y for x, y in zip(local_shape, shard_idx)]
194
195        return tuple(local_shape), tuple(global_offset)
196
197
198def compute_global_tensor_info(
199    tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
200) -> Tuple[List[int], List[int]]:
201    """
202    Compute the global size and stride of a DTensor from the given local tensor.
203    The local size is multiplited by `world_size` per Sharding dim.
204    The local stride is multiplited by `world_size` per Sharding dim, as long as the
205    dimension is outside sharding dim.
206
207    For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).
208    If the DTensor placements are [Shard(2)] and world_size is 2;
209    then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).
210
211    Args:
212        tensor (:class:`torch.Tensor`):
213            Local tensor which DTensor will be constructed from.
214        mesh (:class:`DeviceMesh`):
215            Object which describes the mesh topology
216            of devices for the DTensor.
217        placements (Sequence[:class:`Placement`]]):
218            The attribute of the DTensor that describes its layout
219            on the mesh topology.
220
221    Return:
222        tensor_shape: A List of int which specifies the size of DTensor which build
223            on top of the local tensor.
224        tensor_stride: A List of int which specifies the stride of DTensor.
225    """
226    tensor_shape = list(tensor.size())
227    tensor_stride = list(tensor.stride())
228    for idx, placement in enumerate(placements):
229        mesh_dim_size = mesh.size(idx)
230        if placement.is_shard():
231            shard_placement = cast(Shard, placement)
232            if shard_placement.dim < 0:
233                raise AssertionError(
234                    "Shard placements should have negative dims normalized in "
235                    f"the user-facing APIs: {shard_placement}"
236                )
237            shard_dim = shard_placement.dim
238
239            assert (
240                shard_dim < tensor.ndim
241            ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
242
243            local_dim_size = tensor_shape[shard_dim]
244            tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
245
246            # recover tensor stride by modifying the stride that larger than
247            # the current stride on the shard_dim
248            for i in range(len(tensor_stride)):
249                if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
250                    # rescale the stride by the shard size
251                    tensor_stride[i] = tensor_stride[i] * mesh_dim_size
252        elif not isinstance(placement, (Replicate, Partial)):
253            raise RuntimeError(f"placement type {type(placement)} not supported!")
254    return tensor_shape, tensor_stride
255
256
257def try_find_mesh_from_args(
258    op_call: torch._ops.OpOverload, args: Sequence[object]
259) -> DeviceMesh:
260    """
261    Find the device mesh object from args.
262    It returns None if no mesh is found.
263    NOTE: we can optimize this search if needed
264    """
265    for arg in args:
266        if isinstance(arg, (dtensor.DTensor, DTensorSpec)):
267            return arg.device_mesh
268        elif (
269            isinstance(arg, (list, tuple))
270            and len(arg) > 0
271            and isinstance(arg[0], (dtensor.DTensor, DTensorSpec))
272        ):
273            return arg[0].device_mesh
274
275    raise ValueError(f"Cannot find device mesh from args for op : {op_call}.")
276
277
278def compute_local_stride(
279    global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
280) -> Tuple[int, ...]:
281    """
282    Compute the stride of a local tensor shard, given the global stride of the DTensor.
283    NOTE: Currently this function is assuming the DTensor is evenly shardable.
284    """
285    stride_divisors = [1] * len(global_stride)
286    for mesh_idx, p in enumerate(placements):
287        if p.is_shard():
288            i = cast(Shard, p).dim
289            # tensor dimension i is sharded on mesh dimension mesh_idx,
290            # so we need to divide all the strides larger than stride[i]
291            # (by the submesh size)
292            for j in range(len(global_stride)):
293                if global_stride[j] > global_stride[i]:
294                    stride_divisors[j] *= mesh.size(mesh_idx)
295    return tuple(
296        global_stride[i] // stride_divisors[i] for i in range(len(global_stride))
297    )
298
299
300def normalize_to_torch_size(size) -> torch.Size:  # type: ignore[no-untyped-def]
301    """
302    Unify variable types of size argument to torch.Size
303    Acceptable types include:
304        int, Sequence[int], Tuple[int], Tuple[Sequence[int]],
305        or torch.Size
306    """
307    if isinstance(size, torch.Size):
308        return size
309
310    if isinstance(size, int):
311        torch_size = [size]
312    elif len(size) == 1 and isinstance(size[0], Sequence):
313        torch_size = list(size[0])
314    else:
315        torch_size = list(size)
316    return torch.Size(torch_size)
317