1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4from dataclasses import dataclass 5from typing import cast, List, Optional, Tuple 6 7import torch 8import torch.distributed._functional_collectives as funcol 9from torch.distributed.device_mesh import DeviceMesh 10from torch.distributed.tensor._collective_utils import ( 11 fill_empty_tensor_to_shards, 12 mesh_broadcast, 13 mesh_scatter, 14 pad_tensor, 15 shard_dim_alltoall, 16 unpad_tensor, 17) 18 19 20__all__ = ["Placement", "Shard", "Replicate", "Partial"] 21 22 23class Placement: 24 """ 25 The base class for the Placement type, where it describes how a DTensor is placed onto the 26 ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. 27 It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, 28 and ``Partial``. 29 30 This class is not meant to be used directly, mainly served as a typing stub. 31 """ 32 33 # convenient utils to check for placement types 34 def is_shard(self, dim: Optional[int] = None) -> bool: 35 is_shard_instance = isinstance(self, Shard) 36 if dim is not None and is_shard_instance: 37 return cast(Shard, self).dim == dim 38 else: 39 return is_shard_instance 40 41 def is_replicate(self) -> bool: 42 return isinstance(self, Replicate) 43 44 def is_partial(self) -> bool: 45 return isinstance(self, Partial) 46 47 48@dataclass(frozen=True) 49class Shard(Placement): 50 """ 51 The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension 52 ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the 53 DeviceMesh dimension only holds a shard/piece of the global Tensor. The 54 ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the 55 last few shards on the DeviceMesh dimension might be empty when the tensor dimension 56 is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be 57 used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) 58 59 Args: 60 dim (int): The tensor dimension that describes the DTensor is sharded over its 61 corresponding DeviceMesh dimension. 62 63 .. warning:: sharding on a tensor dimension where the tensor dimension size is not 64 evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. 65 """ 66 67 dim: int 68 69 def _split_tensor( 70 self, 71 tensor: torch.Tensor, 72 num_chunks: int, 73 *, 74 with_padding: bool = True, 75 contiguous: bool = True, 76 ) -> Tuple[List[torch.Tensor], List[int]]: 77 """ 78 This function uses torch.chunk to split a tensor into num_chunks shards along 79 the Shard placement dimension, and return a list of shards with their pad sizes. 80 81 Keyword args: 82 with_padding (bool, optional): when True, we pad the tensor on the last 83 few ranks before calling the collectives (i.e. scatter/all_gather, etc.). 84 This is because collectives usually require equal size tensor inputs 85 """ 86 assert ( 87 self.dim <= tensor.ndim 88 ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" 89 90 # chunk tensor over dimension `dim` into n slices 91 tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) 92 num_empty_tensors = num_chunks - len(tensor_list) 93 94 # if no need to have padding or tensor dim size is evenly sharded already 95 # we can return early. 96 if not with_padding or tensor.size(self.dim) % num_chunks == 0: 97 if contiguous: 98 tensor_list = [t.contiguous() for t in tensor_list] 99 return ( 100 fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors), 101 [], 102 ) 103 104 # compute the chunk size inline with ``torch.chunk`` to calculate padding 105 full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks 106 107 # Compute chunk size for each chunk for ``self.dim`` 108 chunk_sizes = [ 109 tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 110 for idx in range(num_chunks) 111 ] 112 # Compute pad size on each chunk 113 pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] 114 115 # Reuse tensor to fill empty chunk with empty tensor 116 tensor_list = fill_empty_tensor_to_shards( 117 tensor_list, self.dim, num_empty_tensors 118 ) 119 shard_list = [] 120 for shard, pad_size in zip(tensor_list, pad_sizes): 121 # Fill the empty tensor with zeroes with padding. 122 if with_padding and pad_size > 0: 123 shard = pad_tensor(shard, self.dim, pad_size) 124 shard = shard.contiguous() if contiguous else shard 125 shard_list.append(shard) 126 return shard_list, pad_sizes 127 128 @staticmethod 129 def _local_shard_size_on_dim( 130 size_on_dim: int, 131 num_chunks: int, 132 rank: int, 133 return_offset: bool = False, 134 ) -> Tuple[int, int]: 135 """ 136 returns the local shard size and offset on a given tensor dim 137 """ 138 # Compute the chunk size inline with ``torch.chunk`` 139 if size_on_dim % num_chunks == 0: 140 full_chunk_size = size_on_dim // num_chunks 141 return full_chunk_size, full_chunk_size * rank if return_offset else -1 142 143 # uneven sharding case 144 full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks 145 shard_starting_idx = full_chunk_size * rank 146 147 if size_on_dim < shard_starting_idx: 148 return 0, size_on_dim if return_offset else -1 149 else: 150 local_shard_size = ( 151 min(size_on_dim, shard_starting_idx + full_chunk_size) 152 - shard_starting_idx 153 ) 154 return local_shard_size, shard_starting_idx if return_offset else -1 155 156 def _shard_tensor( 157 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 158 ) -> torch.Tensor: 159 """ 160 shard and scatter a tensor on a mesh dimension (use coordinate 161 0 on the mesh dimension as source of truth) 162 """ 163 my_coordinate = mesh.get_coordinate() 164 num_chunks = mesh.size(mesh_dim=mesh_dim) 165 166 if my_coordinate is None: 167 # if rank is not part of mesh, we simply return an empty tensor 168 return tensor.new_empty(0, requires_grad=tensor.requires_grad) 169 170 scatter_list, pad_sizes = self._split_tensor( 171 tensor, num_chunks, with_padding=True, contiguous=True 172 ) 173 174 mesh_dim_local_rank = my_coordinate[mesh_dim] 175 output = torch.empty_like(scatter_list[mesh_dim_local_rank]) 176 mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim) 177 178 # Only unpad if the local_tensor was padded on the dimension. 179 if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0: 180 output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank]) 181 return output 182 183 def _reduce_shard_tensor( 184 self, 185 tensor: torch.Tensor, 186 mesh: DeviceMesh, 187 reduce_op: str, 188 mesh_dim: int, 189 ) -> torch.Tensor: 190 """ 191 reduce and scatter a tensor on a mesh dimension 192 """ 193 my_coordinate = mesh.get_coordinate() 194 num_chunks = mesh.size(mesh_dim=mesh_dim) 195 196 if my_coordinate is None: 197 # if rank is not part of mesh, we simply return local_tensor, 198 # which should be an empty tensor 199 return tensor 200 201 is_padded = tensor.size(self.dim) % num_chunks != 0 202 if is_padded: 203 scattered_list, pad_sizes = self._split_tensor( 204 tensor, num_chunks, with_padding=True, contiguous=True 205 ) 206 tensor = torch.cat(scattered_list, dim=self.dim) 207 elif not tensor.is_contiguous(): 208 tensor = tensor.contiguous() 209 210 output = funcol.reduce_scatter_tensor( 211 tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) 212 ) 213 214 if is_padded: 215 output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] 216 return output 217 218 def _to_replicate_tensor( 219 self, 220 local_tensor: torch.Tensor, 221 mesh: DeviceMesh, 222 mesh_dim: int, 223 current_logical_shape: List[int], 224 ) -> torch.Tensor: 225 """ 226 This function all_gather all shards and return a tensor that 227 is replicated on the previously sharded mesh dimension 228 """ 229 num_chunks = mesh.size(mesh_dim=mesh_dim) 230 # check if it's uneven, so we need to pad input tensor before all_gather 231 local_shape = list(local_tensor.size()) 232 233 logical_dim_size = current_logical_shape[self.dim] 234 is_padded = logical_dim_size % num_chunks != 0 235 236 if is_padded: 237 full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks 238 pad_size = full_chunk_size - local_shape[self.dim] 239 local_tensor = pad_tensor(local_tensor, self.dim, pad_size) 240 241 if not local_tensor.is_contiguous(): 242 local_tensor = local_tensor.contiguous() 243 244 result = funcol.all_gather_tensor( 245 local_tensor, 246 gather_dim=self.dim, 247 group=(mesh, mesh_dim), 248 ) 249 if is_padded: 250 unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] 251 result = unpad_tensor(result, self.dim, unpad_size) 252 return result 253 254 def _replicate_to_shard( 255 self, 256 local_tensor: torch.Tensor, 257 mesh: DeviceMesh, 258 mesh_dim: int, 259 shard_index: int, 260 ) -> torch.Tensor: 261 """ 262 transform from replicated tensor to a sharded tensor on 263 the current rank, which would perform a local chunk 264 """ 265 num_chunks = mesh.size(mesh_dim=mesh_dim) 266 shards, _ = self._split_tensor( 267 local_tensor, 268 num_chunks, 269 with_padding=False, 270 contiguous=False, 271 ) 272 return shards[shard_index].clone() 273 274 def _to_new_shard_dim( 275 self, 276 local_tensor: torch.Tensor, 277 mesh: DeviceMesh, 278 mesh_dim: int, 279 current_logical_shape: List[int], 280 new_shard_dim: int, 281 ) -> torch.Tensor: 282 """ 283 transform from existing sharded tensor to a new sharded tensor on 284 that shard on a new dimension, which performs an alltoall 285 """ 286 my_coordinate = mesh.get_coordinate() 287 if my_coordinate is None: 288 # if rank is not part of mesh, we simply return local_tensor, 289 # which should be an empty tensor 290 return local_tensor 291 292 num_chunks = mesh.size(mesh_dim=mesh_dim) 293 294 old_dim_logical_size = current_logical_shape[self.dim] 295 new_dim_logical_size = current_logical_shape[new_shard_dim] 296 old_dim_padding = old_dim_logical_size % num_chunks != 0 297 new_dim_padding = new_dim_logical_size % num_chunks != 0 298 if old_dim_padding: 299 old_dim_full_chunk_size = ( 300 old_dim_logical_size + num_chunks - 1 301 ) // num_chunks 302 old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) 303 local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) 304 if new_dim_padding: 305 new_dim_full_chunk_size = ( 306 new_dim_logical_size + num_chunks - 1 307 ) // num_chunks 308 new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( 309 new_shard_dim 310 ) 311 local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) 312 313 if not local_tensor.is_contiguous(): 314 local_tensor = local_tensor.contiguous() 315 316 new_tensor = shard_dim_alltoall( 317 local_tensor, self.dim, new_shard_dim, mesh, mesh_dim 318 ) 319 320 if old_dim_padding: 321 old_dim_unpad_size = ( 322 old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined] 323 ) 324 new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined] 325 326 if new_dim_padding: 327 local_shard_size_on_new_dim = self._local_shard_size_on_dim( 328 new_dim_logical_size, num_chunks, my_coordinate[mesh_dim] 329 )[0] 330 new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] 331 new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] 332 333 return new_tensor 334 335 def __eq__(self, other: object) -> bool: 336 if not isinstance(other, Shard): 337 return False 338 return self.dim == other.dim 339 340 def __hash__(self) -> int: 341 return hash(self.dim) 342 343 def __repr__(self) -> str: 344 """ 345 machine readable representation of the Shard placement 346 """ 347 return f"Shard(dim={self.dim})" 348 349 def __str__(self) -> str: 350 """human readable representation of the Shard placement""" 351 return f"S({self.dim})" 352 353 354# kw_only is only available in python >= 3.10 355kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} 356 357 358@dataclass(frozen=True, **kw_only_dataclass) 359class _StridedShard(Shard): 360 """ 361 _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor 362 is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. 363 We call this right-to-left sharding which is the opposite of the default 364 left-to-right sharding. See the example below: 365 tensor shape: [8, 8] 366 mesh: [[0, 1], [2, 3]], names=("dp", "tp") 367 placements: [Shard(0), Shard(0)] 368 369 The default sharding behavior shards the tensor on "dp" mesh dimension first then 370 "tp" dimension. The sharding result will be: 371 Rank | Mesh Coordinate | Shard Index 372 ------------------------------------------------ 373 0 | (0, 0) | 0 (row 0-1) 374 1 | (0, 1) | 1 (row 2-3) 375 2 | (1, 0) | 2 (row 4-5) 376 3 | (1, 1) | 3 (row 6-7) 377 378 While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on 379 "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the 380 result: 381 Rank | Mesh Coordinate | Shard Index 382 ------------------------------------------------ 383 0 | (0, 0) | 0 (row 0-1) 384 1 | (0, 1) | 2 (row 4-5) 385 2 | (1, 0) | 1 (row 2-3) 386 3 | (1, 1) | 3 (row 6-7) 387 388 The consequence is, any attempt to redistribute this DTensor to a full replica will 389 produce a wrong result because the shard-to-replicate redistribution always happens 390 right-to-left, regardless it's left-to-right sharding or right-to-left. To address 391 this, we use _StridedShard placement to make this right-to-left sharding compatible 392 with our left-to-right convention on both tensor distribution and redistribution. 393 394 Now with _StridedShard, the right-to-left sharding above can be represented as: 395 tensor shape: [8, 8] 396 mesh: [[0, 1], [2, 3]], names=("dp", "tp") 397 placements: [_StridedShard(0, split_factor=2), Shard(0)] 398 399 And a left-to-right processing of `placements` will produce the same result, which is 400 different from using the `Shard` placement: 401 Rank | Mesh Coordinate | Shard Index 402 ------------------------------------------------ 403 0 | (0, 0) | 0 (row 0-1) 404 1 | (0, 1) | 2 (row 4-5) 405 2 | (1, 0) | 1 (row 2-3) 406 3 | (1, 1) | 3 (row 6-7) 407 408 The argument `split_factor` is the number of existing shards over the tensor sharding 409 dimension before processing the _StridedShard placement, as if the sharding happened 410 right-to-left. In the example above, the tensor should first be sharded on the "tp" 411 dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the 412 `split_factor` of the _StridedShard placement on "dp" dim is 2. 413 414 TODO: strided sharding needs to work fine with uneven sharding. Now it forbids 415 resharding if the tensor is unevenly sharded. 416 TODO: we should remove _StridedShard placement once we can unify it with Shard 417 """ 418 419 split_factor: int 420 421 def __eq__(self, other: object) -> bool: 422 if isinstance(other, _StridedShard): 423 return self.dim == other.dim and self.split_factor == other.split_factor 424 elif isinstance(other, Shard): 425 # TODO: this is to avoid extra all-gather in dtensor op dispatch 426 # note that sharding prop would not produce _StridedShard and an 427 # placement inequality would introduce an all-gather for resharding 428 return self.dim == other.dim 429 return False 430 431 def __hash__(self) -> int: 432 return hash((self.dim, self.split_factor)) 433 434 def __repr__(self) -> str: 435 """ 436 machine readable representation of the _StridedShard placement 437 """ 438 return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" 439 440 def __str__(self) -> str: 441 """human readable representation of the _StridedShard placement""" 442 return f"_S({self.dim}, {self.split_factor})" 443 444 def _split_tensor( 445 self, 446 tensor: torch.Tensor, 447 num_chunks: int, 448 *, 449 with_padding: bool = True, 450 contiguous: bool = True, 451 ) -> Tuple[List[torch.Tensor], List[int]]: 452 """ 453 TODO: currently _StridedShard does not support padding 454 """ 455 assert ( 456 self.dim <= tensor.ndim 457 ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" 458 459 total_split = num_chunks * self.split_factor 460 assert tensor.size(self.dim) % total_split == 0, ( 461 "_StridedShard currently only allows even sharding but got tensor size" 462 f" {tensor.size(self.dim)} on dim {self.dim} and total split" 463 f" {total_split}={num_chunks} * {self.split_factor}" 464 ) 465 466 group_size = self.split_factor 467 total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim)) 468 tensor_list = [ 469 torch.cat( 470 [ 471 total_split_tensor_list[i + j * num_chunks] # stride is num_chunks 472 for j in range(group_size) 473 ], 474 dim=self.dim, 475 ) 476 for i in range(num_chunks) 477 ] 478 479 if contiguous: 480 tensor_list = [t.contiguous() for t in tensor_list] 481 482 return tensor_list, [] 483 484 def _to_replicate_tensor( 485 self, 486 local_tensor: torch.Tensor, 487 mesh: DeviceMesh, 488 mesh_dim: int, 489 current_logical_shape: List[int], 490 ) -> torch.Tensor: 491 """ 492 Note: currently _StridedShard does not support padding 493 """ 494 num_chunks = mesh.size(mesh_dim=mesh_dim) 495 total_split = num_chunks * self.split_factor 496 # NOTE: we require Strided Sharding to be even for now 497 assert current_logical_shape[self.dim] % total_split == 0, ( 498 "_StridedShard requires even sharding but got tensor size " 499 f"{current_logical_shape[self.dim]} on dim {self.dim} and " 500 f"total split {total_split}=num_chunks {num_chunks} " 501 f"* split_factor {self.split_factor}" 502 ) 503 504 result = funcol.all_gather_tensor( 505 local_tensor, 506 gather_dim=self.dim, 507 group=(mesh, mesh_dim), 508 ) 509 if isinstance(result, funcol.AsyncCollectiveTensor): 510 result = result.wait() 511 512 tensor_shard_list = torch.chunk(result, total_split, dim=self.dim) 513 # rearrange the order 514 new_tensor_shard_list = [] 515 for idx in range(len(tensor_shard_list)): 516 # the shard split of index `idx` is assigned a new index within 517 # _StridedShard._split_tensor: 518 # the original tensor was split into `total_split` chunks, 519 # all chunks with the same `idx % num_chunks` are merged into one 520 # new shard and placed on mesh's local rank `idx % num_chunks` 521 idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks 522 new_tensor_shard_list.append(tensor_shard_list[idx_after_split]) 523 524 return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous() 525 526 527@dataclass(frozen=True) 528class Replicate(Placement): 529 """ 530 The ``Replicate()`` placement describes the DTensor replicating on a corresponding 531 ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a 532 replica of the global Tensor. The ``Replicate`` placement can be used by all 533 DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) 534 """ 535 536 def __eq__(self, other: object) -> bool: 537 return isinstance(other, Replicate) 538 539 def __hash__(self) -> int: 540 # every replicate placement is the same 541 return -1 542 543 def __repr__(self) -> str: 544 """ 545 machine readable representation of the Replicate placement 546 """ 547 return "Replicate()" 548 549 def __str__(self) -> str: 550 """ 551 human readable representation of the Replicate placement 552 """ 553 return "R" 554 555 def _replicate_tensor( 556 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 557 ) -> torch.Tensor: 558 """ 559 Replicate (broadcast) a torch.Tensor on a mesh dimension (use 560 the first coordinate on the mesh dimension as source of truth) 561 """ 562 my_coordinate = mesh.get_coordinate() 563 if my_coordinate is None: 564 # if rank is not part of mesh, we simply return an empty tensor 565 return tensor.new_empty(0, requires_grad=tensor.requires_grad) 566 567 tensor = tensor.contiguous() 568 mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim) 569 return tensor 570 571 572@dataclass(frozen=True) 573class Partial(Placement): 574 """ 575 The ``Partial(reduce_op)`` placement describes the DTensor that is pending 576 reduction on a specified ``DeviceMesh`` dimension, where each rank on the 577 DeviceMesh dimension holds the partial value of the global Tensor. User can 578 redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` 579 placement on the specified ``DeviceMesh`` dimension using ``redistribute``, 580 which would trigger necessary communication operations under the hood (i.e. 581 ``allreduce``, ``reduce_scatter``). 582 583 Args: 584 reduce_op (str, optional): The reduction op to be used for the partial DTensor 585 to produce Replicated/Sharded DTensor. Only element-wise reduction operations 586 are supported, including: "sum", "avg", "product", "max", "min", default: "sum". 587 588 .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, 589 and can only be used by the ``DTensor.from_local`` API. 590 """ 591 592 reduce_op: str = "sum" 593 594 def _reduce_value( 595 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 596 ) -> torch.Tensor: 597 # Partial placement contract #1: 598 # _reduce_value: reduce the value of the tensor on the mesh dimension 599 return funcol.all_reduce( 600 tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) 601 ) 602 603 def _reduce_shard_value( 604 self, 605 tensor: torch.Tensor, 606 mesh: DeviceMesh, 607 mesh_dim: int, 608 shard_spec: Placement, 609 ) -> torch.Tensor: 610 # Partial placement contract #2: 611 # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension 612 shard_spec = cast(Shard, shard_spec) 613 return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) 614 615 def _partition_value( 616 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 617 ) -> torch.Tensor: 618 # Partial placement contract #3: 619 # _partition_value: partition the value of a replicated tensor on the mesh dimension 620 621 # _partition_value is the conjugate operation of _reduce_value 622 # - i.e. _partition_value on a sum reduce op is just a divison operation 623 # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation 624 # TODO: if the reduce_op is min/max, etc. the _partition_value should be a 625 # different operation 626 assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" 627 num_chunks = mesh.size(mesh_dim=mesh_dim) 628 return tensor / num_chunks 629 630 def __eq__(self, other: object) -> bool: 631 if not isinstance(other, Partial): 632 return False 633 return self.reduce_op == other.reduce_op 634 635 def __hash__(self) -> int: 636 return 1 + hash(self.reduce_op) 637 638 def __repr__(self) -> str: 639 """ 640 machine readable representation of the Partial placement 641 """ 642 return f"Partial({self.reduce_op})" 643 644 def __str__(self) -> str: 645 """ 646 human readable representation of the Partial placement 647 """ 648 return "P" 649 650 651# We keep the old _Partial name for a while for BC reason 652_Partial = Partial 653