1from functools import partial 2from typing import no_type_check, Optional, Tuple 3 4import torch 5from torch.distributed._functional_collectives import AsyncCollectiveTensor 6from torch.distributed.tensor import DTensor 7from torch.distributed.tensor._dtensor_spec import DTensorSpec 8 9 10@no_type_check 11def sync_grad_hook(grad, *, device_handle=None, compute_stream=None): 12 if isinstance(grad, AsyncCollectiveTensor): 13 if compute_stream is not None: 14 with device_handle.stream(compute_stream): 15 grad = grad.wait() 16 else: 17 grad = grad.wait() 18 19 return grad 20 21 22def _flatten_tensor( 23 tensor: torch.Tensor, 24) -> Tuple[torch.Tensor, Optional[DTensorSpec]]: 25 if isinstance(tensor, DTensor): 26 tensor._local_tensor.requires_grad_() 27 return tensor._local_tensor, tensor._spec 28 return tensor, None 29 30 31@no_type_check 32def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None): 33 # unflatten would mainly be called everytime FSDP allgather parameters. 34 result = DTensor.from_local( 35 tensor, 36 spec.mesh, 37 spec.placements, 38 run_check=False, 39 shape=spec.shape, 40 stride=spec.stride, 41 ) 42 if tensor.requires_grad: 43 # only register the hook if the tensor requires grad 44 tensor.register_hook( 45 partial( 46 sync_grad_hook, 47 device_handle=device_handle, 48 compute_stream=compute_stream, 49 ) 50 ) 51 return result 52