xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/grouping.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from collections import defaultdict
4from typing import (
5    Any,
6    Callable,
7    DefaultDict,
8    Iterator,
9    List,
10    Optional,
11    Sized,
12    Type,
13    TypeVar,
14)
15
16import torch.utils.data.datapipes.iter.sharding
17from torch.utils.data.datapipes._decorator import functional_datapipe
18from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe
19from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
20
21
22__all__ = [
23    "BatcherIterDataPipe",
24    "GrouperIterDataPipe",
25    "UnBatcherIterDataPipe",
26]
27
28
29_T_co = TypeVar("_T_co", covariant=True)
30
31
32def __getattr__(name: str):
33    if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]:
34        warnings.warn(
35            f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1"
36            f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`",
37            category=FutureWarning,
38            stacklevel=2,
39        )
40
41        return getattr(torch.utils.data.datapipes.iter.sharding, name)
42
43    raise AttributeError(f"module {__name__} has no attribute {name}")
44
45
46@functional_datapipe("batch")
47class BatcherIterDataPipe(IterDataPipe[DataChunk]):
48    r"""
49    Creates mini-batches of data (functional name: ``batch``).
50
51    An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the
52    last batch if ``drop_last`` is set to ``False``.
53
54    Args:
55        datapipe: Iterable DataPipe being batched
56        batch_size: The size of each batch
57        drop_last: Option to drop the last batch if it's not full
58        wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
59            defaults to ``DataChunk``
60
61    Example:
62        >>> # xdoctest: +SKIP
63        >>> from torchdata.datapipes.iter import IterableWrapper
64        >>> dp = IterableWrapper(range(10))
65        >>> dp = dp.batch(batch_size=3, drop_last=True)
66        >>> list(dp)
67        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
68    """
69
70    datapipe: IterDataPipe
71    batch_size: int
72    drop_last: bool
73
74    def __init__(
75        self,
76        datapipe: IterDataPipe,
77        batch_size: int,
78        drop_last: bool = False,
79        wrapper_class: Type[DataChunk] = DataChunk,
80    ) -> None:
81        assert batch_size > 0, "Batch size is required to be larger than 0!"
82        super().__init__()
83        self.datapipe = datapipe
84        self.batch_size = batch_size
85        self.drop_last = drop_last
86        self.wrapper_class = wrapper_class
87
88    def __iter__(self) -> Iterator[DataChunk]:
89        batch: List = []
90        for x in self.datapipe:
91            batch.append(x)
92            if len(batch) == self.batch_size:
93                yield self.wrapper_class(batch)
94                batch = []
95        if len(batch) > 0:
96            if not self.drop_last:
97                yield self.wrapper_class(batch)
98
99    def __len__(self) -> int:
100        if isinstance(self.datapipe, Sized):
101            if self.drop_last:
102                return len(self.datapipe) // self.batch_size
103            else:
104                return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
105        else:
106            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
107
108
109@functional_datapipe("unbatch")
110class UnBatcherIterDataPipe(IterDataPipe):
111    r"""
112    Undos batching of data (functional name: ``unbatch``).
113
114    In other words, it flattens the data up to the specified level within a batched DataPipe.
115
116    Args:
117        datapipe: Iterable DataPipe being un-batched
118        unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
119            it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
120
121    Example:
122        >>> # xdoctest: +SKIP
123        >>> from torchdata.datapipes.iter import IterableWrapper
124        >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
125        >>> dp1 = source_dp.unbatch()
126        >>> list(dp1)
127        [[0, 1], [2], [3, 4], [5], [6]]
128        >>> dp2 = source_dp.unbatch(unbatch_level=2)
129        >>> list(dp2)
130        [0, 1, 2, 3, 4, 5, 6]
131    """
132
133    def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1):
134        self.datapipe = datapipe
135        self.unbatch_level = unbatch_level
136
137    def __iter__(self):
138        for element in self.datapipe:
139            yield from self._dive(element, unbatch_level=self.unbatch_level)
140
141    def _dive(self, element, unbatch_level):
142        if unbatch_level < -1:
143            raise ValueError("unbatch_level must be -1 or >= 0")
144        if unbatch_level == -1:
145            if isinstance(element, (list, DataChunk)):
146                for item in element:
147                    yield from self._dive(item, unbatch_level=-1)
148            else:
149                yield element
150        elif unbatch_level == 0:
151            yield element
152        else:
153            if isinstance(element, (list, DataChunk)):
154                for item in element:
155                    yield from self._dive(item, unbatch_level=unbatch_level - 1)
156            else:
157                raise IndexError(
158                    f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe"
159                )
160
161
162@functional_datapipe("groupby")
163class GrouperIterDataPipe(IterDataPipe[DataChunk]):
164    r"""
165    Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``.
166
167    (functional name: ``groupby``).
168
169    The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
170    will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
171    the DataPipe will yield the largest batch with the same key, provided that its size is larger
172    than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
173
174    After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
175    will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
176
177    Args:
178        datapipe: Iterable datapipe to be grouped
179        group_key_fn: Function used to generate group key from the data of the source datapipe
180        keep_key: Option to yield the matching key along with the items in a tuple,
181            resulting in `(key, [items])` otherwise returning [items]
182        buffer_size: The size of buffer for ungrouped data
183        group_size: The max size of each group, a batch is yielded as soon as it reaches this size
184        guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
185        drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
186            when the buffer is full
187
188    Example:
189        >>> import os
190        >>> # xdoctest: +SKIP
191        >>> from torchdata.datapipes.iter import IterableWrapper
192        >>> def group_fn(file):
193        ...     return os.path.basename(file).split(".")[0]
194        >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
195        >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
196        >>> list(dp0)
197        [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
198        >>> # A group is yielded as soon as its size equals to `group_size`
199        >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
200        >>> list(dp1)
201        [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
202        >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
203        >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
204        >>> list(dp2)
205        [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
206    """
207
208    def __init__(
209        self,
210        datapipe: IterDataPipe[_T_co],
211        group_key_fn: Callable[[_T_co], Any],
212        *,
213        keep_key: bool = False,
214        buffer_size: int = 10000,
215        group_size: Optional[int] = None,
216        guaranteed_group_size: Optional[int] = None,
217        drop_remaining: bool = False,
218    ):
219        _check_unpickable_fn(group_key_fn)
220        self.datapipe = datapipe
221        self.group_key_fn = group_key_fn
222
223        self.keep_key = keep_key
224        self.max_buffer_size = buffer_size
225        self.buffer_elements: DefaultDict[Any, List] = defaultdict(list)
226        self.curr_buffer_size = 0
227        self.group_size = group_size
228        self.guaranteed_group_size = None
229        if group_size is not None and buffer_size is not None:
230            assert 0 < group_size <= buffer_size
231            self.guaranteed_group_size = group_size
232        if guaranteed_group_size is not None:
233            assert group_size is not None and 0 < guaranteed_group_size <= group_size
234            self.guaranteed_group_size = guaranteed_group_size
235        self.drop_remaining = drop_remaining
236        self.wrapper_class = DataChunk
237
238    def _remove_biggest_key(self):
239        biggest_key = None
240        biggest_size = 0
241        result_to_yield = None
242        for findkey in self.buffer_elements.keys():
243            if len(self.buffer_elements[findkey]) > biggest_size:
244                biggest_size = len(self.buffer_elements[findkey])
245                biggest_key = findkey
246
247        if (
248            self.guaranteed_group_size is not None
249            and biggest_size < self.guaranteed_group_size
250            and not self.drop_remaining
251        ):
252            raise RuntimeError(
253                "Failed to group items", str(self.buffer_elements[biggest_key])
254            )
255
256        if (
257            self.guaranteed_group_size is None
258            or biggest_size >= self.guaranteed_group_size
259        ):
260            result_to_yield = self.buffer_elements[biggest_key]
261
262        self.curr_buffer_size -= biggest_size
263        del self.buffer_elements[biggest_key]
264
265        return result_to_yield
266
267    def __iter__(self):
268        for x in self.datapipe:
269            key = self.group_key_fn(x)
270
271            self.buffer_elements[key].append(x)
272            self.curr_buffer_size += 1
273
274            if self.group_size is not None and self.group_size == len(
275                self.buffer_elements[key]
276            ):
277                result: DataChunk[Any] = self.wrapper_class(self.buffer_elements[key])
278                yield (key, result) if self.keep_key else result
279                self.curr_buffer_size -= len(self.buffer_elements[key])
280                del self.buffer_elements[key]
281
282            if self.curr_buffer_size == self.max_buffer_size:
283                result_to_yield = self._remove_biggest_key()
284                if result_to_yield is not None:
285                    result = self.wrapper_class(result_to_yield)
286                    yield (key, result) if self.keep_key else result
287
288        for key in tuple(self.buffer_elements.keys()):
289            result = self.wrapper_class(self.buffer_elements.pop(key))
290            self.curr_buffer_size -= len(result)
291            yield (key, result) if self.keep_key else result
292
293    def reset(self) -> None:
294        self.curr_buffer_size = 0
295        self.buffer_elements = defaultdict(list)
296
297    def __getstate__(self):
298        state = (
299            self.datapipe,
300            self.group_key_fn,
301            self.keep_key,
302            self.max_buffer_size,
303            self.group_size,
304            self.guaranteed_group_size,
305            self.drop_remaining,
306            self.wrapper_class,
307            self._valid_iterator_id,
308            self._number_of_samples_yielded,
309        )
310        if IterDataPipe.getstate_hook is not None:
311            return IterDataPipe.getstate_hook(state)
312        return state
313
314    def __setstate__(self, state):
315        (
316            self.datapipe,
317            self.group_key_fn,
318            self.keep_key,
319            self.max_buffer_size,
320            self.group_size,
321            self.guaranteed_group_size,
322            self.drop_remaining,
323            self.wrapper_class,
324            self._valid_iterator_id,
325            self._number_of_samples_yielded,
326        ) = state
327        self.curr_buffer_size = 0
328        self.buffer_elements = defaultdict(list)
329
330    def __del__(self):
331        self.buffer_elements.clear()
332