xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/fsdp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from typing import Any, cast, List, Optional, Tuple
4
5import torch
6import torch.distributed as dist
7import torch.distributed._shard.sharding_spec as shard_spec
8import torch.distributed.distributed_c10d as c10d
9from torch.distributed._shard.sharded_tensor import (
10    Shard,
11    ShardedTensor,
12    ShardedTensorMetadata,
13    TensorProperties,
14)
15from torch.distributed._shard.sharding_spec import ShardMetadata
16from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
17from torch.distributed.device_mesh import _mesh_resources
18from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
19from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
20from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
21from torch.distributed.remote_device import _remote_device
22from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
23from torch.distributed.tensor.parallel._data_parallel_utils import (
24    _flatten_tensor,
25    _unflatten_tensor,
26)
27
28
29__all__ = ["DTensorExtensions"]
30
31
32def _get_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
33    device_mesh = tensor.device_mesh
34    assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
35
36    placement = tensor.placements[0]
37    offsets = [0] * len(tensor.size())
38    num_chunks = device_mesh.size(mesh_dim=0)
39
40    if tensor.placements[0].is_shard():
41        shard_dim = cast(DShard, placement).dim
42        chunk_size = tensor.size(shard_dim) // num_chunks
43        offsets[shard_dim] = chunk_size
44
45    return (torch.Size(offsets), tensor._local_tensor.size())
46
47
48def _get_box_for(tensor: DTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
49    offsets, size = _get_box(tensor)
50    return (torch.Size([val * idx for val in offsets]), size)
51
52
53def _get_local_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
54    device_mesh = tensor.device_mesh
55    coord = device_mesh.get_coordinate()
56    assert coord is not None
57    return _get_box_for(tensor, coord[0])
58
59
60def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata:
61    mesh = dt.device_mesh
62    assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
63
64    offsets, sizes = _get_local_box(dt)
65    return ShardMetadata(
66        shard_offsets=list(offsets),
67        shard_sizes=list(sizes),
68        placement=f"rank:{current_rank}/{dt._local_tensor.device}",
69    )
70
71
72def _create_sharded_tensor_md_from_dt(
73    dt: DTensor, dt_pg: c10d.ProcessGroup
74) -> ShardedTensorMetadata:
75    # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
76    # and yet has only one valid shard for the current rank.
77
78    shards_md = []
79    my_rank = dist.get_rank(dt_pg)
80    scapegoat_rank = 0 if my_rank > 0 else 1
81
82    if dt.placements[0].is_shard():
83        shard_count = dt_pg.size()
84    else:
85        shard_count = 1
86
87    for i in range(shard_count):
88        offsets, sizes = _get_box_for(dt, i)
89        shards_md.append(
90            ShardMetadata(
91                shard_offsets=list(offsets),
92                shard_sizes=list(sizes),
93                placement=(
94                    f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}"
95                ),
96            )
97        )
98
99    return ShardedTensorMetadata(
100        shards_metadata=shards_md,
101        size=dt.size(),
102        tensor_properties=TensorProperties(
103            dtype=dt.dtype,
104            layout=dt.layout,
105            requires_grad=dt.requires_grad,
106            # ignore memory_format and pin_memory as those are not supported by DT
107        ),
108    )
109
110
111def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup:
112    mesh = dt.device_mesh
113    assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
114    return mesh.get_group()
115
116
117def _rewrite_spec_if_needed(
118    spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
119) -> shard_spec.ShardingSpec:
120    """
121    Rewrite ``spec`` to match the device of ``tensor``.
122
123    FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec
124    produces CUDA metadata, ST construction bombs.
125    """
126    if not isinstance(spec, ChunkShardingSpec):
127        return spec
128
129    # let's see if we need
130    rewrite = False
131    for p in spec.placements:
132        p = cast(_remote_device, p)
133        if p.rank() == rank and p.device() != tensor.device:
134            rewrite = True
135            break
136    if rewrite:
137        spec = copy.deepcopy(spec)
138        for i, placement in enumerate(spec.placements):
139            placement = cast(_remote_device, placement)
140            if placement.rank() == rank and placement.device() != tensor.device:
141                spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
142
143    return spec
144
145
146def _chunk_tensor(
147    tensor: torch.Tensor,
148    rank: int,
149    world_size: int,
150    num_devices_per_node: int,
151    pg: dist.ProcessGroup,
152) -> torch.Tensor:
153    if type(tensor) is ShardedTensor:
154        assert len(tensor.local_shards()) == 1
155
156        inner_param = tensor.local_tensor()
157        inner_st = _create_chunk_sharded_tensor(
158            inner_param,
159            rank,
160            world_size,
161            num_devices_per_node,
162            pg,
163        )
164
165        outer_local_shard = tensor.local_shards()[0]
166        shards: List[Shard] = [
167            Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
168        ]
169        st_meta = copy.deepcopy(tensor.metadata())
170        st_meta.tensor_properties.requires_grad = False
171
172        st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
173            shards,
174            sharded_tensor_metadata=st_meta,
175            process_group=tensor._process_group,
176            init_rrefs=False,
177        )
178        return st_outer
179    elif type(tensor) is DTensor:
180        device_mesh = tensor.device_mesh
181        assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
182
183        inner_param = tensor._local_tensor
184
185        inner_st = _create_chunk_sharded_tensor(
186            inner_param,
187            rank,
188            world_size,
189            torch.cuda.device_count(),
190            pg,
191        )
192
193        dt_pg = _get_dt_pg(tensor)
194        # We do this differently here, we create a ST with no local shards then patch it
195        shards = [
196            Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)))
197        ]
198
199        st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg)
200        st_meta.tensor_properties.requires_grad = False
201
202        st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
203            shards,
204            sharded_tensor_metadata=st_meta,
205            process_group=dt_pg,
206            init_rrefs=False,
207        )
208
209        return st_outer
210    else:
211        return _create_chunk_sharded_tensor(
212            tensor,
213            rank,
214            world_size,
215            num_devices_per_node,
216            pg,
217        )
218
219
220def _chunk_dtensor(
221    tensor: torch.Tensor,
222    rank: int,
223    device_mesh: DeviceMesh,
224) -> DTensor:
225    """
226    Shard a tensor to chunks along the first dimension.
227
228    The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
229    """
230    root_mesh = _mesh_resources.get_root_mesh(device_mesh)
231    if root_mesh is None:
232        raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
233    if root_mesh.ndim < 2:
234        raise RuntimeError(
235            f"Found parent device_mesh of ndim={root_mesh.ndim},",
236            "but meshes must be at least 2D.",
237        )
238
239    # We need to explicitly call .detach() to return a new tensor detached from the current graph.
240    tensor = tensor.clone().detach()
241
242    # When a layer is not involved in TP, then the tensor will not be a DTensor.
243    # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer.
244    # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer.
245    if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor):
246        # For tensors, it is replicated across tp dimension and sharded across FSDP dimension.
247        # TP is the inner dimension and FSDP is the outer dimension.
248        # Therefore, shard placements for tensor is (Shard(0), Replicate()).
249        replicate_placements = [Replicate() for _ in range(root_mesh.ndim)]
250        shard_placements = [Replicate() for _ in range(root_mesh.ndim)]
251        shard_placements[0] = DShard(0)  # type: ignore[call-overload]
252
253        return DTensor.from_local(
254            tensor, root_mesh, replicate_placements, run_check=False
255        ).redistribute(
256            device_mesh=root_mesh,
257            placements=shard_placements,
258        )
259
260    else:
261        tp_placements = tensor.placements
262        tp_placement = tp_placements[0]
263
264        tensor = tensor.to_local()
265
266        # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension.
267        # TP is the inner dimension and FSDP is the outer dimension.
268        # Therefore, shard placements for tensor is (Shard(0), tp_placement).
269        # For higher dimensional meshes, it is replicated across other dimensions. For example, with
270        # HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement).
271        replicate_placements = [Replicate() for _ in range(root_mesh.ndim)]
272        replicate_placements[-1] = tp_placement  # type: ignore[call-overload]
273        shard_placements = [Replicate() for i in range(root_mesh.ndim)]  # type: ignore[misc]
274        shard_placements[-2] = DShard(0)  # type: ignore[call-overload]
275        shard_placements[-1] = tp_placement  # type: ignore[call-overload]
276
277        return DTensor.from_local(
278            tensor, root_mesh, replicate_placements, run_check=False
279        ).redistribute(
280            device_mesh=root_mesh,
281            placements=shard_placements,
282        )
283
284
285def _pre_load_state_dict(
286    tensor: torch.Tensor,
287) -> Tuple[torch.Tensor, List[Shard]]:
288    shards = cast(ShardedTensor, tensor).local_shards()
289    if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
290        inner_tensor = shards[0].tensor
291        shards = inner_tensor.local_shards()  # pyre-ignore[16]
292        tensor = inner_tensor
293
294    return (tensor, shards if len(shards) > 0 else [])
295
296
297def _all_gather_dtensor(
298    tensor: DTensor,
299    parent_mesh: Optional[DeviceMesh],
300) -> torch.Tensor:
301    """All gather a DTensor in its FSDP dimension and return the local tensor."""
302    assert parent_mesh == tensor.device_mesh
303
304    placements = list(copy.deepcopy(tensor.placements))
305    # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement]
306    # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement]
307    for i in range(0, len(placements) - 1):
308        placements[i] = Replicate()
309    tensor = tensor.redistribute(
310        device_mesh=tensor.device_mesh,
311        placements=placements,
312    )
313
314    return tensor.to_local()
315
316
317class DTensorExtensions(FSDPExtensions):
318    """
319    DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP.
320
321    This is the implementation for FSDPExtensions defined in
322    https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py
323    """
324
325    def __init__(self, device_handle) -> None:
326        super().__init__()
327        self.compute_stream = None
328        self.device_handle = device_handle
329        # we have to use the dynamo disable this way to disable dynamo as the decorater way would
330        # trigger build failure with torch deploy...
331        self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform)  # type: ignore[method-assign]
332
333    def pre_flatten_transform(
334        self,
335        tensor: torch.Tensor,
336    ) -> Tuple[torch.Tensor, Optional[Any]]:
337        return _flatten_tensor(tensor)
338
339    def post_unflatten_transform(
340        self, tensor: torch.Tensor, param_extension: Any
341    ) -> torch.Tensor:
342        stream = self.compute_stream or self.device_handle.current_stream()
343        with self.device_handle.stream(stream):
344            # runtime we put the unflattened tensor call on the compute stream since
345            # the unflattened tensor might contain computations in fwd/bwd where we
346            # need to sync properly.
347            # TODO: this is a short term fix and we should make the get_unflat_views
348            # directly happen in the compute stream.
349            result = _unflatten_tensor(
350                tensor,
351                param_extension,
352                device_handle=self.device_handle,
353                compute_stream=self.compute_stream,
354            )
355            _set_fsdp_flattened(result)
356            return result
357
358    def chunk_tensor(
359        self,
360        tensor: torch.Tensor,
361        rank: int,
362        world_size: int,
363        num_devices_per_node: int,
364        pg: dist.ProcessGroup,
365        device: Optional[torch.device] = None,
366    ) -> torch.Tensor:
367        return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
368
369    def chunk_dtensor(
370        self,
371        tensor: torch.Tensor,
372        rank: int,
373        device_mesh: DeviceMesh,
374    ) -> torch.Tensor:
375        return _chunk_dtensor(tensor, rank, device_mesh)
376
377    def pre_load_state_dict_transform(
378        self,
379        tensor: torch.Tensor,
380    ) -> Tuple[torch.Tensor, List[Shard]]:
381        return _pre_load_state_dict(tensor)
382
383    def all_gather_dtensor(
384        self,
385        tensor: DTensor,
386        parent_mesh: Optional[DeviceMesh],
387    ) -> torch.Tensor:
388        return _all_gather_dtensor(tensor, parent_mesh)
389