xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/sharding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from enum import IntEnum
3from typing import Dict, Sized, Tuple
4
5from torch.utils.data.datapipes._decorator import functional_datapipe
6from torch.utils.data.datapipes.datapipe import IterDataPipe
7
8
9__all__ = [
10    "SHARDING_PRIORITIES",
11    "ShardingFilterIterDataPipe",
12]
13
14
15class SHARDING_PRIORITIES(IntEnum):
16    DEFAULT = 1
17    DISTRIBUTED = 2
18    MULTIPROCESSING = 3
19
20
21class _ShardingIterDataPipe(IterDataPipe):
22    def apply_sharding(
23        self,
24        num_of_instances: int,
25        instance_id: int,
26        sharding_group: SHARDING_PRIORITIES,
27    ):
28        raise NotImplementedError
29
30
31@functional_datapipe("sharding_filter")
32class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
33    r"""
34    Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``).
35
36    After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
37    original DataPipe, where `n` equals to the number of instances.
38
39    Args:
40        source_datapipe: Iterable DataPipe that will be sharded
41    """
42
43    def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
44        self.source_datapipe = source_datapipe
45        self.sharding_group_filter = sharding_group_filter
46        self.groups: Dict[int, Tuple[int, int]] = {}
47        self.num_of_instances = 1
48        self.instance_id = 0
49        self._update_num_of_instances()
50
51    def apply_sharding(
52        self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT
53    ):
54        if instance_id >= num_of_instances:
55            raise ValueError(
56                f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})"
57            )
58        if sharding_group == SHARDING_PRIORITIES.DEFAULT:
59            if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups:
60                raise RuntimeError(
61                    "ShardingFilter cannot mix DEFAULT and non DEFAULT groups"
62                )
63        else:
64            if SHARDING_PRIORITIES.DEFAULT in self.groups:
65                raise RuntimeError(
66                    "ShardingFilter cannot mix DEFAULT and non DEFAULT groups"
67                )
68        self.groups[sharding_group] = (num_of_instances, instance_id)
69        self._update_num_of_instances()
70
71    def _update_num_of_instances(self):
72        sorted_sharding_groups = []
73        for key in sorted(self.groups.keys()):
74            if self.sharding_group_filter is None or key == self.sharding_group_filter:
75                sorted_sharding_groups.append(self.groups[key])
76
77        sorted_sharding_groups.reverse()
78
79        self.num_of_instances = 1
80        self.instance_id = 0
81
82        for group_num_of_instances, group_instance_id in sorted_sharding_groups:
83            self.instance_id += self.num_of_instances * group_instance_id
84            self.num_of_instances *= group_num_of_instances
85
86    def __iter__(self):
87        for i, item in enumerate(self.source_datapipe):
88            if i % self.num_of_instances == self.instance_id:
89                yield item
90
91    def __len__(self):
92        if isinstance(self.source_datapipe, Sized):
93            return len(self.source_datapipe) // self.num_of_instances + (
94                1
95                if (
96                    self.instance_id < len(self.source_datapipe) % self.num_of_instances
97                )
98                else 0
99            )
100        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
101