1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import functools 4import itertools 5import operator 6from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union 7 8import torch 9from torch.distributed.tensor._api import DTensor 10from torch.distributed.tensor._collective_utils import redistribute_cost 11from torch.distributed.tensor._dtensor_spec import DTensorSpec 12from torch.distributed.tensor._op_schema import ( 13 OpSchema, 14 OpStrategy, 15 PlacementList, 16 PlacementStrategy, 17 RuntimeSchemaInfo, 18) 19from torch.distributed.tensor.device_mesh import DeviceMesh 20from torch.distributed.tensor.placement_types import ( 21 Partial, 22 Placement, 23 Replicate, 24 Shard, 25) 26 27 28# convenient wrapper to register sharding propagation rules 29# pyre-fixme[3]: Return type must be annotated. 30# pyre-fixme[2]: Parameter must be annotated. 31def register_prop_rule(op, schema_info=None): 32 # pyre-fixme[53]: Captured variable `func` is not annotated. 33 # pyre-fixme[3]: Return type must be annotated. 34 # pyre-fixme[2]: Parameter must be annotated. 35 def wrapper(impl): 36 overloads = op if isinstance(op, list) else [op] 37 for overload in overloads: 38 DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( 39 overload, impl, schema_info 40 ) 41 return impl 42 43 return wrapper 44 45 46def register_op_strategy(op, schema_info=None): 47 # pyre-fixme[53]: Captured variable `func` is not annotated. 48 # pyre-fixme[3]: Return type must be annotated. 49 # pyre-fixme[2]: Parameter must be annotated. 50 51 # For every ATen op that accepts any args in this list, 52 # the arg itself can impact the strides (and potentially the sharding strategy) 53 # of the output tensor. 54 # thus, we will detect ATen schemas with any of these args and ensure 55 # that they get specialized here. 56 arg_names_that_require_specializing_cache_strategy = [ 57 "memory_format", 58 ] 59 60 def wrapper(impl): 61 if isinstance(op, list): 62 overloads = op 63 else: 64 overloads = [op] 65 66 for overload in overloads: 67 curr_schema_info = None 68 if schema_info is None: 69 specialized_args = [ 70 a.name 71 for a in overload._schema.arguments 72 if a.name in arg_names_that_require_specializing_cache_strategy 73 ] 74 if any(specialized_args): 75 curr_schema_info = RuntimeSchemaInfo( 76 static_kwargkey=specialized_args 77 ) 78 else: 79 curr_schema_info = schema_info 80 DTensor._op_dispatcher.sharding_propagator.register_op_strategy( 81 overload, impl, curr_schema_info 82 ) 83 return impl 84 85 return wrapper 86 87 88def as_list( 89 x: Union[List[object], object] 90 # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. 91) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] 92 # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, 93 # which is an object but treated as a list by the tracer. Therefore, keep 94 # `immutable_list` intact here as well. 95 if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list): 96 return x 97 else: 98 return [x] 99 100 101def normalize_dim(dim: int, ndim: int) -> int: 102 return dim if dim >= 0 else dim + ndim 103 104 105def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]: 106 """Normalize a dim or a sequence of dims, so that they are all positive.""" 107 if isinstance(dims, int): 108 dims = (normalize_dim(dims, ndim),) 109 elif isinstance(dims, list): 110 dims = [normalize_dim(dim, ndim) for dim in dims] 111 elif isinstance(dims, tuple): 112 dims = tuple([normalize_dim(dim, ndim) for dim in dims]) 113 return dims 114 115 116def prod(xs: Iterable[int]) -> int: 117 return functools.reduce(operator.mul, xs, 1) 118 119 120def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: 121 """Check if the shape is shardable according to the spec.""" 122 # number of shards in each tensor dimension 123 shards_map = [1] * len(shape) 124 for i, placement in enumerate(spec.placements): 125 if placement.is_shard(): 126 shard_dim = cast(Shard, placement).dim 127 shards_map[shard_dim] *= spec.mesh.size(i) 128 129 for i, dim_size in enumerate(shape): 130 # TODO: maybe we should determine is_shardable based on 131 # whether it's evenly sharded or not 132 if shards_map[i] > 1 and dim_size < shards_map[i]: 133 return False 134 135 return True 136 137 138def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: 139 """Check if the shape is evenly shardable according to the spec.""" 140 # number of shards in each tensor dimension 141 shards_map = [1] * len(shape) 142 for i, placement in enumerate(spec.placements): 143 if placement.is_shard(): 144 shard_dim = cast(Shard, placement).dim 145 shards_map[shard_dim] *= spec.mesh.size(i) 146 147 for i, dim_size in enumerate(shape): 148 if shards_map[i] > 1 and (dim_size % shards_map[i] != 0): 149 return False 150 151 return True 152 153 154def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool: 155 """Return True if tensor dim is sharded.""" 156 return any(p.is_shard(dim) for p in spec.placements) 157 158 159def is_tensor_partial(spec: DTensorSpec) -> bool: 160 """Return True if tensor is partial on the mesh.""" 161 return any(p.is_partial() for p in spec.placements) 162 163 164def infer_broadcast_dims_map( 165 common_shape: torch.Size, input_shape: torch.Size 166) -> List[int]: 167 # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim 168 # this is aligned with the broadcast semantics 169 common_ndim = len(common_shape) 170 input_ndim = len(input_shape) 171 broadcast_dims_map = [-1] * common_ndim 172 for idx in range(-1, -1 - input_ndim, -1): 173 if input_shape[idx] == common_shape[idx]: 174 broadcast_dims_map[common_ndim + idx] = input_ndim + idx 175 return broadcast_dims_map 176 177 178def map_placements_after_broadcast( 179 placements: Tuple[Placement, ...], 180 shape: torch.Size, 181 broadcast_dims_map: List[int], 182) -> Tuple[Placement, ...]: 183 """Map each placement based on the output shape after broadcast.""" 184 new_placements: List[Placement] = [] 185 for placement in placements: 186 if isinstance(placement, (Replicate, Partial)): 187 new_placements.append(placement) 188 else: 189 assert isinstance(placement, Shard) 190 shard_dim = normalize_dim(placement.dim, len(shape)) 191 new_shard_dim = broadcast_dims_map[shard_dim] 192 if new_shard_dim != -1: 193 # there's a map from the common shape shard dim to 194 # the input shape shard dim before broadcasting, 195 # use that instead 196 new_placements.append(Shard(new_shard_dim)) 197 else: 198 # there's no map between common shape shard dim and 199 # the input shape shard dim before broadcasting, 200 # in this case it means implicit broadcasting happen 201 # in this dim, so we can just mark it as replicate 202 # and implict broadcast will broadcast automatically 203 # to the sharded shape 204 new_placements.append(Replicate()) 205 206 return tuple(new_placements) 207 208 209def generate_redistribute_costs( 210 src_strategy: OpStrategy, dst_spec: DTensorSpec 211) -> List[float]: 212 redistribute_costs: List[float] = [] 213 for strat in src_strategy.strategies: 214 redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec)) 215 216 return redistribute_costs 217 218 219def expand_to_full_mesh_op_strategy( 220 mesh: DeviceMesh, 221 op_schema: OpSchema, 222 single_mesh_dim_strategies: List[PlacementList], 223 *, 224 input_index: int = 1, 225 inplace_op: bool = False, 226) -> OpStrategy: 227 # Expand the single_mesh_dim_strategies to full mesh dim strategies. 228 all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim 229 230 strategy_combs = itertools.product(*all_mesh_dim_strategies) 231 232 all_strategies = [] 233 for strategy_comb in strategy_combs: 234 spec_list: List[Optional[DTensorSpec]] = [] 235 for specs in zip(*strategy_comb): 236 if specs[0] is not None: 237 spec_list.append(DTensorSpec(mesh, specs)) 238 else: 239 spec_list.append(None) 240 241 input_specs: List[DTensorSpec] = [ 242 s for s in spec_list[input_index:] if isinstance(s, DTensorSpec) 243 ] 244 245 input_args_strategy = op_schema.args_strategy 246 assert len(input_specs) == len(input_args_strategy) 247 self_spec = input_args_strategy[0].strategies[0].output_spec 248 249 if inplace_op and self_spec.placements != input_specs[0].placements: 250 # if it's inplace op, we would only allow the placement strategy to be added when the 251 # input_spec matches the first argument's runtime sharding, otherwise we skip 252 continue 253 254 # check inputs shardable 255 inputs_shardable = all( 256 is_tensor_shardable(inp.shape, s) 257 for inp, s in zip(input_args_strategy, input_specs) 258 ) 259 260 # only add to the all_strategies list when all inputs are shardable 261 if inputs_shardable: 262 redistribute_cost = [ 263 generate_redistribute_costs(input_strategy, input_spec) 264 for input_strategy, input_spec in zip(input_args_strategy, input_specs) 265 ] 266 if input_index > 1: 267 output_specs = tuple(spec_list[:input_index]) 268 else: 269 if spec_list[0] is not None: 270 output_specs = spec_list[0] # type: ignore[assignment] 271 else: 272 raise RuntimeError("output spec is None") 273 strategy = PlacementStrategy( 274 output_specs=output_specs, 275 input_specs=input_specs, 276 redistribute_cost=redistribute_cost, 277 ) 278 all_strategies.append(strategy) 279 280 return OpStrategy(all_strategies) 281