xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharded_tensor/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections.abc
3import copy
4from typing import List, Optional, Sequence, TYPE_CHECKING
5
6import torch
7from torch.distributed import distributed_c10d as c10d, rpc
8from torch.distributed._shard.sharding_spec._internals import (
9    check_tensor,
10    validate_non_overlapping_shards_metadata,
11)
12
13from .metadata import ShardedTensorMetadata, TensorProperties
14from .shard import Shard
15
16
17if TYPE_CHECKING:
18    from torch.distributed._shard.metadata import ShardMetadata
19
20
21def _parse_and_validate_remote_device(pg, remote_device):
22    if remote_device is None:
23        raise ValueError("remote device is None")
24
25    worker_name = remote_device.worker_name()
26    rank = remote_device.rank()
27    device = remote_device.device()
28
29    # Validate rank, skip validation if rank is not part of process group.
30    if rank is not None and not c10d._rank_not_in_group(pg):
31        pg_global_ranks = c10d.get_process_group_ranks(pg)
32        if rank not in pg_global_ranks:
33            raise ValueError(
34                f"Global rank {rank} does not exist in input process group: {pg_global_ranks}"
35            )
36
37    if worker_name is not None:
38        if not rpc._is_current_rpc_agent_set():
39            raise RuntimeError(
40                f"RPC framework needs to be initialized for using worker names: {worker_name}"
41            )
42
43        workers = rpc._get_current_rpc_agent().get_worker_infos()
44        for worker in workers:
45            if worker.name == worker_name:
46                return worker.id, device
47
48        raise ValueError(f"Invalid worker name: {worker_name}")
49
50    return rank, device
51
52
53def _validate_output_tensor_for_gather(
54    my_rank: int,
55    dst_rank: int,
56    size: torch.Size,
57    dst_tensor: Optional[torch.Tensor],
58) -> None:
59    if dst_rank == my_rank:
60        if dst_tensor is None:
61            raise ValueError(
62                f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}"
63            )
64        if tuple(size) != (dst_tensor.size()):
65            raise ValueError(
66                f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())},"
67                f"but should be {tuple(size)}"
68            )
69    elif dst_tensor:
70        raise ValueError(
71            "Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks."
72        )
73
74
75def _flatten_tensor_size(size) -> torch.Size:
76    """
77    Checks if tensor size is valid, then flatten/return a torch.Size object.
78    """
79    if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
80        dims = list(*size)
81    else:
82        dims = list(size)
83
84    for dim in dims:
85        if not isinstance(dim, int):
86            raise TypeError(f"size has to be a sequence of ints, found: {dims}")
87
88    return torch.Size(dims)
89
90
91def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
92    if is_local:
93        assert isinstance(ranks, int)
94        if expected != actual:
95            raise ValueError(
96                f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! "
97                f"Found one local shard tensor {prop_name}={expected}, "
98                f"the other local shard tensor {prop_name}={actual}."
99            )
100    else:
101        # compare failure check across ranks, ranks list should have two rank
102        assert len(ranks) == 2
103        if expected != actual:
104            raise ValueError(
105                f"ShardedTensor {prop_name} property does not match from different ranks! "
106                f"Found {prop_name}={expected} on rank:{ranks[0]}, "
107                f"and {prop_name}={actual} on rank:{ranks[1]}."
108            )
109
110
111def build_metadata_from_local_shards(
112    local_shards: List[Shard],
113    global_size: torch.Size,
114    current_rank: int,
115    pg: c10d.ProcessGroup,
116) -> ShardedTensorMetadata:
117    assert len(local_shards) > 0, "must have local shards!"
118    local_shard_metadatas: List[ShardMetadata] = []
119
120    first_shard_dtype = local_shards[0].tensor.dtype
121    first_shard_layout = local_shards[0].tensor.layout
122    first_shard_requires_grad = local_shards[0].tensor.requires_grad
123    first_shard_is_pinned = local_shards[0].tensor.is_pinned()
124
125    # 1). Validate local tensors and associated metadatas
126    for local_shard in local_shards:
127        local_shard_tensor = local_shard.tensor
128        local_shard_meta = local_shard.metadata
129        local_shard_metadatas.append(local_shard_meta)
130        rank, local_device = _parse_and_validate_remote_device(
131            pg, local_shard_meta.placement
132        )
133
134        if (
135            local_shard_tensor.layout != torch.strided
136            or local_shard_tensor.layout != first_shard_layout
137        ):
138            raise ValueError(
139                f"Only torch.strided layout is currently supported, but found "
140                f"{local_shard_tensor.layout} on rank:{current_rank}!"
141            )
142
143        if not local_shard_tensor.is_contiguous():
144            raise ValueError(
145                "Only torch.contiguous_format memory_format is currently supported!"
146            )
147
148        if rank != current_rank:
149            raise ValueError(
150                f"Local shard metadata's rank does not match with the rank in its process group! "
151                f"Found current rank in the process group: {current_rank}, "
152                f"local ShardMetadata placement's rank: {rank}"
153            )
154        if local_shard_tensor.device != local_device:
155            raise ValueError(
156                f"Local shard tensor device does not match with local Shard's placement! "
157                f"Found local shard tensor device: {local_shard_tensor.device}, "
158                f"local shard metadata placement device: {local_device}"
159            )
160
161        _raise_if_mismatch(
162            local_shard_meta.shard_sizes,
163            list(local_shard_tensor.size()),
164            "size",
165            current_rank,
166        )
167        _raise_if_mismatch(
168            local_shard_tensor.is_pinned(),
169            first_shard_is_pinned,
170            "pin_memory",
171            current_rank,
172        )
173        _raise_if_mismatch(
174            local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank
175        )
176        _raise_if_mismatch(
177            local_shard_tensor.requires_grad,
178            first_shard_requires_grad,
179            "requires_grad",
180            current_rank,
181        )
182
183    # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then
184    #    do all_gather to collect local_sharded_tensor_metadata from all ranks
185    local_tensor_properties = TensorProperties(
186        dtype=first_shard_dtype,
187        layout=first_shard_layout,
188        requires_grad=first_shard_requires_grad,
189        memory_format=torch.contiguous_format,
190        pin_memory=first_shard_is_pinned,
191    )
192
193    local_sharded_tensor_metadata = ShardedTensorMetadata(
194        shards_metadata=local_shard_metadatas,
195        size=global_size,
196        tensor_properties=local_tensor_properties,
197    )
198
199    return local_sharded_tensor_metadata
200
201
202def build_global_metadata(
203    gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]],
204):
205    global_sharded_tensor_metadata = None
206    global_metadata_rank = 0
207
208    for rank, rank_metadata in enumerate(gathered_metadatas):
209        if rank_metadata is None:
210            continue
211
212        if global_sharded_tensor_metadata is None:
213            global_sharded_tensor_metadata = copy.deepcopy(rank_metadata)
214            global_metadata_rank = rank
215        else:
216            _raise_if_mismatch(
217                global_sharded_tensor_metadata.size,
218                rank_metadata.size,
219                "global_size",
220                [global_metadata_rank, rank],
221                is_local=False,
222            )
223
224            # don't need to check layout and memory format as we already checked in local shards validation stage
225            _raise_if_mismatch(
226                global_sharded_tensor_metadata.tensor_properties.dtype,
227                rank_metadata.tensor_properties.dtype,
228                "dtype",
229                [global_metadata_rank, rank],
230                is_local=False,
231            )
232
233            _raise_if_mismatch(
234                global_sharded_tensor_metadata.tensor_properties.requires_grad,
235                rank_metadata.tensor_properties.requires_grad,
236                "requires_grad",
237                [global_metadata_rank, rank],
238                is_local=False,
239            )
240
241            _raise_if_mismatch(
242                global_sharded_tensor_metadata.tensor_properties.pin_memory,
243                rank_metadata.tensor_properties.pin_memory,
244                "pin_memory",
245                [global_metadata_rank, rank],
246                is_local=False,
247            )
248            # pass all validations, extend shards metadata
249            global_sharded_tensor_metadata.shards_metadata.extend(
250                rank_metadata.shards_metadata
251            )
252
253    if global_sharded_tensor_metadata is not None:
254        # check if shards_metadata have overlap shards
255        validate_non_overlapping_shards_metadata(
256            global_sharded_tensor_metadata.shards_metadata
257        )
258
259        # check if the shards_metadata is compatible with global size of the sharded tensor.
260        check_tensor(
261            global_sharded_tensor_metadata.shards_metadata,
262            global_sharded_tensor_metadata.size,
263        )
264    else:
265        raise ValueError("ShardedTensor have no local shards on all ranks!")
266
267    return global_sharded_tensor_metadata
268