xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/_data_parallel_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from functools import partial
2from typing import no_type_check, Optional, Tuple
3
4import torch
5from torch.distributed._functional_collectives import AsyncCollectiveTensor
6from torch.distributed.tensor import DTensor
7from torch.distributed.tensor._dtensor_spec import DTensorSpec
8
9
10@no_type_check
11def sync_grad_hook(grad, *, device_handle=None, compute_stream=None):
12    if isinstance(grad, AsyncCollectiveTensor):
13        if compute_stream is not None:
14            with device_handle.stream(compute_stream):
15                grad = grad.wait()
16        else:
17            grad = grad.wait()
18
19    return grad
20
21
22def _flatten_tensor(
23    tensor: torch.Tensor,
24) -> Tuple[torch.Tensor, Optional[DTensorSpec]]:
25    if isinstance(tensor, DTensor):
26        tensor._local_tensor.requires_grad_()
27        return tensor._local_tensor, tensor._spec
28    return tensor, None
29
30
31@no_type_check
32def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None):
33    # unflatten would mainly be called everytime FSDP allgather parameters.
34    result = DTensor.from_local(
35        tensor,
36        spec.mesh,
37        spec.placements,
38        run_check=False,
39        shape=spec.shape,
40        stride=spec.stride,
41    )
42    if tensor.requires_grad:
43        # only register the hook if the tensor requires grad
44        tensor.register_hook(
45            partial(
46                sync_grad_hook,
47                device_handle=device_handle,
48                compute_stream=compute_stream,
49            )
50        )
51    return result
52