1# mypy: allow-untyped-defs 2import functools 3import operator 4from abc import ABC, abstractmethod 5from dataclasses import dataclass 6from typing import Callable, Dict, List, TYPE_CHECKING 7 8import torch 9import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta 10from torch.distributed._shard.metadata import ShardMetadata 11from torch.distributed._shard.op_registry_utils import _decorator_func 12 13from ._internals import ( 14 check_tensor, 15 get_chunked_dim_size, 16 get_split_size, 17 validate_non_overlapping_shards_metadata, 18) 19 20 21if TYPE_CHECKING: 22 # Only include ShardedTensor when do type checking, exclude it 23 # from run-time to resolve circular dependency. 24 from torch.distributed._shard.sharded_tensor import ShardedTensor 25 26 27class PlacementSpec(ABC): # noqa: B024 28 """ 29 Base class representing the placement of an entity. Subclasses of this 30 class can be used to specify customized placements which might not be 31 covered by existing APIs. 32 """ 33 34 35@dataclass 36class DevicePlacementSpec(PlacementSpec): 37 """ 38 Associates placement of an entity with a single device. 39 40 Args: 41 device(:class:`torch.distributed._remote_device`): The device to place the entity on. 42 """ 43 44 device: torch.distributed._remote_device 45 46 def __post_init__(self): 47 if not isinstance(self.device, torch.distributed._remote_device): 48 self.device = torch.distributed._remote_device(self.device) 49 50 51class ShardingSpec(ABC): 52 """ 53 Base class representing sharding specifications. 54 """ 55 56 @abstractmethod 57 def build_metadata( 58 self, 59 tensor_sizes: torch.Size, 60 tensor_properties: sharded_tensor_meta.TensorProperties, 61 ) -> sharded_tensor_meta.ShardedTensorMetadata: 62 """ 63 Given a global tensor size, define how to shard a tensor like this shape 64 across ranks, return ShardedTensorMetadata 65 Args: 66 tensor_sizes (:class:`torch.Size`): 67 The tensor shape to shard on, a `torch.Size` object that represents the 68 tensor shape to be sharded according to the ShardingSpec. 69 tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties): 70 Tensor properties used to create a ShardedTensor. 71 Returns: 72 A :class:`ShardedTensorMetadata` object that encodes the information about 73 the layout of the ShardedTensor and its properties. 74 """ 75 76 @abstractmethod 77 def shard( 78 self, tensor: torch.Tensor, src_rank: int = 0, process_group=None 79 ) -> "ShardedTensor": 80 """ 81 Given a global tensor on src_rank, shard this tensor 82 across ranks within the process group, return a ShardedTensor. 83 Args: 84 tensor (:class:`torch.Tensor`): Tensor needs to be sharded. 85 Keyword args: 86 src_rank (int, optional): The source rank which is used as the ground truth of 87 the data for the parameter that would be sharded and scattered 88 across the rest of the ranks. 89 Default: 0. 90 process_group (ProcessGroup, optional): The process group to work on. If None, 91 the default process group will be used. 92 Returns: 93 A :class:`ShardedTensor` sharded from the given tensor. 94 """ 95 96 97# Ops customized for a particular ShardingSpec. 98_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {} 99 100 101def _has_custom_op(sharding_spec, op): 102 """ 103 Returns whether or not the ShardingSpec has a custom op implementation. 104 """ 105 class_name = type(sharding_spec).__qualname__ 106 return ( 107 class_name in _CUSTOM_SHARDING_SPEC_OPS 108 and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] 109 ) 110 111 112def _dispatch_custom_op( 113 sharding_spec, op: Callable, types, args, kwargs, process_group 114): 115 """ 116 Calls the custom op for this ShardingSpec if it exists. 117 """ 118 class_name = type(sharding_spec).__qualname__ 119 if not _has_custom_op(sharding_spec, op): 120 raise RuntimeError(f"Custom op: {op} not registered for {class_name}") 121 func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] 122 return func(types, args, kwargs, process_group) 123 124 125def custom_sharding_spec_op(sharding_spec_class, func): 126 """ 127 Decorator to allow custom registration of ops. 128 Args: 129 sharding_spec_class(type): The ShardingSpec for which we need to add this custom op. 130 func(Callable): The op to override (ex: torch.bmm) 131 """ 132 class_name = sharding_spec_class.__qualname__ 133 if class_name not in _CUSTOM_SHARDING_SPEC_OPS: 134 _CUSTOM_SHARDING_SPEC_OPS[class_name] = {} 135 return functools.partial( 136 _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] 137 ) 138 139 140@dataclass 141class EnumerableShardingSpec(ShardingSpec): 142 """ 143 This is a type of PlacementSpec that allows users to specify a generic 144 sharding scheme by enumerating exactly how each shard is laid out. 145 146 Args: 147 shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing 148 each shard. Note that none of the shards should overlap. 149 """ 150 151 shards: List[ShardMetadata] 152 153 def __post_init__(self): 154 if len(self.shards) == 0: 155 raise ValueError(f"Empty shard list provided: {self.shards}") 156 157 # Validate each shard has same rank. 158 rank = -1 159 for shard in self.shards: 160 if rank != -1 and rank != len(shard.shard_offsets): 161 raise ValueError( 162 f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}" 163 ) 164 rank = len(shard.shard_offsets) 165 166 validate_non_overlapping_shards_metadata(self.shards) 167 168 def build_metadata( 169 self, 170 tensor_sizes: torch.Size, 171 tensor_properties: sharded_tensor_meta.TensorProperties, 172 ) -> sharded_tensor_meta.ShardedTensorMetadata: 173 # check if shards form a valid tensor 174 check_tensor(self.shards, tensor_sizes) 175 return sharded_tensor_meta.ShardedTensorMetadata( 176 self.shards, tensor_sizes, tensor_properties 177 ) 178 179 def shard( 180 self, tensor: torch.Tensor, src_rank: int = 0, process_group=None 181 ) -> "ShardedTensor": 182 # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec 183 raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") 184 185 186def _infer_sharding_spec_from_shards_metadata(shards_metadata): 187 """ 188 Infer the sharding spec from the metadata of each shard of a ShardedTensor. 189 If the tensor is sharded only on one dimension, we can then verify whether it's 190 a ChunkShardingSpec or not. The way to verify it is to first get the total length 191 and perform a chunk sharding with the given placements to see if we can have the 192 same chunk size as the given shards_metadata. If not, we assume it's enum sharded. 193 194 Args: 195 shards_metadata (List[ShardMetadata]): List of Metadata of local shards. 196 197 Returns: 198 A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding 199 spec for one sharded tensor. 200 """ 201 placements = [] 202 chunk_sharding_dim = None 203 chunk_offset_list = [] 204 shard_size_list = [] 205 shard_offset_list = [] 206 # collect local shard metadatas from the global sharded_tensor_metadata 207 for shard_metadata in shards_metadata: # type: ignore[attr-defined] 208 placements.append(shard_metadata.placement) 209 local_offsets = shard_metadata.shard_offsets 210 chunk_offset_list.append(sum(local_offsets)) 211 shard_size_list.append(shard_metadata.shard_sizes) 212 shard_offset_list.append(shard_metadata.shard_offsets) 213 shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0] 214 # If the offset is [0, 0, ..., 0] (all zeros), 215 # we cannot decide whether how the tensor is sharded. 216 if len(shard_dims) == 0: 217 continue 218 # If the offset is [0, N, .,0, M, 0, .., 0], 219 # we are sure it's sharded by more than one dimension. 220 if len(shard_dims) != 1: 221 chunk_sharding_dim = None 222 break 223 # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just 224 # one dimension, we need to make sure all ranks share the same dimension. 225 if not chunk_sharding_dim: 226 chunk_sharding_dim = shard_dims[0] 227 elif chunk_sharding_dim != shard_dims[0]: 228 chunk_sharding_dim = None 229 break 230 231 if chunk_sharding_dim is not None: 232 # Ensure we infer the correct placement order from offsets 233 placements = [ 234 x 235 for _, x in sorted( 236 zip(chunk_offset_list, placements), key=operator.itemgetter(0) 237 ) 238 ] 239 240 from .chunk_sharding_spec import ChunkShardingSpec 241 242 chunk_spec = ChunkShardingSpec( 243 dim=chunk_sharding_dim, 244 placements=placements, 245 ) 246 247 shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list]) 248 shard_total_length = sum(shard_sizes) 249 shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list]) 250 251 chunks = len(placements) 252 split_size = get_split_size(shard_total_length, chunks) 253 chunk_shard_sizes = sorted( 254 [ 255 get_chunked_dim_size(shard_total_length, split_size, idx) 256 for idx in range(chunks) 257 ] 258 ) 259 # Should match ChunkShardingSpec offsets calculation 260 chunk_shard_offsets = [split_size * idx for idx in range(chunks)] 261 if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets: 262 return chunk_spec 263 return EnumerableShardingSpec(shards_metadata) 264