xref: /aosp_15_r20/external/pytorch/torch/utils/data/distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import math
2from typing import Iterator, Optional, TypeVar
3
4import torch
5import torch.distributed as dist
6from torch.utils.data.dataset import Dataset
7from torch.utils.data.sampler import Sampler
8
9
10__all__ = ["DistributedSampler"]
11
12
13_T_co = TypeVar("_T_co", covariant=True)
14
15
16class DistributedSampler(Sampler[_T_co]):
17    r"""Sampler that restricts data loading to a subset of the dataset.
18
19    It is especially useful in conjunction with
20    :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
21    process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
22    :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
23    original dataset that is exclusive to it.
24
25    .. note::
26        Dataset is assumed to be of constant size and that any instance of it always
27        returns the same elements in the same order.
28
29    Args:
30        dataset: Dataset used for sampling.
31        num_replicas (int, optional): Number of processes participating in
32            distributed training. By default, :attr:`world_size` is retrieved from the
33            current distributed group.
34        rank (int, optional): Rank of the current process within :attr:`num_replicas`.
35            By default, :attr:`rank` is retrieved from the current distributed
36            group.
37        shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
38            indices.
39        seed (int, optional): random seed used to shuffle the sampler if
40            :attr:`shuffle=True`. This number should be identical across all
41            processes in the distributed group. Default: ``0``.
42        drop_last (bool, optional): if ``True``, then the sampler will drop the
43            tail of the data to make it evenly divisible across the number of
44            replicas. If ``False``, the sampler will add extra indices to make
45            the data evenly divisible across the replicas. Default: ``False``.
46
47    .. warning::
48        In distributed mode, calling the :meth:`set_epoch` method at
49        the beginning of each epoch **before** creating the :class:`DataLoader` iterator
50        is necessary to make shuffling work properly across multiple epochs. Otherwise,
51        the same ordering will be always used.
52
53    Example::
54
55        >>> # xdoctest: +SKIP
56        >>> sampler = DistributedSampler(dataset) if is_distributed else None
57        >>> loader = DataLoader(dataset, shuffle=(sampler is None),
58        ...                     sampler=sampler)
59        >>> for epoch in range(start_epoch, n_epochs):
60        ...     if is_distributed:
61        ...         sampler.set_epoch(epoch)
62        ...     train(loader)
63    """
64
65    def __init__(
66        self,
67        dataset: Dataset,
68        num_replicas: Optional[int] = None,
69        rank: Optional[int] = None,
70        shuffle: bool = True,
71        seed: int = 0,
72        drop_last: bool = False,
73    ) -> None:
74        if num_replicas is None:
75            if not dist.is_available():
76                raise RuntimeError("Requires distributed package to be available")
77            num_replicas = dist.get_world_size()
78        if rank is None:
79            if not dist.is_available():
80                raise RuntimeError("Requires distributed package to be available")
81            rank = dist.get_rank()
82        if rank >= num_replicas or rank < 0:
83            raise ValueError(
84                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
85            )
86        self.dataset = dataset
87        self.num_replicas = num_replicas
88        self.rank = rank
89        self.epoch = 0
90        self.drop_last = drop_last
91        # If the dataset length is evenly divisible by # of replicas, then there
92        # is no need to drop any data, since the dataset will be split equally.
93        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
94            # Split to nearest available length that is evenly divisible.
95            # This is to ensure each rank receives the same amount of data when
96            # using this Sampler.
97            self.num_samples = math.ceil(
98                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
99            )
100        else:
101            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
102        self.total_size = self.num_samples * self.num_replicas
103        self.shuffle = shuffle
104        self.seed = seed
105
106    def __iter__(self) -> Iterator[_T_co]:
107        if self.shuffle:
108            # deterministically shuffle based on epoch and seed
109            g = torch.Generator()
110            g.manual_seed(self.seed + self.epoch)
111            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
112        else:
113            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]
114
115        if not self.drop_last:
116            # add extra samples to make it evenly divisible
117            padding_size = self.total_size - len(indices)
118            if padding_size <= len(indices):
119                indices += indices[:padding_size]
120            else:
121                indices += (indices * math.ceil(padding_size / len(indices)))[
122                    :padding_size
123                ]
124        else:
125            # remove tail of data to make it evenly divisible.
126            indices = indices[: self.total_size]
127        assert len(indices) == self.total_size
128
129        # subsample
130        indices = indices[self.rank : self.total_size : self.num_replicas]
131        assert len(indices) == self.num_samples
132
133        return iter(indices)
134
135    def __len__(self) -> int:
136        return self.num_samples
137
138    def set_epoch(self, epoch: int) -> None:
139        r"""
140        Set the epoch for this sampler.
141
142        When :attr:`shuffle=True`, this ensures all replicas
143        use a different random ordering for each epoch. Otherwise, the next iteration of this
144        sampler will yield the same ordering.
145
146        Args:
147            epoch (int): Epoch number.
148        """
149        self.epoch = epoch
150