1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3# Copyright (c) Meta Platforms, Inc. and affiliates 4from dataclasses import dataclass 5from typing import ( 6 Callable, 7 cast, 8 Dict, 9 Iterable, 10 List, 11 Optional, 12 Sequence, 13 Set, 14 Tuple, 15 Union, 16) 17 18import torch 19from torch import Tensor 20from torch.distributed.device_mesh import DeviceMesh 21from torch.distributed.tensor._dtensor_spec import DTensorSpec 22from torch.distributed.tensor._op_schema import ( 23 OpSchema, 24 OpStrategy, 25 PlacementStrategy, 26 RuntimeSchemaInfo, 27 StrategyType, 28) 29from torch.distributed.tensor._ops.utils import ( 30 generate_redistribute_costs, 31 normalize_dim, 32 normalize_dims, 33 prod, 34 register_op_strategy, 35) 36from torch.distributed.tensor.placement_types import Placement, Replicate, Shard 37 38 39aten = torch.ops.aten 40 41Shape = Tuple[int, ...] 42 43 44@dataclass 45class DimSpec: 46 """Specifies how an output dimension maps to an input dimension.""" 47 48 def inputs(self) -> Iterable["DimSpec"]: 49 return () 50 51 52# Rules that map each dimension of the output to dimensions of the input tensor 53DimMap = Tuple[DimSpec, ...] 54 55 56@dataclass 57class Singleton(DimSpec): 58 """Output dimension is a singleton.""" 59 60 61@dataclass 62class InputDim(DimSpec): 63 """Output dimension maps directly to an input dimension.""" 64 65 input_dim: int 66 67 68@dataclass 69class Broadcast(DimSpec): 70 """Output is the broadcast of a singleton input dimension.""" 71 72 dim: DimSpec 73 dim_size: int 74 75 @classmethod 76 def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: 77 return Broadcast(dim, dim_size) 78 79 def inputs(self) -> Iterable[DimSpec]: 80 return (self.dim,) 81 82 83@dataclass 84class NewDim(DimSpec): 85 """This is a new dimension created by the op.""" 86 87 size: int 88 89 @classmethod 90 def new(cls, size: int) -> DimSpec: 91 return Singleton() if size == 1 else NewDim(size) 92 93 94@dataclass 95class Repeat(DimSpec): 96 """Output dimension is the input dimension repeated n-times.""" 97 98 input_dim: DimSpec 99 times: int 100 101 @classmethod 102 def new(cls, dim: DimSpec, times: int) -> DimSpec: 103 if times == 1: 104 return dim 105 elif isinstance(dim, Singleton): 106 # repeating a singleton is the same as broadcasting it 107 return Broadcast(dim, times) 108 else: 109 return Repeat(dim, times) 110 111 def inputs(self) -> Iterable[DimSpec]: 112 return (self.input_dim,) 113 114 115@dataclass 116class Flatten(DimSpec): 117 """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output.""" 118 119 input_dims: Sequence[DimSpec] 120 121 @classmethod 122 def new(cls, dims: Sequence[DimSpec]) -> DimSpec: 123 if len(dims) == 0: 124 # flattening a scalar leads to a singleton 125 return Singleton() 126 elif len(dims) == 1: 127 # flattening a single dimension is no-op 128 return dims[0] 129 else: 130 return Flatten(dims) 131 132 def inputs(self) -> Iterable[DimSpec]: 133 return self.input_dims 134 135 136@dataclass 137class Split(DimSpec): 138 """ 139 This dimension is a member of a decomposition of the input dim. 140 141 Note that input_dim itself could be a Flattened set of input dims. 142 """ 143 144 input_dim: DimSpec 145 group_shape: Shape 146 split_id: int 147 148 @classmethod 149 def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec: 150 assert len(group_shape) > 0 151 if len(group_shape) == 1: 152 # not really a group, just return the input dim back 153 assert idx == 0 154 return dim 155 elif group_shape[idx] == 1: 156 return Singleton() 157 else: 158 # remove singletons from group 159 # group_mapping = [(new_index, (shape, old_index)) ...] 160 group_mapping = list( 161 enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) 162 ) 163 new_group_shape = tuple(m[1][0] for m in group_mapping) 164 new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] 165 return Split(dim, new_group_shape, new_idx) 166 167 def inputs(self) -> Iterable[DimSpec]: 168 return (self.input_dim,) 169 170 171def dim_pad_left(ndim: int, min_dims: int) -> DimMap: 172 return (Singleton(),) * max(0, min_dims - ndim) + tuple( 173 InputDim(i) for i in range(ndim) 174 ) 175 176 177def dim_atleast_3d(ndim: int) -> DimMap: 178 if ndim == 0: 179 return (Singleton(), Singleton(), Singleton()) 180 elif ndim == 1: 181 return (Singleton(), InputDim(0), Singleton()) 182 elif ndim == 2: 183 return (InputDim(0), InputDim(1), Singleton()) 184 else: 185 return tuple(InputDim(i) for i in range(ndim)) 186 187 188def expand(input_shape: Shape, shape: Shape) -> DimMap: 189 """Implement broadcast on multiple dimensions.""" 190 assert len(shape) >= len(input_shape) 191 192 # 1. create padded input dimensions 193 padded_input = dim_pad_left(len(input_shape), len(shape)) 194 # 2. check that input shapes are compatible 195 mapping = [] 196 for p, desired_s in zip(padded_input, shape): 197 if isinstance(p, Singleton): 198 actual_s = 1 199 assert desired_s >= 0 200 else: 201 assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}" 202 actual_s = input_shape[p.input_dim] 203 assert actual_s == 1 or desired_s == -1 or desired_s == actual_s 204 mapping.append( 205 p 206 if desired_s in (1, -1) or desired_s == actual_s 207 else Broadcast.new(p, desired_s) 208 ) 209 return tuple(mapping) 210 211 212def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape: 213 if isinstance(sizes[0], int): 214 return cast(Shape, sizes) 215 elif len(sizes) == 1: 216 return sizes[0] 217 else: 218 raise RuntimeError("Size must be int... or tuple") 219 220 221def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: 222 if ndim == 0: 223 return (Singleton(),) 224 elif ndim == 1: 225 return (InputDim(0),) 226 else: 227 # only flattening dims from start_dim to end_dim (inclusive) 228 # other dims are passed through 229 if end_dim < 0: 230 end_dim += ndim 231 results: List[DimSpec] = [InputDim(i) for i in range(start_dim)] 232 results.append( 233 Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1))) 234 ) 235 results.extend([InputDim(i) for i in range(end_dim + 1, ndim)]) 236 return tuple(results) 237 238 239def dim_movedim( 240 ndim: int, 241 input: Union[int, Sequence[int]], 242 destination: Union[int, Sequence[int]], 243) -> DimMap: 244 input = normalize_dims(input, ndim) 245 destination = normalize_dims(destination, ndim) 246 247 assert len(input) == len(destination) 248 input_set = set(input) 249 assert len(input_set) == len(input), "Found repeated input dims" 250 assert len(set(destination)) == len(destination), "Found repeated output dims" 251 assert max(input) < ndim 252 assert max(destination) < ndim 253 254 dest = [-1] * ndim 255 for i, d in zip(input, destination): 256 dest[d] = i 257 258 unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) 259 for i in range(ndim): 260 if dest[i] == -1: 261 dest[i] = next(unused_inputs_iter) 262 263 return tuple(InputDim(i) for i in dest) 264 265 266def dim_repeat(ndim: int, sizes: Shape) -> DimMap: 267 sizes = normalize_sizes(sizes) 268 assert ( 269 len(sizes) >= ndim 270 ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." 271 pad = len(sizes) - ndim 272 return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( 273 Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) 274 ) 275 276 277def infer_size(total_size: int, sizes: Shape) -> Shape: 278 """ 279 One dimension input to view may be "-1". 280 281 Infer the size of this dimension given the total_size. 282 """ 283 infers = [i for i, s in enumerate(sizes) if s == -1] 284 size = prod(sizes) 285 assert len(infers) <= 1, "can only infer one size" 286 if infers: 287 size = -size 288 missing_size = total_size // size 289 assert ( 290 total_size % size == 0 291 ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." 292 return tuple(s if s != -1 else missing_size for s in sizes) 293 assert size == total_size, f"sizes do not match {total_size} vs {size}" 294 return sizes 295 296 297def view_groups(from_size: Shape, to_size: Shape) -> DimMap: 298 """ 299 Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension. 300 301 A view or reshape operation can be decomposed into a set of 3 types of smaller operations: 302 1) Forward a dimension from input to output 303 2) Flatten a set of dimensions into a single dimension 304 3) Split one dimension into multiple dimensions 305 306 view_groups identifies these operations and returns, for each output dimension, what 307 is operation was performed in the input dimension. For example: 308 309 view_groups([2, 3, 4], [2, 12]) -> ( 310 InputDim(0), 311 Flatten((InputDim(1), InputDim(2))) 312 ) 313 314 - ouptut dimension 0 maps to input dimension 0 315 - output dimension 1 maps to a flattened input dimensions 1 and 2 316 317 318 view_groups([2, 3], [3, 2]) -> ( 319 Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), 320 Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), 321 ) 322 323 - in the above, input is flattened into a single dimension and then split 324 into two separate dimensions with different sizes from the input. 325 """ 326 from_nelem = prod(from_size) 327 to_size = infer_size(from_nelem, normalize_sizes(to_size)) 328 329 assert from_nelem == prod(to_size), "Total view shape does not add up" 330 331 from_idx = 0 332 to_idx = 0 333 from_len = len(from_size) 334 to_len = len(to_size) 335 336 result_pp = [] 337 338 while from_idx < from_len or to_idx < to_len: 339 from_group_dim, to_group_shape = [], [] 340 341 if from_idx >= from_len: 342 f = 1 343 else: 344 f = from_size[from_idx] 345 from_group_dim.append(from_idx) 346 from_idx += 1 347 348 if to_idx >= to_len: 349 t = 1 350 else: 351 t = to_size[to_idx] 352 to_group_shape.append(t) 353 to_idx += 1 354 355 # if any of the groups is singleton, great, we need to backtrack though 356 if f == 1 and t != 1: 357 # produces ([1], []) 358 to_idx -= 1 359 to_group_shape = [] 360 elif f != 1 and t == 1: 361 # produces ([], [1]) 362 from_idx -= 1 363 from_group_dim = [] 364 else: 365 # produces ([1], [1]), ([2], [2]), ([2,3], [6]) 366 while f != t: 367 if f < t: 368 nf = from_size[from_idx] 369 from_group_dim.append(from_idx) 370 from_idx += 1 371 f *= nf 372 else: 373 nt = to_size[to_idx] 374 to_group_shape.append(nt) 375 to_idx += 1 376 t *= nt 377 378 if len(to_group_shape) > 0: 379 flattened = Flatten.new( 380 tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1) 381 ) 382 result_pp += [ 383 Split.new(flattened, tuple(to_group_shape), i) 384 for i in range(len(to_group_shape)) 385 ] 386 387 return tuple(result_pp) 388 389 390def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap: 391 if len(dims) < ndim: 392 dims = (1,) * (ndim - len(dims)) + dims 393 return dim_repeat(ndim, dims) 394 395 396def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: 397 dim1 = normalize_dim(dim1, ndim) 398 dim2 = normalize_dim(dim2, ndim) 399 assert dim1 < ndim 400 assert dim2 < ndim 401 dimmap = [InputDim(i) for i in range(ndim)] 402 swapdim = dimmap[dim1] 403 dimmap[dim1] = dimmap[dim2] 404 dimmap[dim2] = swapdim 405 return tuple(dimmap) 406 407 408def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: 409 # FIXME: this is wrong when dim=None and one of the dimensions 410 # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could 411 # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to 412 # removal of a dimension that is not actually a singleton. 413 return tuple( 414 InputDim(i) 415 for i, s in enumerate(shape) 416 if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) 417 ) 418 419 420def dim_unsqueeze(ndim: int, dim: int) -> DimMap: 421 dims = tuple(InputDim(i) for i in range(ndim)) 422 if dim < 0: 423 dim += ndim + 1 424 return dims[:dim] + (Singleton(),) + dims[dim:] 425 426 427def dim_view_as_real(shape: Shape) -> DimMap: 428 ndim = len(shape) 429 results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)] 430 # each complex number is split into two real numbers, 431 # resulting in one more dimension of size 2 432 results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0)) 433 results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1)) 434 return tuple(results) 435 436 437def dim_reduction( 438 ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool 439) -> DimMap: 440 """ 441 General fallback for reduction ops where Partial() does not apply. 442 443 This will cause incoming tensor to be replicated on the reducing dimensions. 444 """ 445 if dim_or_dims is None: 446 dim_or_dims = tuple(range(ndim)) 447 if isinstance(dim_or_dims, int): 448 dim_or_dims = (dim_or_dims,) 449 dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) 450 return tuple( 451 InputDim(i) if i not in dim_or_dims else Singleton() 452 for i in range(ndim) 453 if i not in dim_or_dims or keepdim 454 ) 455 456 457dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { 458 torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), 459 torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), 460 torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), 461 torch.broadcast_to: lambda input, shape: expand(input.shape, shape), 462 Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), 463 torch.flatten: lambda tensor: dim_flatten(tensor.ndim), 464 torch.movedim: lambda input, source, destination: dim_movedim( 465 input.ndim, source, destination 466 ), 467 torch.permute: lambda input, dims: tuple( 468 InputDim(i) for i in normalize_dims(dims, input.ndim) 469 ), 470 torch.ravel: lambda tensor: dim_flatten(tensor.ndim), 471 Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), 472 torch.reshape: lambda input, shape: view_groups(input.shape, shape), 473 torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), 474 torch.tile: lambda input, dims: dim_tile(input.ndim, dims), 475 torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), 476 torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), 477 Tensor.view: lambda input, *shape: view_groups(input.shape, shape), 478 torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), 479 torch.view_as_real: lambda input: dim_view_as_real(input.shape), 480} 481 482 483def propagate_shape_and_sharding( 484 input_src_placements: Sequence[Placement], 485 local_in_shape: Shape, 486 rule: DimMap, 487 mesh_sizes: Shape, 488) -> Tuple[Sequence[Placement], Sequence[Placement]]: 489 """ 490 Determine input target sharding and output sharding based on 491 given global tensor shape and input source sharding. 492 493 Sharding propagation follows mapped dimensions: 494 - An output dimension that maps directly to an input dimension is sharded equally 495 - An output dimension that is a flattened set of input dimensions can only be 496 sharded if only the leftmost flattened dimension is sharded. 497 - An output dimension that is a split of the input dimension can only be sharded 498 if the leftmost split size is divisible by the mesh dimension 499 """ 500 assert len(input_src_placements) == len(mesh_sizes) 501 # for each input dim, for each mesh dim, provides a list of possible shardable dimensions 502 mesh_ndim = len(mesh_sizes) 503 shardable_dims: Dict[int, List[bool]] = {} 504 505 # in case an input dimension disappears (e.g. collapsing, reduction) 506 # we cannot shard in that dimension (we need a replication fall-back rule) 507 seen_input_dims: Set[int] = set() 508 509 def collect_used_inputs(cmd: DimSpec) -> None: 510 if isinstance(cmd, InputDim): 511 seen_input_dims.add(cmd.input_dim) 512 for inp in cmd.inputs(): 513 collect_used_inputs(inp) 514 515 for cmd in rule: 516 collect_used_inputs(cmd) 517 for dim in range(len(local_in_shape)): 518 shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim 519 520 def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: 521 if isinstance(cmd, InputDim): 522 return cmd 523 elif isinstance(cmd, Flatten): 524 for dim in cmd.input_dims[1:]: 525 if isinstance(dim, InputDim): 526 shardable_dims[dim.input_dim] = [False] * mesh_ndim 527 dim0 = cmd.input_dims[0] 528 return dim0 if isinstance(dim0, InputDim) else None 529 elif isinstance(cmd, Split): 530 in_dim = get_in_dim_to_shard(cmd.input_dim) 531 out_size = cmd.group_shape[cmd.split_id] 532 if cmd.split_id == 0 and in_dim is not None: 533 # we need to check that the input dimension is divisible 534 # by the size of the submesh we're sharding it on 535 # NOTE: it would be possible to shard the same input dimension 536 # on more than one mesh dimension. In that case, the dimension 537 # needs to be divisible by the product of mesh sizes. 538 # In order to keep the problem more tractable, we will not consider 539 # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) 540 # but we will allow it if that's the input and it's compatible 541 542 # 1. is this dimension shardable on each individual mesh dim? 543 shardable_dims[in_dim.input_dim] = [ 544 out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes 545 ] 546 547 # 2. here we special case things like [Shard(0), Shard(0)] 548 submesh_size = 1 549 for size, shard in zip(mesh_sizes, input_src_placements): 550 if isinstance(shard, Shard) and shard.dim == in_dim: 551 submesh_size *= size 552 assert ( 553 out_size % submesh_size == 0 554 ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." 555 556 # we will only shard our first component of the split 557 return in_dim if cmd.split_id == 0 else None 558 elif isinstance(cmd, Repeat): 559 in_dim = get_in_dim_to_shard(cmd.input_dim) 560 if in_dim is not None: 561 shardable_dims[in_dim.input_dim] = [False] * mesh_ndim 562 return None 563 else: 564 return None 565 566 # for each output dim, find the corresponding input dim in terms of sharding prop 567 shard_dim_map = {} 568 for dim, cmd in enumerate(rule): 569 in_dim = get_in_dim_to_shard(cmd) 570 if in_dim is not None: 571 shard_dim_map[in_dim.input_dim] = dim 572 573 input_tgt_placements = [ 574 Replicate() 575 if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] 576 else p 577 for mesh_dim, p in enumerate(input_src_placements) 578 ] 579 output_placements = [ 580 Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p 581 for p in input_tgt_placements 582 ] 583 584 return input_tgt_placements, output_placements 585 586 587def register_op_strategy_map( 588 aten_op_overload: torch._ops.OpOverload, 589 local_op_name: Callable[..., torch.Tensor], 590 schema_info: Optional[RuntimeSchemaInfo] = None, 591) -> None: 592 dim_map: Callable[..., DimMap] = dim_maps[local_op_name] 593 594 @register_op_strategy(aten_op_overload, schema_info=schema_info) 595 def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 596 rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) 597 input_strategy = cast(OpStrategy, op_schema.args_schema[0]) 598 global_in_shape = input_strategy.shape 599 assert global_in_shape is not None, "Shape required." 600 601 output_strategy = OpStrategy([]) 602 for input_placement_strategy in input_strategy.strategies: 603 input_src_spec = input_placement_strategy.output_spec 604 605 input_tgt_placements, output_placements = propagate_shape_and_sharding( 606 input_src_spec.placements, 607 tuple(global_in_shape), 608 rules, 609 mesh.shape, 610 ) 611 612 # TODO: optimize this. we shouldn't simply blindly replicate 613 # unshardable dims ... 614 # FIXME: this can be wrong for situations where we have 615 # [Shard(0), Shard(0)] 616 input_tgt_spec = DTensorSpec( 617 placements=tuple(input_tgt_placements), 618 mesh=input_src_spec.mesh, 619 tensor_meta=input_src_spec.tensor_meta, 620 ) 621 redistribute_costs = [ 622 generate_redistribute_costs(input_strategy, input_tgt_spec) 623 ] 624 625 output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) 626 output_strategy.strategies.append( 627 PlacementStrategy( 628 output_specs=output_spec, 629 input_specs=(input_tgt_spec,), 630 redistribute_cost=redistribute_costs, 631 ) 632 ) 633 634 return output_strategy 635 636 637register_op_strategy_map(aten.squeeze.default, torch.squeeze) 638register_op_strategy_map( 639 aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) 640) 641register_op_strategy_map( 642 aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) 643) 644register_op_strategy_map( 645 aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1) 646) 647register_op_strategy_map( 648 aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) 649) 650register_op_strategy_map( 651 aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1) 652) 653register_op_strategy_map( 654 aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1) 655) 656register_op_strategy_map( 657 aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1) 658) 659register_op_strategy_map( 660 aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1) 661) 662register_op_strategy_map( 663 aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1) 664) 665register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex) 666register_op_strategy_map(aten.view_as_real.default, torch.view_as_real) 667