xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/placement_types.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4from dataclasses import dataclass
5from typing import cast, List, Optional, Tuple
6
7import torch
8import torch.distributed._functional_collectives as funcol
9from torch.distributed.device_mesh import DeviceMesh
10from torch.distributed.tensor._collective_utils import (
11    fill_empty_tensor_to_shards,
12    mesh_broadcast,
13    mesh_scatter,
14    pad_tensor,
15    shard_dim_alltoall,
16    unpad_tensor,
17)
18
19
20__all__ = ["Placement", "Shard", "Replicate", "Partial"]
21
22
23class Placement:
24    """
25    The base class for the Placement type, where it describes how a DTensor is placed onto the
26    ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
27    It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
28    and ``Partial``.
29
30    This class is not meant to be used directly, mainly served as a typing stub.
31    """
32
33    # convenient utils to check for placement types
34    def is_shard(self, dim: Optional[int] = None) -> bool:
35        is_shard_instance = isinstance(self, Shard)
36        if dim is not None and is_shard_instance:
37            return cast(Shard, self).dim == dim
38        else:
39            return is_shard_instance
40
41    def is_replicate(self) -> bool:
42        return isinstance(self, Replicate)
43
44    def is_partial(self) -> bool:
45        return isinstance(self, Partial)
46
47
48@dataclass(frozen=True)
49class Shard(Placement):
50    """
51    The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension
52    ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the
53    DeviceMesh dimension only holds a shard/piece of the global Tensor. The
54    ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the
55    last few shards on the DeviceMesh dimension might be empty when the tensor dimension
56    is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be
57    used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)
58
59    Args:
60        dim (int): The tensor dimension that describes the DTensor is sharded over its
61            corresponding DeviceMesh dimension.
62
63    .. warning:: sharding on a tensor dimension where the tensor dimension size is not
64        evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
65    """
66
67    dim: int
68
69    def _split_tensor(
70        self,
71        tensor: torch.Tensor,
72        num_chunks: int,
73        *,
74        with_padding: bool = True,
75        contiguous: bool = True,
76    ) -> Tuple[List[torch.Tensor], List[int]]:
77        """
78        This function uses torch.chunk to split a tensor into num_chunks shards along
79        the Shard placement dimension, and return a list of shards with their pad sizes.
80
81        Keyword args:
82            with_padding (bool, optional): when True, we pad the tensor on the last
83            few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
84            This is because collectives usually require equal size tensor inputs
85        """
86        assert (
87            self.dim <= tensor.ndim
88        ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
89
90        # chunk tensor over dimension `dim` into n slices
91        tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
92        num_empty_tensors = num_chunks - len(tensor_list)
93
94        # if no need to have padding or tensor dim size is evenly sharded already
95        # we can return early.
96        if not with_padding or tensor.size(self.dim) % num_chunks == 0:
97            if contiguous:
98                tensor_list = [t.contiguous() for t in tensor_list]
99            return (
100                fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors),
101                [],
102            )
103
104        # compute the chunk size inline with ``torch.chunk`` to calculate padding
105        full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks
106
107        # Compute chunk size for each chunk for ``self.dim``
108        chunk_sizes = [
109            tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0
110            for idx in range(num_chunks)
111        ]
112        # Compute pad size on each chunk
113        pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes]
114
115        # Reuse tensor to fill empty chunk with empty tensor
116        tensor_list = fill_empty_tensor_to_shards(
117            tensor_list, self.dim, num_empty_tensors
118        )
119        shard_list = []
120        for shard, pad_size in zip(tensor_list, pad_sizes):
121            # Fill the empty tensor with zeroes with padding.
122            if with_padding and pad_size > 0:
123                shard = pad_tensor(shard, self.dim, pad_size)
124            shard = shard.contiguous() if contiguous else shard
125            shard_list.append(shard)
126        return shard_list, pad_sizes
127
128    @staticmethod
129    def _local_shard_size_on_dim(
130        size_on_dim: int,
131        num_chunks: int,
132        rank: int,
133        return_offset: bool = False,
134    ) -> Tuple[int, int]:
135        """
136        returns the local shard size and offset on a given tensor dim
137        """
138        # Compute the chunk size inline with ``torch.chunk``
139        if size_on_dim % num_chunks == 0:
140            full_chunk_size = size_on_dim // num_chunks
141            return full_chunk_size, full_chunk_size * rank if return_offset else -1
142
143        # uneven sharding case
144        full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks
145        shard_starting_idx = full_chunk_size * rank
146
147        if size_on_dim < shard_starting_idx:
148            return 0, size_on_dim if return_offset else -1
149        else:
150            local_shard_size = (
151                min(size_on_dim, shard_starting_idx + full_chunk_size)
152                - shard_starting_idx
153            )
154            return local_shard_size, shard_starting_idx if return_offset else -1
155
156    def _shard_tensor(
157        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
158    ) -> torch.Tensor:
159        """
160        shard and scatter a tensor on a mesh dimension (use coordinate
161        0 on the mesh dimension as source of truth)
162        """
163        my_coordinate = mesh.get_coordinate()
164        num_chunks = mesh.size(mesh_dim=mesh_dim)
165
166        if my_coordinate is None:
167            # if rank is not part of mesh, we simply return an empty tensor
168            return tensor.new_empty(0, requires_grad=tensor.requires_grad)
169
170        scatter_list, pad_sizes = self._split_tensor(
171            tensor, num_chunks, with_padding=True, contiguous=True
172        )
173
174        mesh_dim_local_rank = my_coordinate[mesh_dim]
175        output = torch.empty_like(scatter_list[mesh_dim_local_rank])
176        mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim)
177
178        # Only unpad if the local_tensor was padded on the dimension.
179        if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0:
180            output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank])
181        return output
182
183    def _reduce_shard_tensor(
184        self,
185        tensor: torch.Tensor,
186        mesh: DeviceMesh,
187        reduce_op: str,
188        mesh_dim: int,
189    ) -> torch.Tensor:
190        """
191        reduce and scatter a tensor on a mesh dimension
192        """
193        my_coordinate = mesh.get_coordinate()
194        num_chunks = mesh.size(mesh_dim=mesh_dim)
195
196        if my_coordinate is None:
197            # if rank is not part of mesh, we simply return local_tensor,
198            # which should be an empty tensor
199            return tensor
200
201        is_padded = tensor.size(self.dim) % num_chunks != 0
202        if is_padded:
203            scattered_list, pad_sizes = self._split_tensor(
204                tensor, num_chunks, with_padding=True, contiguous=True
205            )
206            tensor = torch.cat(scattered_list, dim=self.dim)
207        elif not tensor.is_contiguous():
208            tensor = tensor.contiguous()
209
210        output = funcol.reduce_scatter_tensor(
211            tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim)
212        )
213
214        if is_padded:
215            output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]])  # type: ignore[possibly-undefined]
216        return output
217
218    def _to_replicate_tensor(
219        self,
220        local_tensor: torch.Tensor,
221        mesh: DeviceMesh,
222        mesh_dim: int,
223        current_logical_shape: List[int],
224    ) -> torch.Tensor:
225        """
226        This function all_gather all shards and return a tensor that
227        is replicated on the previously sharded mesh dimension
228        """
229        num_chunks = mesh.size(mesh_dim=mesh_dim)
230        # check if it's uneven, so we need to pad input tensor before all_gather
231        local_shape = list(local_tensor.size())
232
233        logical_dim_size = current_logical_shape[self.dim]
234        is_padded = logical_dim_size % num_chunks != 0
235
236        if is_padded:
237            full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
238            pad_size = full_chunk_size - local_shape[self.dim]
239            local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
240
241        if not local_tensor.is_contiguous():
242            local_tensor = local_tensor.contiguous()
243
244        result = funcol.all_gather_tensor(
245            local_tensor,
246            gather_dim=self.dim,
247            group=(mesh, mesh_dim),
248        )
249        if is_padded:
250            unpad_size = full_chunk_size * num_chunks - logical_dim_size  # type: ignore[possibly-undefined]
251            result = unpad_tensor(result, self.dim, unpad_size)
252        return result
253
254    def _replicate_to_shard(
255        self,
256        local_tensor: torch.Tensor,
257        mesh: DeviceMesh,
258        mesh_dim: int,
259        shard_index: int,
260    ) -> torch.Tensor:
261        """
262        transform from replicated tensor to a sharded tensor on
263        the current rank, which would perform a local chunk
264        """
265        num_chunks = mesh.size(mesh_dim=mesh_dim)
266        shards, _ = self._split_tensor(
267            local_tensor,
268            num_chunks,
269            with_padding=False,
270            contiguous=False,
271        )
272        return shards[shard_index].clone()
273
274    def _to_new_shard_dim(
275        self,
276        local_tensor: torch.Tensor,
277        mesh: DeviceMesh,
278        mesh_dim: int,
279        current_logical_shape: List[int],
280        new_shard_dim: int,
281    ) -> torch.Tensor:
282        """
283        transform from existing sharded tensor to a new sharded tensor on
284        that shard on a new dimension, which performs an alltoall
285        """
286        my_coordinate = mesh.get_coordinate()
287        if my_coordinate is None:
288            # if rank is not part of mesh, we simply return local_tensor,
289            # which should be an empty tensor
290            return local_tensor
291
292        num_chunks = mesh.size(mesh_dim=mesh_dim)
293
294        old_dim_logical_size = current_logical_shape[self.dim]
295        new_dim_logical_size = current_logical_shape[new_shard_dim]
296        old_dim_padding = old_dim_logical_size % num_chunks != 0
297        new_dim_padding = new_dim_logical_size % num_chunks != 0
298        if old_dim_padding:
299            old_dim_full_chunk_size = (
300                old_dim_logical_size + num_chunks - 1
301            ) // num_chunks
302            old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim)
303            local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size)
304        if new_dim_padding:
305            new_dim_full_chunk_size = (
306                new_dim_logical_size + num_chunks - 1
307            ) // num_chunks
308            new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size(
309                new_shard_dim
310            )
311            local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size)
312
313        if not local_tensor.is_contiguous():
314            local_tensor = local_tensor.contiguous()
315
316        new_tensor = shard_dim_alltoall(
317            local_tensor, self.dim, new_shard_dim, mesh, mesh_dim
318        )
319
320        if old_dim_padding:
321            old_dim_unpad_size = (
322                old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim]  # type: ignore[possibly-undefined]
323            )
324            new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size)  # type: ignore[possibly-undefined]
325
326        if new_dim_padding:
327            local_shard_size_on_new_dim = self._local_shard_size_on_dim(
328                new_dim_logical_size, num_chunks, my_coordinate[mesh_dim]
329            )[0]
330            new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim  # type: ignore[possibly-undefined]
331            new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size)  # type: ignore[possibly-undefined]
332
333        return new_tensor
334
335    def __eq__(self, other: object) -> bool:
336        if not isinstance(other, Shard):
337            return False
338        return self.dim == other.dim
339
340    def __hash__(self) -> int:
341        return hash(self.dim)
342
343    def __repr__(self) -> str:
344        """
345        machine readable representation of the Shard placement
346        """
347        return f"Shard(dim={self.dim})"
348
349    def __str__(self) -> str:
350        """human readable representation of the Shard placement"""
351        return f"S({self.dim})"
352
353
354# kw_only is only available in python >= 3.10
355kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {}
356
357
358@dataclass(frozen=True, **kw_only_dataclass)
359class _StridedShard(Shard):
360    """
361    _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor
362    is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension.
363    We call this right-to-left sharding which is the opposite of the default
364    left-to-right sharding. See the example below:
365        tensor shape: [8, 8]
366        mesh: [[0, 1], [2, 3]], names=("dp", "tp")
367        placements: [Shard(0), Shard(0)]
368
369    The default sharding behavior shards the tensor on "dp" mesh dimension first then
370    "tp" dimension. The sharding result will be:
371        Rank    |   Mesh Coordinate |   Shard Index
372        ------------------------------------------------
373        0       |   (0, 0)          |   0 (row 0-1)
374        1       |   (0, 1)          |   1 (row 2-3)
375        2       |   (1, 0)          |   2 (row 4-5)
376        3       |   (1, 1)          |   3 (row 6-7)
377
378    While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on
379    "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the
380    result:
381        Rank    |   Mesh Coordinate |   Shard Index
382        ------------------------------------------------
383        0       |   (0, 0)          |   0 (row 0-1)
384        1       |   (0, 1)          |   2 (row 4-5)
385        2       |   (1, 0)          |   1 (row 2-3)
386        3       |   (1, 1)          |   3 (row 6-7)
387
388    The consequence is, any attempt to redistribute this DTensor to a full replica will
389    produce a wrong result because the shard-to-replicate redistribution always happens
390    right-to-left, regardless it's left-to-right sharding or right-to-left. To address
391    this, we use _StridedShard placement to make this right-to-left sharding compatible
392    with our left-to-right convention on both tensor distribution and redistribution.
393
394    Now with _StridedShard, the right-to-left sharding above can be represented as:
395        tensor shape: [8, 8]
396        mesh: [[0, 1], [2, 3]], names=("dp", "tp")
397        placements: [_StridedShard(0, split_factor=2), Shard(0)]
398
399    And a left-to-right processing of `placements` will produce the same result, which is
400    different from using the `Shard` placement:
401        Rank    |   Mesh Coordinate |   Shard Index
402        ------------------------------------------------
403        0       |   (0, 0)          |   0 (row 0-1)
404        1       |   (0, 1)          |   2 (row 4-5)
405        2       |   (1, 0)          |   1 (row 2-3)
406        3       |   (1, 1)          |   3 (row 6-7)
407
408    The argument `split_factor` is the number of existing shards over the tensor sharding
409    dimension before processing the _StridedShard placement, as if the sharding happened
410    right-to-left. In the example above, the tensor should first be sharded on the "tp"
411    dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the
412    `split_factor` of the _StridedShard placement on "dp" dim is 2.
413
414    TODO: strided sharding needs to work fine with uneven sharding. Now it forbids
415    resharding if the tensor is unevenly sharded.
416    TODO: we should remove _StridedShard placement once we can unify it with Shard
417    """
418
419    split_factor: int
420
421    def __eq__(self, other: object) -> bool:
422        if isinstance(other, _StridedShard):
423            return self.dim == other.dim and self.split_factor == other.split_factor
424        elif isinstance(other, Shard):
425            # TODO: this is to avoid extra all-gather in dtensor op dispatch
426            # note that sharding prop would not produce _StridedShard and an
427            # placement inequality would introduce an all-gather for resharding
428            return self.dim == other.dim
429        return False
430
431    def __hash__(self) -> int:
432        return hash((self.dim, self.split_factor))
433
434    def __repr__(self) -> str:
435        """
436        machine readable representation of the _StridedShard placement
437        """
438        return f"_StridedShard(dim={self.dim}, sf={self.split_factor})"
439
440    def __str__(self) -> str:
441        """human readable representation of the _StridedShard placement"""
442        return f"_S({self.dim}, {self.split_factor})"
443
444    def _split_tensor(
445        self,
446        tensor: torch.Tensor,
447        num_chunks: int,
448        *,
449        with_padding: bool = True,
450        contiguous: bool = True,
451    ) -> Tuple[List[torch.Tensor], List[int]]:
452        """
453        TODO: currently _StridedShard does not support padding
454        """
455        assert (
456            self.dim <= tensor.ndim
457        ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
458
459        total_split = num_chunks * self.split_factor
460        assert tensor.size(self.dim) % total_split == 0, (
461            "_StridedShard currently only allows even sharding but got tensor size"
462            f" {tensor.size(self.dim)} on dim {self.dim} and total split"
463            f" {total_split}={num_chunks} * {self.split_factor}"
464        )
465
466        group_size = self.split_factor
467        total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim))
468        tensor_list = [
469            torch.cat(
470                [
471                    total_split_tensor_list[i + j * num_chunks]  # stride is num_chunks
472                    for j in range(group_size)
473                ],
474                dim=self.dim,
475            )
476            for i in range(num_chunks)
477        ]
478
479        if contiguous:
480            tensor_list = [t.contiguous() for t in tensor_list]
481
482        return tensor_list, []
483
484    def _to_replicate_tensor(
485        self,
486        local_tensor: torch.Tensor,
487        mesh: DeviceMesh,
488        mesh_dim: int,
489        current_logical_shape: List[int],
490    ) -> torch.Tensor:
491        """
492        Note: currently _StridedShard does not support padding
493        """
494        num_chunks = mesh.size(mesh_dim=mesh_dim)
495        total_split = num_chunks * self.split_factor
496        # NOTE: we require Strided Sharding to be even for now
497        assert current_logical_shape[self.dim] % total_split == 0, (
498            "_StridedShard requires even sharding but got tensor size "
499            f"{current_logical_shape[self.dim]} on dim {self.dim} and "
500            f"total split {total_split}=num_chunks {num_chunks} "
501            f"* split_factor {self.split_factor}"
502        )
503
504        result = funcol.all_gather_tensor(
505            local_tensor,
506            gather_dim=self.dim,
507            group=(mesh, mesh_dim),
508        )
509        if isinstance(result, funcol.AsyncCollectiveTensor):
510            result = result.wait()
511
512        tensor_shard_list = torch.chunk(result, total_split, dim=self.dim)
513        # rearrange the order
514        new_tensor_shard_list = []
515        for idx in range(len(tensor_shard_list)):
516            # the shard split of index `idx` is assigned a new index within
517            # _StridedShard._split_tensor:
518            # the original tensor was split into `total_split` chunks,
519            # all chunks with the same `idx % num_chunks` are merged into one
520            # new shard and placed on mesh's local rank `idx % num_chunks`
521            idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks
522            new_tensor_shard_list.append(tensor_shard_list[idx_after_split])
523
524        return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous()
525
526
527@dataclass(frozen=True)
528class Replicate(Placement):
529    """
530    The ``Replicate()`` placement describes the DTensor replicating on a corresponding
531    ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a
532    replica of the global Tensor. The ``Replicate`` placement can be used by all
533    DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.)
534    """
535
536    def __eq__(self, other: object) -> bool:
537        return isinstance(other, Replicate)
538
539    def __hash__(self) -> int:
540        # every replicate placement is the same
541        return -1
542
543    def __repr__(self) -> str:
544        """
545        machine readable representation of the Replicate placement
546        """
547        return "Replicate()"
548
549    def __str__(self) -> str:
550        """
551        human readable representation of the Replicate placement
552        """
553        return "R"
554
555    def _replicate_tensor(
556        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
557    ) -> torch.Tensor:
558        """
559        Replicate (broadcast) a torch.Tensor on a mesh dimension (use
560        the first coordinate on the mesh dimension as source of truth)
561        """
562        my_coordinate = mesh.get_coordinate()
563        if my_coordinate is None:
564            # if rank is not part of mesh, we simply return an empty tensor
565            return tensor.new_empty(0, requires_grad=tensor.requires_grad)
566
567        tensor = tensor.contiguous()
568        mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim)
569        return tensor
570
571
572@dataclass(frozen=True)
573class Partial(Placement):
574    """
575    The ``Partial(reduce_op)`` placement describes the DTensor that is pending
576    reduction on a specified ``DeviceMesh`` dimension, where each rank on the
577    DeviceMesh dimension holds the partial value of the global Tensor. User can
578    redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)``
579    placement on the specified ``DeviceMesh`` dimension using ``redistribute``,
580    which would trigger necessary communication operations under the hood (i.e.
581    ``allreduce``, ``reduce_scatter``).
582
583    Args:
584        reduce_op (str, optional): The reduction op to be used for the partial DTensor
585            to produce Replicated/Sharded DTensor. Only element-wise reduction operations
586            are supported, including: "sum", "avg", "product", "max", "min", default: "sum".
587
588    .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators,
589        and can only be used by the ``DTensor.from_local`` API.
590    """
591
592    reduce_op: str = "sum"
593
594    def _reduce_value(
595        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
596    ) -> torch.Tensor:
597        # Partial placement contract #1:
598        # _reduce_value: reduce the value of the tensor on the mesh dimension
599        return funcol.all_reduce(
600            tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
601        )
602
603    def _reduce_shard_value(
604        self,
605        tensor: torch.Tensor,
606        mesh: DeviceMesh,
607        mesh_dim: int,
608        shard_spec: Placement,
609    ) -> torch.Tensor:
610        # Partial placement contract #2:
611        # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension
612        shard_spec = cast(Shard, shard_spec)
613        return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
614
615    def _partition_value(
616        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
617    ) -> torch.Tensor:
618        # Partial placement contract #3:
619        # _partition_value: partition the value of a replicated tensor on the mesh dimension
620
621        # _partition_value is the conjugate operation of _reduce_value
622        # - i.e. _partition_value on a sum reduce op is just a divison operation
623        # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation
624        # TODO: if the reduce_op is min/max, etc. the _partition_value should be a
625        # different operation
626        assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!"
627        num_chunks = mesh.size(mesh_dim=mesh_dim)
628        return tensor / num_chunks
629
630    def __eq__(self, other: object) -> bool:
631        if not isinstance(other, Partial):
632            return False
633        return self.reduce_op == other.reduce_op
634
635    def __hash__(self) -> int:
636        return 1 + hash(self.reduce_op)
637
638    def __repr__(self) -> str:
639        """
640        machine readable representation of the Partial placement
641        """
642        return f"Partial({self.reduce_op})"
643
644    def __str__(self) -> str:
645        """
646        human readable representation of the Partial placement
647        """
648        return "P"
649
650
651# We keep the old _Partial name for a while for BC reason
652_Partial = Partial
653