xref: /aosp_15_r20/external/pytorch/test/test_datapipe.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3# Owner(s): ["module: dataloader"]
4
5import copy
6import itertools
7import os
8import os.path
9import pickle
10import pydoc
11import random
12import sys
13import tempfile
14import warnings
15from functools import partial
16from typing import (
17    Any,
18    Awaitable,
19    Dict,
20    Generic,
21    Iterator,
22    List,
23    Optional,
24    Set,
25    Tuple,
26    Type,
27    TYPE_CHECKING,
28    TypeVar,
29    Union,
30)
31
32if not TYPE_CHECKING:
33    # pyre isn't treating this the same as a typing.NamedTuple
34    from typing_extensions import NamedTuple
35else:
36    from typing import NamedTuple
37
38import operator
39from unittest import skipIf
40
41import numpy as np
42
43import torch
44import torch.nn as nn
45import torch.utils.data.datapipes as dp
46import torch.utils.data.graph
47import torch.utils.data.graph_settings
48from torch.testing._internal.common_utils import (
49    run_tests,
50    skipIfNoDill,
51    skipIfTorchDynamo,
52    suppress_warnings,
53    TEST_DILL,
54    TestCase,
55)
56from torch.utils._import_utils import import_dill
57from torch.utils.data import (
58    argument_validation,
59    DataChunk,
60    DataLoader,
61    IterDataPipe,
62    MapDataPipe,
63    RandomSampler,
64    runtime_validation,
65    runtime_validation_disabled,
66)
67from torch.utils.data.datapipes.dataframe import (
68    CaptureDataFrame,
69    dataframe_wrapper as df_wrapper,
70)
71from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
72from torch.utils.data.datapipes.utils.common import StreamWrapper
73from torch.utils.data.datapipes.utils.decoder import (
74    basichandlers as decoder_basichandlers,
75)
76from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration
77from torch.utils.data.graph import traverse_dps
78
79dill = import_dill()
80HAS_DILL = TEST_DILL
81
82try:
83    import pandas  # type: ignore[import]  # noqa: F401 F403
84
85    HAS_PANDAS = True
86except ImportError:
87    HAS_PANDAS = False
88skipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)")
89
90skipTyping = skipIf(True, "TODO: Fix typing bug")
91T_co = TypeVar("T_co", covariant=True)
92
93
94def create_temp_dir_and_files():
95    # The temp dir and files within it will be released and deleted in tearDown().
96    # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
97    temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
98    temp_dir_path = temp_dir.name
99    with tempfile.NamedTemporaryFile(
100        dir=temp_dir_path, delete=False, suffix=".txt"
101    ) as f:
102        temp_file1_name = f.name
103    with tempfile.NamedTemporaryFile(
104        dir=temp_dir_path, delete=False, suffix=".byte"
105    ) as f:
106        temp_file2_name = f.name
107    with tempfile.NamedTemporaryFile(
108        dir=temp_dir_path, delete=False, suffix=".empty"
109    ) as f:
110        temp_file3_name = f.name
111
112    with open(temp_file1_name, "w") as f1:
113        f1.write("0123456789abcdef")
114    with open(temp_file2_name, "wb") as f2:
115        f2.write(b"0123456789abcdef")
116
117    temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path)  # noqa: P201
118    temp_sub_dir_path = temp_sub_dir.name
119    with tempfile.NamedTemporaryFile(
120        dir=temp_sub_dir_path, delete=False, suffix=".txt"
121    ) as f:
122        temp_sub_file1_name = f.name
123    with tempfile.NamedTemporaryFile(
124        dir=temp_sub_dir_path, delete=False, suffix=".byte"
125    ) as f:
126        temp_sub_file2_name = f.name
127
128    with open(temp_sub_file1_name, "w") as f1:
129        f1.write("0123456789abcdef")
130    with open(temp_sub_file2_name, "wb") as f2:
131        f2.write(b"0123456789abcdef")
132
133    return [
134        (temp_dir, temp_file1_name, temp_file2_name, temp_file3_name),
135        (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name),
136    ]
137
138
139def reset_after_n_next_calls(
140    datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], n: int
141) -> Tuple[List[T_co], List[T_co]]:
142    """
143    Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list
144    Then, reset the DataPipe and return a tuple of two lists
145        1. A list of elements yielded before the reset
146        2. A list of all elements of the DataPipe after the reset
147    """
148    it = iter(datapipe)
149    res_before_reset = []
150    for _ in range(n):
151        res_before_reset.append(next(it))
152    return res_before_reset, list(datapipe)
153
154
155def odd_or_even(x: int) -> int:
156    return x % 2
157
158
159class TestDataChunk(TestCase):
160    def setUp(self):
161        self.elements = list(range(10))
162        random.shuffle(self.elements)
163        self.chunk: DataChunk[int] = DataChunk(self.elements)
164
165    def test_getitem(self):
166        for i in range(10):
167            self.assertEqual(self.elements[i], self.chunk[i])
168
169    def test_iter(self):
170        for ele, dc in zip(self.elements, iter(self.chunk)):
171            self.assertEqual(ele, dc)
172
173    def test_len(self):
174        self.assertEqual(len(self.elements), len(self.chunk))
175
176    def test_as_string(self):
177        self.assertEqual(str(self.chunk), str(self.elements))
178
179        batch = [self.elements] * 3
180        chunks: List[DataChunk[int]] = [DataChunk(self.elements)] * 3
181        self.assertEqual(str(batch), str(chunks))
182
183    def test_sort(self):
184        chunk: DataChunk[int] = DataChunk(self.elements)
185        chunk.sort()
186        self.assertTrue(isinstance(chunk, DataChunk))
187        for i, d in enumerate(chunk):
188            self.assertEqual(i, d)
189
190    def test_reverse(self):
191        chunk: DataChunk[int] = DataChunk(self.elements)
192        chunk.reverse()
193        self.assertTrue(isinstance(chunk, DataChunk))
194        for i in range(10):
195            self.assertEqual(chunk[i], self.elements[9 - i])
196
197    def test_random_shuffle(self):
198        elements = list(range(10))
199        chunk: DataChunk[int] = DataChunk(elements)
200
201        rng = random.Random(0)
202        rng.shuffle(chunk)
203
204        rng = random.Random(0)
205        rng.shuffle(elements)
206
207        self.assertEqual(chunk, elements)
208
209
210class TestStreamWrapper(TestCase):
211    class _FakeFD:
212        def __init__(self, filepath):
213            self.filepath = filepath
214            self.opened = False
215            self.closed = False
216
217        def open(self):
218            self.opened = True
219
220        def read(self):
221            if self.opened:
222                return "".join(self)
223            else:
224                raise OSError("Cannot read from un-opened file descriptor")
225
226        def __iter__(self):
227            for i in range(5):
228                yield str(i)
229
230        def close(self):
231            if self.opened:
232                self.opened = False
233                self.closed = True
234
235        def __repr__(self):
236            return "FakeFD"
237
238    def test_dir(self):
239        fd = TestStreamWrapper._FakeFD("")
240        wrap_fd = StreamWrapper(fd)
241
242        s = set(dir(wrap_fd))
243        for api in ["open", "read", "close"]:
244            self.assertTrue(api in s)
245
246    @skipIfTorchDynamo()
247    def test_api(self):
248        fd = TestStreamWrapper._FakeFD("")
249        wrap_fd = StreamWrapper(fd)
250
251        self.assertFalse(fd.opened)
252        self.assertFalse(fd.closed)
253        with self.assertRaisesRegex(IOError, "Cannot read from"):
254            wrap_fd.read()
255
256        wrap_fd.open()
257        self.assertTrue(fd.opened)
258        self.assertEqual("01234", wrap_fd.read())
259
260        del wrap_fd
261        self.assertFalse(fd.opened)
262        self.assertTrue(fd.closed)
263
264    def test_pickle(self):
265        with tempfile.TemporaryFile() as f:
266            with self.assertRaises(TypeError) as ctx1:
267                pickle.dumps(f)
268
269            wrap_f = StreamWrapper(f)
270            with self.assertRaises(TypeError) as ctx2:
271                pickle.dumps(wrap_f)
272
273            # Same exception when pickle
274            self.assertEqual(str(ctx1.exception), str(ctx2.exception))
275
276        fd = TestStreamWrapper._FakeFD("")
277        wrap_fd = StreamWrapper(fd)
278        _ = pickle.loads(pickle.dumps(wrap_fd))
279
280    def test_repr(self):
281        fd = TestStreamWrapper._FakeFD("")
282        wrap_fd = StreamWrapper(fd)
283        self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>")
284
285        with tempfile.TemporaryFile() as f:
286            wrap_f = StreamWrapper(f)
287            self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")
288
289
290class TestIterableDataPipeBasic(TestCase):
291    def setUp(self):
292        ret = create_temp_dir_and_files()
293        self.temp_dir = ret[0][0]
294        self.temp_files = ret[0][1:]
295        self.temp_sub_dir = ret[1][0]
296        self.temp_sub_files = ret[1][1:]
297
298    def tearDown(self):
299        try:
300            self.temp_sub_dir.cleanup()
301            self.temp_dir.cleanup()
302        except Exception as e:
303            warnings.warn(
304                f"TestIterableDatasetBasic was not able to cleanup temp dir due to {str(e)}"
305            )
306
307    def test_listdirfiles_iterable_datapipe(self):
308        temp_dir = self.temp_dir.name
309        datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, "")
310
311        count = 0
312        for pathname in datapipe:
313            count = count + 1
314            self.assertTrue(pathname in self.temp_files)
315        self.assertEqual(count, len(self.temp_files))
316
317        count = 0
318        datapipe = dp.iter.FileLister(temp_dir, "", recursive=True)
319        for pathname in datapipe:
320            count = count + 1
321            self.assertTrue(
322                (pathname in self.temp_files) or (pathname in self.temp_sub_files)
323            )
324        self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))
325
326        temp_files = self.temp_files
327        datapipe = dp.iter.FileLister([temp_dir, *temp_files])
328        count = 0
329        for pathname in datapipe:
330            count += 1
331            self.assertTrue(pathname in self.temp_files)
332        self.assertEqual(count, 2 * len(self.temp_files))
333
334        # test functional API
335        datapipe = datapipe.list_files()
336        count = 0
337        for pathname in datapipe:
338            count += 1
339            self.assertTrue(pathname in self.temp_files)
340        self.assertEqual(count, 2 * len(self.temp_files))
341
342    def test_listdirfilesdeterministic_iterable_datapipe(self):
343        temp_dir = self.temp_dir.name
344
345        datapipe = dp.iter.FileLister(temp_dir, "")
346        # The output order should be always the same.
347        self.assertEqual(list(datapipe), list(datapipe))
348
349        datapipe = dp.iter.FileLister(temp_dir, "", recursive=True)
350        # The output order should be always the same.
351        self.assertEqual(list(datapipe), list(datapipe))
352
353    def test_openfilesfromdisk_iterable_datapipe(self):
354        # test import datapipe class directly
355        from torch.utils.data.datapipes.iter import FileLister, FileOpener
356
357        temp_dir = self.temp_dir.name
358        datapipe1 = FileLister(temp_dir, "")
359        datapipe2 = FileOpener(datapipe1, mode="b")
360
361        count = 0
362        for rec in datapipe2:
363            count = count + 1
364            self.assertTrue(rec[0] in self.temp_files)
365            with open(rec[0], "rb") as f:
366                self.assertEqual(rec[1].read(), f.read())
367                rec[1].close()
368        self.assertEqual(count, len(self.temp_files))
369
370        # functional API
371        datapipe3 = datapipe1.open_files(mode="b")
372
373        count = 0
374        for rec in datapipe3:
375            count = count + 1
376            self.assertTrue(rec[0] in self.temp_files)
377            with open(rec[0], "rb") as f:
378                self.assertEqual(rec[1].read(), f.read())
379                rec[1].close()
380        self.assertEqual(count, len(self.temp_files))
381
382        # __len__ Test
383        with self.assertRaises(TypeError):
384            len(datapipe3)
385
386    def test_routeddecoder_iterable_datapipe(self):
387        temp_dir = self.temp_dir.name
388        temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
389        png_data = np.array(
390            [[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]],
391            dtype=np.single,
392        )
393        np.save(temp_pngfile_pathname, png_data)
394        datapipe1 = dp.iter.FileLister(temp_dir, ["*.png", "*.txt"])
395        datapipe2 = dp.iter.FileOpener(datapipe1, mode="b")
396
397        def _png_decoder(extension, data):
398            if extension != "png":
399                return None
400            return np.load(data)
401
402        def _helper(prior_dp, dp, channel_first=False):
403            # Byte stream is not closed
404            for inp in prior_dp:
405                self.assertFalse(inp[1].closed)
406            for inp, rec in zip(prior_dp, dp):
407                ext = os.path.splitext(rec[0])[1]
408                if ext == ".png":
409                    expected = np.array(
410                        [
411                            [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
412                            [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
413                        ],
414                        dtype=np.single,
415                    )
416                    if channel_first:
417                        expected = expected.transpose(2, 0, 1)
418                    self.assertEqual(rec[1], expected)
419                else:
420                    with open(rec[0], "rb") as f:
421                        self.assertEqual(rec[1], f.read().decode("utf-8"))
422                # Corresponding byte stream is closed by Decoder
423                self.assertTrue(inp[1].closed)
424
425        cached = list(datapipe2)
426        with warnings.catch_warnings(record=True) as wa:
427            datapipe3 = dp.iter.RoutedDecoder(cached, _png_decoder)
428        datapipe3.add_handler(decoder_basichandlers)
429        _helper(cached, datapipe3)
430
431        cached = list(datapipe2)
432        with warnings.catch_warnings(record=True) as wa:
433            datapipe4 = dp.iter.RoutedDecoder(cached, decoder_basichandlers)
434        datapipe4.add_handler(_png_decoder)
435        _helper(cached, datapipe4, channel_first=True)
436
437    def test_groupby_iterable_datapipe(self):
438        file_list = [
439            "a.png",
440            "b.png",
441            "c.json",
442            "a.json",
443            "c.png",
444            "b.json",
445            "d.png",
446            "d.json",
447            "e.png",
448            "f.json",
449            "g.png",
450            "f.png",
451            "g.json",
452            "e.json",
453            "h.txt",
454            "h.json",
455        ]
456
457        import io
458
459        datapipe1 = dp.iter.IterableWrapper(
460            [(filename, io.BytesIO(b"12345abcde")) for filename in file_list]
461        )
462
463        def group_fn(data):
464            filepath, _ = data
465            return os.path.basename(filepath).split(".")[0]
466
467        datapipe2 = dp.iter.Grouper(datapipe1, group_key_fn=group_fn, group_size=2)
468
469        def order_fn(data):
470            data.sort(key=lambda f: f[0], reverse=True)
471            return data
472
473        datapipe3 = dp.iter.Mapper(datapipe2, fn=order_fn)  # type: ignore[var-annotated]
474
475        expected_result = [
476            ("a.png", "a.json"),
477            ("c.png", "c.json"),
478            ("b.png", "b.json"),
479            ("d.png", "d.json"),
480            ("f.png", "f.json"),
481            ("g.png", "g.json"),
482            ("e.png", "e.json"),
483            ("h.txt", "h.json"),
484        ]
485
486        count = 0
487        for rec, expected in zip(datapipe3, expected_result):
488            count = count + 1
489            self.assertEqual(os.path.basename(rec[0][0]), expected[0])
490            self.assertEqual(os.path.basename(rec[1][0]), expected[1])
491            for i in [0, 1]:
492                self.assertEqual(rec[i][1].read(), b"12345abcde")
493                rec[i][1].close()
494        self.assertEqual(count, 8)
495
496        # testing the keep_key option
497        datapipe4 = dp.iter.Grouper(
498            datapipe1, group_key_fn=group_fn, keep_key=True, group_size=2
499        )
500
501        def order_fn(data):
502            data[1].sort(key=lambda f: f[0], reverse=True)
503            return data
504
505        datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn)  # type: ignore[var-annotated]
506
507        expected_result = [
508            ("a", ("a.png", "a.json")),
509            ("c", ("c.png", "c.json")),
510            ("b", ("b.png", "b.json")),
511            ("d", ("d.png", "d.json")),
512            ("f", ("f.png", "f.json")),
513            ("g", ("g.png", "g.json")),
514            ("e", ("e.png", "e.json")),
515            ("h", ("h.txt", "h.json")),
516        ]
517
518        count = 0
519        for rec, expected in zip(datapipe5, expected_result):
520            count = count + 1
521            self.assertEqual(rec[0], expected[0])
522            self.assertEqual(rec[1][0][0], expected[1][0])
523            self.assertEqual(rec[1][1][0], expected[1][1])
524            for i in [0, 1]:
525                self.assertEqual(rec[1][i][1].read(), b"12345abcde")
526                rec[1][i][1].close()
527        self.assertEqual(count, 8)
528
529    def test_demux_mux_datapipe(self):
530        numbers = NumbersDataset(10)
531        n1, n2 = numbers.demux(2, lambda x: x % 2)
532        self.assertEqual([0, 2, 4, 6, 8], list(n1))
533        self.assertEqual([1, 3, 5, 7, 9], list(n2))
534
535        # Functional Test: demux and mux works sequentially as expected
536        numbers = NumbersDataset(10)
537        n1, n2, n3 = numbers.demux(3, lambda x: x % 3)
538        n = n1.mux(n2, n3)
539        self.assertEqual(list(range(9)), list(n))
540
541        # Functional Test: Uneven DataPipes
542        source_numbers = list(range(0, 10)) + [10, 12]
543        numbers_dp = dp.iter.IterableWrapper(source_numbers)
544        n1, n2 = numbers_dp.demux(2, lambda x: x % 2)
545        self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1))
546        self.assertEqual([1, 3, 5, 7, 9], list(n2))
547        n = n1.mux(n2)
548        self.assertEqual(list(range(10)), list(n))
549
550    @suppress_warnings  # Suppress warning for lambda fn
551    def test_map_with_col_file_handle_datapipe(self):
552        temp_dir = self.temp_dir.name
553        datapipe1 = dp.iter.FileLister(temp_dir, "")
554        datapipe2 = dp.iter.FileOpener(datapipe1)
555
556        def _helper(datapipe):
557            dp1 = datapipe.map(lambda x: x.read(), input_col=1)
558            dp2 = datapipe.map(lambda x: (x[0], x[1].read()))
559            self.assertEqual(list(dp1), list(dp2))
560
561        # tuple
562        _helper(datapipe2)
563        # list
564        datapipe3 = datapipe2.map(lambda x: list(x))
565        _helper(datapipe3)
566
567
568@skipIfNoDataFrames
569class TestCaptureDataFrame(TestCase):
570    def get_new_df(self):
571        return df_wrapper.create_dataframe([[1, 2]], columns=["a", "b"])
572
573    def compare_capture_and_eager(self, operations):
574        cdf = CaptureDataFrame()
575        cdf = operations(cdf)
576        df = self.get_new_df()
577        cdf = cdf.apply_ops(df)
578
579        df = self.get_new_df()
580        df = operations(df)
581
582        self.assertTrue(df.equals(cdf))
583
584    def test_basic_capture(self):
585        def operations(df):
586            df["c"] = df.b + df["a"] * 7
587            # somehow swallows pandas UserWarning when `df.c = df.b + df['a'] * 7`
588            return df
589
590        self.compare_capture_and_eager(operations)
591
592
593class TestDataFramesPipes(TestCase):
594    """
595    Most of test will fail if pandas instaled, but no dill available.
596    Need to rework them to avoid multiple skips.
597    """
598
599    def _get_datapipe(self, range=10, dataframe_size=7):
600        return NumbersDataset(range).map(lambda i: (i, i % 3))
601
602    def _get_dataframes_pipe(self, range=10, dataframe_size=7):
603        return (
604            NumbersDataset(range)
605            .map(lambda i: (i, i % 3))
606            ._to_dataframes_pipe(columns=["i", "j"], dataframe_size=dataframe_size)
607        )
608
609    @skipIfNoDataFrames
610    @skipIfNoDill  # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map
611    def test_capture(self):
612        dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0]))
613        df_numbers = self._get_dataframes_pipe()
614        df_numbers["k"] = df_numbers["j"] + df_numbers.i * 3
615        expected = list(dp_numbers)
616        actual = list(df_numbers)
617        self.assertEqual(expected, actual)
618
619    @skipIfNoDataFrames
620    @skipIfNoDill
621    def test_shuffle(self):
622        #  With non-zero (but extremely low) probability (when shuffle do nothing),
623        #  this test fails, so feel free to restart
624        df_numbers = self._get_dataframes_pipe(range=1000).shuffle()
625        dp_numbers = self._get_datapipe(range=1000)
626        df_result = [tuple(item) for item in df_numbers]
627        self.assertNotEqual(list(dp_numbers), df_result)
628        self.assertEqual(list(dp_numbers), sorted(df_result))
629
630    @skipIfNoDataFrames
631    @skipIfNoDill
632    def test_batch(self):
633        df_numbers = self._get_dataframes_pipe(range=100).batch(8)
634        df_numbers_list = list(df_numbers)
635        last_batch = df_numbers_list[-1]
636        self.assertEqual(4, len(last_batch))
637        unpacked_batch = [tuple(row) for row in last_batch]
638        self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)
639
640    @skipIfNoDataFrames
641    @skipIfNoDill
642    def test_unbatch(self):
643        df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3)
644        dp_numbers = self._get_datapipe(range=100)
645        self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2)))
646
647    @skipIfNoDataFrames
648    @skipIfNoDill
649    def test_filter(self):
650        df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5)
651        actual = list(df_numbers)
652        self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], actual)
653
654    @skipIfNoDataFrames
655    @skipIfNoDill
656    def test_collate(self):
657        def collate_i(column):
658            return column.sum()
659
660        def collate_j(column):
661            return column.prod()
662
663        df_numbers = self._get_dataframes_pipe(range=30).batch(3)
664        df_numbers = df_numbers.collate({"j": collate_j, "i": collate_i})
665
666        expected_i = [
667            3,
668            12,
669            21,
670            30,
671            39,
672            48,
673            57,
674            66,
675            75,
676            84,
677        ]
678
679        actual_i = []
680        for i, j in df_numbers:
681            actual_i.append(i)
682        self.assertEqual(expected_i, actual_i)
683
684        actual_i = []
685        for item in df_numbers:
686            actual_i.append(item.i)
687        self.assertEqual(expected_i, actual_i)
688
689
690class IDP_NoLen(IterDataPipe):
691    def __init__(self, input_dp):
692        super().__init__()
693        self.input_dp = input_dp
694
695    # Prevent in-place modification
696    def __iter__(self):
697        input_dp = (
698            self.input_dp
699            if isinstance(self.input_dp, IterDataPipe)
700            else copy.deepcopy(self.input_dp)
701        )
702        yield from input_dp
703
704
705def _fake_fn(data):
706    return data
707
708
709def _fake_add(constant, data):
710    return constant + data
711
712
713def _fake_filter_fn(data):
714    return True
715
716
717def _simple_filter_fn(data):
718    return data >= 5
719
720
721def _fake_filter_fn_constant(constant, data):
722    return data >= constant
723
724
725def _mul_10(x):
726    return x * 10
727
728
729def _mod_3_test(x):
730    return x % 3 == 1
731
732
733def _to_list(x):
734    return [x]
735
736
737lambda_fn1 = lambda x: x  # noqa: E731
738lambda_fn2 = lambda x: x % 2  # noqa: E731
739lambda_fn3 = lambda x: x >= 5  # noqa: E731
740
741
742class Add1Module(nn.Module):
743    def forward(self, x):
744        return x + 1
745
746
747class Add1Callable:
748    def __call__(self, x):
749        return x + 1
750
751
752class TestFunctionalIterDataPipe(TestCase):
753    def _serialization_test_helper(self, datapipe, use_dill):
754        if use_dill:
755            serialized_dp = dill.dumps(datapipe)
756            deserialized_dp = dill.loads(serialized_dp)
757        else:
758            serialized_dp = pickle.dumps(datapipe)
759            deserialized_dp = pickle.loads(serialized_dp)
760        try:
761            self.assertEqual(list(datapipe), list(deserialized_dp))
762        except AssertionError as e:
763            print(f"{datapipe} is failing.")
764            raise e
765
766    def _serialization_test_for_single_dp(self, dp, use_dill=False):
767        # 1. Testing for serialization before any iteration starts
768        self._serialization_test_helper(dp, use_dill)
769        # 2. Testing for serialization after DataPipe is partially read
770        it = iter(dp)
771        _ = next(it)
772        self._serialization_test_helper(dp, use_dill)
773        # 3. Testing for serialization after DataPipe is fully read
774        it = iter(dp)
775        _ = list(it)
776        self._serialization_test_helper(dp, use_dill)
777
778    def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill=False):
779        # 1. Testing for serialization before any iteration starts
780        self._serialization_test_helper(dp1, use_dill)
781        self._serialization_test_helper(dp2, use_dill)
782
783        # 2. Testing for serialization after DataPipe is partially read
784        it1, it2 = iter(dp1), iter(dp2)
785        _, _ = next(it1), next(it2)
786        # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning
787        with warnings.catch_warnings(record=True) as wa:
788            self._serialization_test_helper(dp1, use_dill)
789            self._serialization_test_helper(dp2, use_dill)
790
791        # 2.5. Testing for serialization after one child DataPipe is fully read
792        #      (Only for DataPipes with children DataPipes)
793        it1 = iter(dp1)
794        _ = list(it1)  # fully read one child
795        # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning
796        with warnings.catch_warnings(record=True) as wa:
797            self._serialization_test_helper(dp1, use_dill)
798            self._serialization_test_helper(dp2, use_dill)
799
800        # 3. Testing for serialization after DataPipe is fully read
801        it2 = iter(dp2)
802        _ = list(it2)  # fully read the other child
803        self._serialization_test_helper(dp1, use_dill)
804        self._serialization_test_helper(dp2, use_dill)
805
806    def test_serializable(self):
807        picklable_datapipes: List = [
808            (
809                dp.iter.Batcher,
810                None,
811                (
812                    3,
813                    True,
814                ),
815                {},
816            ),
817            (dp.iter.Collator, None, (_fake_fn,), {}),
818            (dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}),
819            (dp.iter.Demultiplexer, None, (2, _simple_filter_fn), {}),
820            (dp.iter.FileLister, ".", (), {}),
821            (dp.iter.FileOpener, None, (), {}),
822            (dp.iter.Filter, None, (_fake_filter_fn,), {}),
823            (dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}),
824            (dp.iter.Forker, None, (2,), {}),
825            (dp.iter.Forker, None, (2,), {"copy": "shallow"}),
826            (dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}),
827            (dp.iter.IterableWrapper, range(10), (), {}),
828            (dp.iter.Mapper, None, (_fake_fn,), {}),
829            (dp.iter.Mapper, None, (partial(_fake_add, 1),), {}),
830            (dp.iter.Multiplexer, None, (dp.iter.IterableWrapper(range(10)),), {}),
831            (dp.iter.Sampler, None, (), {}),
832            (dp.iter.Shuffler, dp.iter.IterableWrapper([0] * 10), (), {}),
833            (dp.iter.StreamReader, None, (), {}),
834            (dp.iter.UnBatcher, None, (0,), {}),
835            (dp.iter.Zipper, None, (dp.iter.IterableWrapper(range(10)),), {}),
836        ]
837        # Skipping comparison for these DataPipes
838        dp_skip_comparison = {dp.iter.FileOpener, dp.iter.StreamReader}
839        # These DataPipes produce multiple DataPipes as outputs and those should be compared
840        dp_compare_children = {dp.iter.Demultiplexer, dp.iter.Forker}
841
842        for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
843            if custom_input is None:
844                custom_input = dp.iter.IterableWrapper(range(10))
845            if (
846                dpipe in dp_skip_comparison
847            ):  # Merely make sure they are picklable and loadable (no value comparison)
848                datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
849                serialized_dp = pickle.dumps(datapipe)
850                _ = pickle.loads(serialized_dp)
851            elif dpipe in dp_compare_children:  # DataPipes that have children
852                dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
853                self._serialization_test_for_dp_with_children(dp1, dp2)
854            else:  # Single DataPipe that requires comparison
855                datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
856                self._serialization_test_for_single_dp(datapipe)
857
858    @skipIfTorchDynamo("Dict with function as keys")
859    def test_serializable_with_dill(self):
860        """Only for DataPipes that take in a function as argument"""
861        input_dp = dp.iter.IterableWrapper(range(10))
862
863        datapipes_with_lambda_fn: List[
864            Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]
865        ] = [
866            (dp.iter.Collator, (lambda_fn1,), {}),
867            (
868                dp.iter.Demultiplexer,
869                (
870                    2,
871                    lambda_fn2,
872                ),
873                {},
874            ),
875            (dp.iter.Filter, (lambda_fn3,), {}),
876            (dp.iter.Grouper, (lambda_fn3,), {}),
877            (dp.iter.Mapper, (lambda_fn1,), {}),
878        ]
879
880        def _local_fns():
881            def _fn1(x):
882                return x
883
884            def _fn2(x):
885                return x % 2
886
887            def _fn3(x):
888                return x >= 5
889
890            return _fn1, _fn2, _fn3
891
892        fn1, fn2, fn3 = _local_fns()
893
894        datapipes_with_local_fn: List[
895            Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]
896        ] = [
897            (dp.iter.Collator, (fn1,), {}),
898            (
899                dp.iter.Demultiplexer,
900                (
901                    2,
902                    fn2,
903                ),
904                {},
905            ),
906            (dp.iter.Filter, (fn3,), {}),
907            (dp.iter.Grouper, (fn3,), {}),
908            (dp.iter.Mapper, (fn1,), {}),
909        ]
910
911        dp_compare_children = {dp.iter.Demultiplexer}
912
913        if HAS_DILL:
914            for dpipe, dp_args, dp_kwargs in (
915                datapipes_with_lambda_fn + datapipes_with_local_fn
916            ):
917                if dpipe in dp_compare_children:
918                    dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
919                    self._serialization_test_for_dp_with_children(
920                        dp1, dp2, use_dill=True
921                    )
922                else:
923                    datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
924                    self._serialization_test_for_single_dp(datapipe, use_dill=True)
925        else:
926            msgs = (
927                r"^Lambda function is not supported by pickle",
928                r"^Local function is not supported by pickle",
929            )
930            for dps, msg in zip(
931                (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs
932            ):
933                for dpipe, dp_args, dp_kwargs in dps:
934                    with self.assertWarnsRegex(UserWarning, msg):
935                        datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
936                    with self.assertRaises((pickle.PicklingError, AttributeError)):
937                        pickle.dumps(datapipe)
938
939    def test_docstring(self):
940        """
941        Ensure functional form of IterDataPipe has the correct docstring from
942        the class form.
943
944        Regression test for https://github.com/pytorch/data/issues/792.
945        """
946        input_dp = dp.iter.IterableWrapper(range(10))
947
948        for dp_funcname in [
949            "batch",
950            "collate",
951            "concat",
952            "demux",
953            "filter",
954            "fork",
955            "map",
956            "mux",
957            "read_from_stream",
958            # "sampler",
959            "shuffle",
960            "unbatch",
961            "zip",
962        ]:
963            if sys.version_info >= (3, 9):
964                docstring = pydoc.render_doc(
965                    thing=getattr(input_dp, dp_funcname), forceload=True
966                )
967            elif sys.version_info < (3, 9):
968                # pydoc works differently on Python 3.8, see
969                # https://docs.python.org/3/whatsnew/3.9.html#pydoc
970                docstring = getattr(input_dp, dp_funcname).__doc__
971
972            assert f"(functional name: ``{dp_funcname}``)" in docstring
973            assert "Args:" in docstring
974            assert "Example:" in docstring or "Examples:" in docstring
975
976    def test_iterable_wrapper_datapipe(self):
977        input_ls = list(range(10))
978        input_dp = dp.iter.IterableWrapper(input_ls)
979
980        # Functional Test: values are unchanged and in the same order
981        self.assertEqual(input_ls, list(input_dp))
982
983        # Functional Test: deep copy by default when an iterator is initialized (first element is read)
984        it = iter(input_dp)
985        self.assertEqual(
986            0, next(it)
987        )  # The deep copy only happens when the first element is read
988        input_ls.append(50)
989        self.assertEqual(list(range(1, 10)), list(it))
990
991        # Functional Test: shallow copy
992        input_ls2 = [1, 2, 3]
993        input_dp_shallow = dp.iter.IterableWrapper(input_ls2, deepcopy=False)
994        input_ls2.append(10)
995        self.assertEqual([1, 2, 3, 10], list(input_dp_shallow))
996
997        # Reset Test: reset the DataPipe
998        input_ls = list(range(10))
999        input_dp = dp.iter.IterableWrapper(input_ls)
1000        n_elements_before_reset = 5
1001        res_before_reset, res_after_reset = reset_after_n_next_calls(
1002            input_dp, n_elements_before_reset
1003        )
1004        self.assertEqual(input_ls[:n_elements_before_reset], res_before_reset)
1005        self.assertEqual(input_ls, res_after_reset)
1006
1007        # __len__ Test: inherits length from sequence
1008        self.assertEqual(len(input_ls), len(input_dp))
1009
1010    def test_concat_iterdatapipe(self):
1011        input_dp1 = dp.iter.IterableWrapper(range(10))
1012        input_dp2 = dp.iter.IterableWrapper(range(5))
1013
1014        # Functional Test: Raises exception for empty input
1015        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
1016            dp.iter.Concater()
1017
1018        # Functional Test: Raises exception for non-IterDataPipe input
1019        with self.assertRaisesRegex(
1020            TypeError, r"Expected all inputs to be `IterDataPipe`"
1021        ):
1022            dp.iter.Concater(input_dp1, ())  # type: ignore[arg-type]
1023
1024        # Functional Test: Concatenate DataPipes as expected
1025        concat_dp = input_dp1.concat(input_dp2)
1026        self.assertEqual(len(concat_dp), 15)
1027        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
1028
1029        # Reset Test: reset the DataPipe
1030        n_elements_before_reset = 5
1031        res_before_reset, res_after_reset = reset_after_n_next_calls(
1032            concat_dp, n_elements_before_reset
1033        )
1034        self.assertEqual(list(range(5)), res_before_reset)
1035        self.assertEqual(list(range(10)) + list(range(5)), res_after_reset)
1036
1037        # __len__ Test: inherits length from source DataPipe
1038        input_dp_nl = IDP_NoLen(range(5))
1039        concat_dp = input_dp1.concat(input_dp_nl)
1040        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1041            len(concat_dp)
1042
1043        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
1044
1045    def test_fork_iterdatapipe(self):
1046        input_dp = dp.iter.IterableWrapper(range(10))
1047
1048        with self.assertRaises(ValueError):
1049            input_dp.fork(num_instances=0)
1050
1051        dp0 = input_dp.fork(num_instances=1, buffer_size=0)
1052        self.assertEqual(dp0, input_dp)
1053
1054        # Functional Test: making sure all child DataPipe shares the same reference
1055        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1056        self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3)))
1057
1058        # Functional Test: one child DataPipe yields all value at a time
1059        output1, output2, output3 = list(dp1), list(dp2), list(dp3)
1060        self.assertEqual(list(range(10)), output1)
1061        self.assertEqual(list(range(10)), output2)
1062        self.assertEqual(list(range(10)), output3)
1063
1064        # Functional Test: two child DataPipes yield value together
1065        dp1, dp2 = input_dp.fork(num_instances=2)
1066        output = []
1067        for n1, n2 in zip(dp1, dp2):
1068            output.append((n1, n2))
1069        self.assertEqual([(i, i) for i in range(10)], output)
1070
1071        # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small
1072        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=4)
1073        it1 = iter(dp1)
1074        for _ in range(4):
1075            next(it1)
1076        with self.assertRaises(BufferError):
1077            next(it1)
1078        with self.assertRaises(BufferError):
1079            list(dp2)
1080
1081        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5)
1082        with self.assertRaises(BufferError):
1083            list(dp2)
1084
1085        # Functional Test: one child DataPipe yields all value first with unlimited buffer
1086        with warnings.catch_warnings(record=True) as wa:
1087            dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1)
1088            self.assertEqual(len(wa), 1)
1089            self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
1090        l1, l2 = list(dp1), list(dp2)
1091        for d1, d2 in zip(l1, l2):
1092            self.assertEqual(d1, d2)
1093
1094        # Functional Test: two child DataPipes yield value together with buffer size 1
1095        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
1096        output = []
1097        for n1, n2 in zip(dp1, dp2):
1098            output.append((n1, n2))
1099        self.assertEqual([(i, i) for i in range(10)], output)
1100
1101        # Functional Test: two child DataPipes yield shallow copies with copy equals shallow
1102        dp1, dp2 = input_dp.map(_to_list).fork(num_instances=2, copy="shallow")
1103        for n1, n2 in zip(dp1, dp2):
1104            self.assertIsNot(n1, n2)
1105            self.assertEqual(n1, n2)
1106
1107        # Functional Test: two child DataPipes yield deep copies with copy equals deep
1108        dp1, dp2 = (
1109            input_dp.map(_to_list).map(_to_list).fork(num_instances=2, copy="deep")
1110        )
1111        for n1, n2 in zip(dp1, dp2):
1112            self.assertIsNot(n1[0], n2[0])
1113            self.assertEqual(n1, n2)
1114
1115        # Functional Test: fork DataPipe raises error for unknown copy method
1116        with self.assertRaises(ValueError):
1117            input_dp.fork(num_instances=2, copy="unknown")
1118
1119        # Functional Test: make sure logic related to slowest_ptr is working properly
1120        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1121        output1, output2, output3 = [], [], []
1122        for i, (n1, n2) in enumerate(zip(dp1, dp2)):
1123            output1.append(n1)
1124            output2.append(n2)
1125            if i == 4:  # yield all of dp3 when halfway through dp1, dp2
1126                output3 = list(dp3)
1127                break
1128        self.assertEqual(list(range(5)), output1)
1129        self.assertEqual(list(range(5)), output2)
1130        self.assertEqual(list(range(10)), output3)
1131
1132        # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
1133        dp1, dp2 = input_dp.fork(num_instances=2)
1134        _ = iter(dp1)
1135        output2 = []
1136        with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
1137            for i, n2 in enumerate(dp2):
1138                output2.append(n2)
1139                if i == 4:
1140                    with warnings.catch_warnings(record=True) as wa:
1141                        _ = iter(dp1)  # This will reset all child DataPipes
1142                        self.assertEqual(len(wa), 1)
1143                        self.assertRegex(
1144                            str(wa[0].message), r"child DataPipes are not exhausted"
1145                        )
1146        self.assertEqual(list(range(5)), output2)
1147
1148        # Reset Test: DataPipe resets when some of it has been read
1149        dp1, dp2 = input_dp.fork(num_instances=2)
1150        output1, output2 = [], []
1151        for i, (n1, n2) in enumerate(zip(dp1, dp2)):
1152            output1.append(n1)
1153            output2.append(n2)
1154            if i == 4:
1155                with warnings.catch_warnings(record=True) as wa:
1156                    _ = iter(dp1)  # Reset both all child DataPipe
1157                    self.assertEqual(len(wa), 1)
1158                    self.assertRegex(
1159                        str(wa[0].message), r"Some child DataPipes are not exhausted"
1160                    )
1161                break
1162        with warnings.catch_warnings(record=True) as wa:
1163            for i, (n1, n2) in enumerate(zip(dp1, dp2)):
1164                output1.append(n1)
1165                output2.append(n2)
1166            self.assertEqual(len(wa), 1)
1167            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
1168        self.assertEqual(list(range(5)) + list(range(10)), output1)
1169        self.assertEqual(list(range(5)) + list(range(10)), output2)
1170
1171        # Reset Test: DataPipe reset, even when some other child DataPipes are not read
1172        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1173        output1, output2 = list(dp1), list(dp2)
1174        self.assertEqual(list(range(10)), output1)
1175        self.assertEqual(list(range(10)), output2)
1176        with warnings.catch_warnings(record=True) as wa:
1177            self.assertEqual(
1178                list(range(10)), list(dp1)
1179            )  # Resets even though dp3 has not been read
1180            self.assertEqual(len(wa), 1)
1181            self.assertRegex(
1182                str(wa[0].message), r"Some child DataPipes are not exhausted"
1183            )
1184        output3 = []
1185        for i, n3 in enumerate(dp3):
1186            output3.append(n3)
1187            if i == 4:
1188                with warnings.catch_warnings(record=True) as wa:
1189                    output1 = list(dp1)  # Resets even though dp3 is only partially read
1190                    self.assertEqual(len(wa), 1)
1191                    self.assertRegex(
1192                        str(wa[0].message), r"Some child DataPipes are not exhausted"
1193                    )
1194                self.assertEqual(list(range(5)), output3)
1195                self.assertEqual(list(range(10)), output1)
1196                break
1197        self.assertEqual(
1198            list(range(10)), list(dp3)
1199        )  # dp3 has to read from the start again
1200
1201        # __len__ Test: Each DataPipe inherits the source datapipe's length
1202        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1203        self.assertEqual(len(input_dp), len(dp1))
1204        self.assertEqual(len(input_dp), len(dp2))
1205        self.assertEqual(len(input_dp), len(dp3))
1206
1207        # Pickle Test:
1208        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
1209        traverse_dps(dp1)  # This should not raise any error
1210        for _ in zip(dp1, dp2, dp3):
1211            pass
1212        traverse_dps(dp2)  # This should not raise any error either
1213
1214    def test_mux_iterdatapipe(self):
1215        # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted
1216        input_dp1 = dp.iter.IterableWrapper(range(4))
1217        input_dp2 = dp.iter.IterableWrapper(range(4, 8))
1218        input_dp3 = dp.iter.IterableWrapper(range(8, 12))
1219        output_dp = input_dp1.mux(input_dp2, input_dp3)
1220        expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
1221        self.assertEqual(len(expected_output), len(output_dp))
1222        self.assertEqual(expected_output, list(output_dp))
1223
1224        # Functional Test: Uneven input Data Pipes
1225        input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4])
1226        input_dp2 = dp.iter.IterableWrapper([10])
1227        input_dp3 = dp.iter.IterableWrapper([100, 200, 300])
1228        output_dp = input_dp1.mux(input_dp2, input_dp3)
1229        expected_output = [1, 10, 100]
1230        self.assertEqual(len(expected_output), len(output_dp))
1231        self.assertEqual(expected_output, list(output_dp))
1232
1233        # Functional Test: Empty Data Pipe
1234        input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3])
1235        input_dp2 = dp.iter.IterableWrapper([])
1236        output_dp = input_dp1.mux(input_dp2)
1237        self.assertEqual(len(input_dp2), len(output_dp))
1238        self.assertEqual(list(input_dp2), list(output_dp))
1239
1240        # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__
1241        input_dp1 = dp.iter.IterableWrapper(range(10))
1242        input_dp_no_len = IDP_NoLen(range(10))
1243        output_dp = input_dp1.mux(input_dp_no_len)
1244        with self.assertRaises(TypeError):
1245            len(output_dp)
1246
1247    def test_demux_iterdatapipe(self):
1248        input_dp = dp.iter.IterableWrapper(range(10))
1249
1250        with self.assertRaises(ValueError):
1251            input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)
1252
1253        # Functional Test: split into 2 DataPipes and output them one at a time
1254        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1255        output1, output2 = list(dp1), list(dp2)
1256        self.assertEqual(list(range(0, 10, 2)), output1)
1257        self.assertEqual(list(range(1, 10, 2)), output2)
1258
1259        # Functional Test: split into 2 DataPipes and output them together
1260        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1261        output = []
1262        for n1, n2 in zip(dp1, dp2):
1263            output.append((n1, n2))
1264        self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output)
1265
1266        # Functional Test: values of the same classification are lumped together, and buffer_size = 3 being too small
1267        dp1, dp2 = input_dp.demux(
1268            num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4
1269        )
1270        it1 = iter(dp1)
1271        with self.assertRaises(BufferError):
1272            next(
1273                it1
1274            )  # Buffer raises because first 5 elements all belong to the a different child
1275        with self.assertRaises(BufferError):
1276            list(dp2)
1277
1278        # Functional Test: values of the same classification are lumped together, and buffer_size = 5 is just enough
1279        dp1, dp2 = input_dp.demux(
1280            num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5
1281        )
1282        output1, output2 = list(dp1), list(dp2)
1283        self.assertEqual(list(range(5, 10)), output1)
1284        self.assertEqual(list(range(0, 5)), output2)
1285
1286        # Functional Test: values of the same classification are lumped together, and unlimited buffer
1287        with warnings.catch_warnings(record=True) as wa:
1288            dp1, dp2 = input_dp.demux(
1289                num_instances=2,
1290                classifier_fn=lambda x: 0 if x >= 5 else 1,
1291                buffer_size=-1,
1292            )
1293            exp_l = 1 if HAS_DILL else 2
1294            self.assertEqual(len(wa), exp_l)
1295            self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set")
1296        output1, output2 = list(dp1), list(dp2)
1297        self.assertEqual(list(range(5, 10)), output1)
1298        self.assertEqual(list(range(0, 5)), output2)
1299
1300        # Functional Test: classifier returns a value outside of [0, num_instance - 1]
1301        dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
1302        it = iter(dp0[0])
1303        with self.assertRaises(ValueError):
1304            next(it)
1305            next(it)
1306
1307        # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
1308        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1309        _ = iter(dp1)
1310        output2 = []
1311        with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
1312            for i, n2 in enumerate(dp2):
1313                output2.append(n2)
1314                if i == 4:
1315                    with warnings.catch_warnings(record=True) as wa:
1316                        _ = iter(dp1)  # This will reset all child DataPipes
1317                        self.assertEqual(len(wa), 1)
1318                        self.assertRegex(
1319                            str(wa[0].message), r"child DataPipes are not exhausted"
1320                        )
1321        self.assertEqual(list(range(1, 10, 2)), output2)
1322
1323        # Reset Test: DataPipe resets when some of it has been read
1324        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1325        output1, output2 = [], []
1326        for n1, n2 in zip(dp1, dp2):
1327            output1.append(n1)
1328            output2.append(n2)
1329            if n1 == 4:
1330                break
1331        with warnings.catch_warnings(record=True) as wa:
1332            i1 = iter(dp1)  # Reset all child DataPipes
1333            self.assertEqual(len(wa), 1)
1334            self.assertRegex(
1335                str(wa[0].message), r"Some child DataPipes are not exhausted"
1336            )
1337            for n1, n2 in zip(dp1, dp2):
1338                output1.append(n1)
1339                output2.append(n2)
1340            self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1)
1341            self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2)
1342            self.assertEqual(len(wa), 1)
1343            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
1344
1345        # Reset Test: DataPipe reset, even when not all child DataPipes are exhausted
1346        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1347        output1 = list(dp1)
1348        self.assertEqual(list(range(0, 10, 2)), output1)
1349        with warnings.catch_warnings(record=True) as wa:
1350            self.assertEqual(
1351                list(range(0, 10, 2)), list(dp1)
1352            )  # Reset even when dp2 is not read
1353            self.assertEqual(len(wa), 1)
1354            self.assertRegex(
1355                str(wa[0].message), r"Some child DataPipes are not exhausted"
1356            )
1357        output2 = []
1358        for i, n2 in enumerate(dp2):
1359            output2.append(n2)
1360            if i == 1:
1361                self.assertEqual(list(range(1, 5, 2)), output2)
1362                with warnings.catch_warnings(record=True) as wa:
1363                    self.assertEqual(
1364                        list(range(0, 10, 2)), list(dp1)
1365                    )  # Can reset even when dp2 is partially read
1366                    self.assertEqual(len(wa), 1)
1367                    self.assertRegex(
1368                        str(wa[0].message), r"Some child DataPipes are not exhausted"
1369                    )
1370                break
1371        output2 = list(dp2)  # output2 has to read from beginning again
1372        self.assertEqual(list(range(1, 10, 2)), output2)
1373
1374        # Functional Test: drop_none = True
1375        dp1, dp2 = input_dp.demux(
1376            num_instances=2,
1377            classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
1378            drop_none=True,
1379        )
1380        self.assertEqual([2, 4, 6, 8], list(dp1))
1381        self.assertEqual([1, 3, 7, 9], list(dp2))
1382
1383        # Functional Test: drop_none = False
1384        dp1, dp2 = input_dp.demux(
1385            num_instances=2,
1386            classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
1387            drop_none=False,
1388        )
1389        it1 = iter(dp1)
1390        with self.assertRaises(ValueError):
1391            next(it1)
1392
1393        # __len__ Test: __len__ not implemented
1394        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
1395        with self.assertRaises(TypeError):
1396            len(
1397                dp1
1398            )  # It is not implemented as we do not know length for each child in advance
1399        with self.assertRaises(TypeError):
1400            len(dp2)
1401
1402        # Pickle Test:
1403        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=odd_or_even)
1404        traverse_dps(dp1)  # This should not raise any error
1405        for _ in zip(dp1, dp2):
1406            pass
1407        traverse_dps(dp2)  # This should not raise any error either
1408
1409    def test_map_iterdatapipe(self):
1410        target_length = 10
1411        input_dp = dp.iter.IterableWrapper(range(target_length))
1412
1413        def fn(item, dtype=torch.float, *, sum=False):
1414            data = torch.tensor(item, dtype=dtype)
1415            return data if not sum else data.sum()
1416
1417        # Functional Test: apply to each element correctly
1418        map_dp = input_dp.map(fn)
1419        self.assertEqual(target_length, len(map_dp))
1420        for x, y in zip(map_dp, range(target_length)):
1421            self.assertEqual(x, torch.tensor(y, dtype=torch.float))
1422
1423        # Functional Test: works with partial function
1424        map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
1425        for x, y in zip(map_dp, range(target_length)):
1426            self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
1427
1428        # __len__ Test: inherits length from source DataPipe
1429        self.assertEqual(target_length, len(map_dp))
1430
1431        input_dp_nl = IDP_NoLen(range(target_length))
1432        map_dp_nl = input_dp_nl.map(lambda x: x)
1433        for x, y in zip(map_dp_nl, range(target_length)):
1434            self.assertEqual(x, torch.tensor(y, dtype=torch.float))
1435
1436        # __len__ Test: inherits length from source DataPipe - raises error when invalid
1437        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1438            len(map_dp_nl)
1439
1440        # Reset Test: DataPipe resets properly
1441        n_elements_before_reset = 5
1442        res_before_reset, res_after_reset = reset_after_n_next_calls(
1443            map_dp, n_elements_before_reset
1444        )
1445        self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
1446        self.assertEqual(list(range(10)), res_after_reset)
1447
1448    @suppress_warnings  # Suppress warning for lambda fn
1449    def test_map_tuple_list_with_col_iterdatapipe(self):
1450        def fn_11(d):
1451            return -d
1452
1453        def fn_1n(d):
1454            return -d, d
1455
1456        def fn_n1(d0, d1):
1457            return d0 + d1
1458
1459        def fn_nn(d0, d1):
1460            return -d0, -d1, d0 + d1
1461
1462        def fn_n1_def(d0, d1=1):
1463            return d0 + d1
1464
1465        def fn_n1_kwargs(d0, d1, **kwargs):
1466            return d0 + d1
1467
1468        def fn_n1_pos(d0, d1, *args):
1469            return d0 + d1
1470
1471        def fn_n1_sep_pos(d0, *args, d1):
1472            return d0 + d1
1473
1474        def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
1475            return d0 + d1
1476
1477        p_fn_n1 = partial(fn_n1, d1=1)
1478        p_fn_cmplx = partial(fn_cmplx, d2=2)
1479        p_fn_cmplx_large_arg = partial(
1480            fn_cmplx, d2={i: list(range(i)) for i in range(10_000)}
1481        )
1482
1483        def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
1484            for constr in (list, tuple):
1485                datapipe = dp.iter.IterableWrapper(
1486                    [constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]
1487                )
1488                if ref_fn is None:
1489                    with self.assertRaises(error):
1490                        res_dp = datapipe.map(fn, input_col, output_col)
1491                        list(res_dp)
1492                else:
1493                    res_dp = datapipe.map(fn, input_col, output_col)
1494                    ref_dp = datapipe.map(ref_fn)
1495                    self.assertEqual(list(res_dp), list(ref_dp))
1496                    # Reset
1497                    self.assertEqual(list(res_dp), list(ref_dp))
1498
1499        _helper(lambda data: data, fn_n1_def, 0, 1)
1500        _helper(
1501            lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2
1502        )
1503        _helper(lambda data: data, p_fn_n1, 0, 1)
1504        _helper(lambda data: data, p_fn_cmplx, 0, 1)
1505        _helper(lambda data: data, p_fn_cmplx_large_arg, 0, 1)
1506        _helper(
1507            lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2
1508        )
1509        _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2])
1510
1511        # Replacing with one input column and default output column
1512        _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
1513        _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
1514        # The index of input column is out of range
1515        _helper(None, fn_1n, 3, error=IndexError)
1516        # Unmatched input columns with fn arguments
1517        _helper(None, fn_n1, 1, error=ValueError)
1518        _helper(None, fn_n1, [0, 1, 2], error=ValueError)
1519        _helper(None, operator.add, 0, error=ValueError)
1520        _helper(None, operator.add, [0, 1, 2], error=ValueError)
1521        _helper(None, fn_cmplx, 0, 1, ValueError)
1522        _helper(None, fn_n1_pos, 1, error=ValueError)
1523        _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
1524        _helper(None, p_fn_n1, [0, 1], error=ValueError)
1525        _helper(None, fn_1n, [1, 2], error=ValueError)
1526        # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
1527        _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
1528        # Fn has keyword-only arguments
1529        _helper(None, fn_n1_kwargs, 1, error=ValueError)
1530        _helper(None, fn_cmplx, [0, 1], 2, ValueError)
1531
1532        # Replacing with multiple input columns and default output column (the left-most input column)
1533        _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
1534        _helper(
1535            lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])),
1536            fn_nn,
1537            [2, 1],
1538        )
1539
1540        # output_col can only be specified when input_col is not None
1541        _helper(None, fn_n1, None, 1, error=ValueError)
1542        # output_col can only be single-element list or tuple
1543        _helper(None, fn_n1, None, [0, 1], error=ValueError)
1544        # Single-element list as output_col
1545        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
1546        # Replacing with one input column and single specified output column
1547        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
1548        _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
1549        # The index of output column is out of range
1550        _helper(None, fn_1n, 1, 3, error=IndexError)
1551        _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
1552        _helper(
1553            lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]),
1554            fn_nn,
1555            [1, 2],
1556            0,
1557        )
1558
1559        # Appending the output at the end
1560        _helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
1561        _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
1562        _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
1563        _helper(
1564            lambda data: (*data, (-data[1], -data[2], data[1] + data[2])),
1565            fn_nn,
1566            [1, 2],
1567            -1,
1568        )
1569
1570        # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected
1571        _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
1572        _helper(lambda data: (data[0], data[1], int(data[2])), int, 2)
1573
1574        # Handle nn.Module and Callable (without __name__ implemented)
1575        _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0)
1576        _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0)
1577
1578    @suppress_warnings  # Suppress warning for lambda fn
1579    @skipIfTorchDynamo()
1580    def test_map_dict_with_col_iterdatapipe(self):
1581        def fn_11(d):
1582            return -d
1583
1584        def fn_1n(d):
1585            return -d, d
1586
1587        def fn_n1(d0, d1):
1588            return d0 + d1
1589
1590        def fn_nn(d0, d1):
1591            return -d0, -d1, d0 + d1
1592
1593        def fn_n1_def(d0, d1=1):
1594            return d0 + d1
1595
1596        p_fn_n1 = partial(fn_n1, d1=1)
1597
1598        def fn_n1_pos(d0, d1, *args):
1599            return d0 + d1
1600
1601        def fn_n1_kwargs(d0, d1, **kwargs):
1602            return d0 + d1
1603
1604        def fn_kwonly(*, d0, d1):
1605            return d0 + d1
1606
1607        def fn_has_nondefault_kwonly(d0, *, d1):
1608            return d0 + d1
1609
1610        def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
1611            return d0 + d1
1612
1613        p_fn_cmplx = partial(fn_cmplx, d2=2)
1614        p_fn_cmplx_large_arg = partial(
1615            fn_cmplx, d2={i: list(range(i)) for i in range(10_000)}
1616        )
1617
1618        # Prevent modification in-place to support resetting
1619        def _dict_update(data, newdata, remove_idx=None):
1620            _data = dict(data)
1621            _data.update(newdata)
1622            if remove_idx:
1623                for idx in remove_idx:
1624                    del _data[idx]
1625            return _data
1626
1627        def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
1628            datapipe = dp.iter.IterableWrapper(
1629                [
1630                    {"x": 0, "y": 1, "z": 2},
1631                    {"x": 3, "y": 4, "z": 5},
1632                    {"x": 6, "y": 7, "z": 8},
1633                ]
1634            )
1635            if ref_fn is None:
1636                with self.assertRaises(error):
1637                    res_dp = datapipe.map(fn, input_col, output_col)
1638                    list(res_dp)
1639            else:
1640                res_dp = datapipe.map(fn, input_col, output_col)
1641                ref_dp = datapipe.map(ref_fn)
1642                self.assertEqual(list(res_dp), list(ref_dp))
1643                # Reset
1644                self.assertEqual(list(res_dp), list(ref_dp))
1645
1646        _helper(lambda data: data, fn_n1_def, "x", "y")
1647        _helper(lambda data: data, p_fn_n1, "x", "y")
1648        _helper(lambda data: data, p_fn_cmplx, "x", "y")
1649        _helper(lambda data: data, p_fn_cmplx_large_arg, "x", "y")
1650        _helper(
1651            lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
1652            p_fn_cmplx,
1653            ["x", "y", "z"],
1654            "z",
1655        )
1656
1657        _helper(
1658            lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
1659            fn_n1_def,
1660            ["x", "y"],
1661            "z",
1662        )
1663
1664        _helper(None, fn_n1_pos, "x", error=ValueError)
1665        _helper(None, fn_n1_kwargs, "x", error=ValueError)
1666        # non-default kw-only args
1667        _helper(None, fn_kwonly, ["x", "y"], error=ValueError)
1668        _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError)
1669        _helper(None, fn_cmplx, ["x", "y"], error=ValueError)
1670
1671        # Replacing with one input column and default output column
1672        _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
1673        _helper(
1674            lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y"
1675        )
1676        # The key of input column is not in dict
1677        _helper(None, fn_1n, "a", error=KeyError)
1678        # Unmatched input columns with fn arguments
1679        _helper(None, fn_n1, "y", error=ValueError)
1680        _helper(None, fn_1n, ["x", "y"], error=ValueError)
1681        _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
1682        _helper(None, p_fn_n1, ["x", "y"], error=ValueError)
1683        _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
1684        # Replacing with multiple input columns and default output column (the left-most input column)
1685        _helper(
1686            lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]),
1687            fn_n1,
1688            ["z", "x"],
1689        )
1690        _helper(
1691            lambda data: _dict_update(
1692                data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]
1693            ),
1694            fn_nn,
1695            ["z", "y"],
1696        )
1697
1698        # output_col can only be specified when input_col is not None
1699        _helper(None, fn_n1, None, "x", error=ValueError)
1700        # output_col can only be single-element list or tuple
1701        _helper(None, fn_n1, None, ["x", "y"], error=ValueError)
1702        # Single-element list as output_col
1703        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
1704        # Replacing with one input column and single specified output column
1705        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
1706        _helper(
1707            lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}),
1708            fn_1n,
1709            "y",
1710            "z",
1711        )
1712        _helper(
1713            lambda data: _dict_update(data, {"y": data["x"] + data["z"]}),
1714            fn_n1,
1715            ["x", "z"],
1716            "y",
1717        )
1718        _helper(
1719            lambda data: _dict_update(
1720                data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}
1721            ),
1722            fn_nn,
1723            ["y", "z"],
1724            "x",
1725        )
1726
1727        # Adding new key to dict for the output
1728        _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
1729        _helper(
1730            lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}),
1731            fn_1n,
1732            "y",
1733            "a",
1734        )
1735        _helper(
1736            lambda data: _dict_update(data, {"a": data["x"] + data["z"]}),
1737            fn_n1,
1738            ["x", "z"],
1739            "a",
1740        )
1741        _helper(
1742            lambda data: _dict_update(
1743                data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}
1744            ),
1745            fn_nn,
1746            ["y", "z"],
1747            "a",
1748        )
1749
1750    def test_collate_iterdatapipe(self):
1751        arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1752        input_dp = dp.iter.IterableWrapper(arrs)
1753
1754        def _collate_fn(batch, default_type=torch.float):
1755            return torch.tensor(sum(batch), dtype=default_type)
1756
1757        # Functional Test: defaults to the default collate function when a custom one is not specified
1758        collate_dp = input_dp.collate()
1759        for x, y in zip(arrs, collate_dp):
1760            self.assertEqual(torch.tensor(x), y)
1761
1762        # Functional Test: custom collate function
1763        collate_dp = input_dp.collate(collate_fn=_collate_fn)
1764        for x, y in zip(arrs, collate_dp):
1765            self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y)
1766
1767        # Functional Test: custom, partial collate function
1768        collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int))
1769        for x, y in zip(arrs, collate_dp):
1770            self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
1771
1772        # Reset Test: reset the DataPipe and results are still correct
1773        n_elements_before_reset = 1
1774        res_before_reset, res_after_reset = reset_after_n_next_calls(
1775            collate_dp, n_elements_before_reset
1776        )
1777        self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset)
1778        for x, y in zip(arrs, res_after_reset):
1779            self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
1780
1781        # __len__ Test: __len__ is inherited
1782        self.assertEqual(len(input_dp), len(collate_dp))
1783
1784        # __len__ Test: verify that it has no valid __len__ when the source doesn't have it
1785        input_dp_nl = IDP_NoLen(arrs)
1786        collate_dp_nl = input_dp_nl.collate()
1787        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1788            len(collate_dp_nl)
1789        for x, y in zip(arrs, collate_dp_nl):
1790            self.assertEqual(torch.tensor(x), y)
1791
1792    def test_batch_iterdatapipe(self):
1793        arrs = list(range(10))
1794        input_dp = dp.iter.IterableWrapper(arrs)
1795
1796        # Functional Test: raise error when input argument `batch_size = 0`
1797        with self.assertRaises(AssertionError):
1798            input_dp.batch(batch_size=0)
1799
1800        # Functional Test: by default, do not drop the last batch
1801        bs = 3
1802        batch_dp = input_dp.batch(batch_size=bs)
1803        self.assertEqual(len(batch_dp), 4)
1804        for i, batch in enumerate(batch_dp):
1805            self.assertEqual(len(batch), 1 if i == 3 else bs)
1806            self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)])
1807
1808        # Functional Test: Drop the last batch when specified
1809        bs = 4
1810        batch_dp = input_dp.batch(batch_size=bs, drop_last=True)
1811        for i, batch in enumerate(batch_dp):
1812            self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)])
1813
1814        # __len__ test: verifying that the overall length and of each batch is correct
1815        for i, batch in enumerate(batch_dp):
1816            self.assertEqual(len(batch), bs)
1817
1818        # __len__ Test: the length is missing if the source DataPipe doesn't have length
1819        self.assertEqual(len(batch_dp), 2)
1820        input_dp_nl = IDP_NoLen(range(10))
1821        batch_dp_nl = input_dp_nl.batch(batch_size=2)
1822        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1823            len(batch_dp_nl)
1824
1825        # Reset Test: Ensures that the DataPipe can properly reset
1826        n_elements_before_reset = 1
1827        res_before_reset, res_after_reset = reset_after_n_next_calls(
1828            batch_dp, n_elements_before_reset
1829        )
1830        self.assertEqual([[0, 1, 2, 3]], res_before_reset)
1831        self.assertEqual([[0, 1, 2, 3], [4, 5, 6, 7]], res_after_reset)
1832
1833    def test_unbatch_iterdatapipe(self):
1834        target_length = 6
1835        prebatch_dp = dp.iter.IterableWrapper(range(target_length))
1836
1837        # Functional Test: Unbatch DataPipe should be the same as pre-batch DataPipe
1838        input_dp = prebatch_dp.batch(3)
1839        unbatch_dp = input_dp.unbatch()
1840        self.assertEqual(len(list(unbatch_dp)), target_length)  # __len__ is as expected
1841        for i, res in zip(range(target_length), unbatch_dp):
1842            self.assertEqual(i, res)
1843
1844        # Functional Test: unbatch works for an input with nested levels
1845        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
1846        unbatch_dp = input_dp.unbatch()
1847        self.assertEqual(len(list(unbatch_dp)), target_length)
1848        for i, res in zip(range(target_length), unbatch_dp):
1849            self.assertEqual(i, res)
1850
1851        input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
1852
1853        # Functional Test: unbatch works for an input with nested levels
1854        unbatch_dp = input_dp.unbatch()
1855        expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
1856        self.assertEqual(len(list(unbatch_dp)), 4)
1857        for j, res in zip(expected_dp, unbatch_dp):
1858            self.assertEqual(j, res)
1859
1860        # Functional Test: unbatching multiple levels at the same time
1861        unbatch_dp = input_dp.unbatch(unbatch_level=2)
1862        expected_dp2 = [0, 1, 2, 3, 4, 5, 6, 7]
1863        self.assertEqual(len(list(unbatch_dp)), 8)
1864        for i, res in zip(expected_dp2, unbatch_dp):
1865            self.assertEqual(i, res)
1866
1867        # Functional Test: unbatching all levels at the same time
1868        unbatch_dp = input_dp.unbatch(unbatch_level=-1)
1869        self.assertEqual(len(list(unbatch_dp)), 8)
1870        for i, res in zip(expected_dp2, unbatch_dp):
1871            self.assertEqual(i, res)
1872
1873        # Functional Test: raises error when input unbatch_level is less than -1
1874        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
1875        with self.assertRaises(ValueError):
1876            unbatch_dp = input_dp.unbatch(unbatch_level=-2)
1877            for i in unbatch_dp:
1878                print(i)
1879
1880        # Functional Test: raises error when input unbatch_level is too high
1881        with self.assertRaises(IndexError):
1882            unbatch_dp = input_dp.unbatch(unbatch_level=5)
1883            for i in unbatch_dp:
1884                print(i)
1885
1886        # Reset Test: unbatch_dp resets properly
1887        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
1888        unbatch_dp = input_dp.unbatch(unbatch_level=-1)
1889        n_elements_before_reset = 3
1890        res_before_reset, res_after_reset = reset_after_n_next_calls(
1891            unbatch_dp, n_elements_before_reset
1892        )
1893        self.assertEqual([0, 1, 2], res_before_reset)
1894        self.assertEqual([0, 1, 2, 3, 4, 5], res_after_reset)
1895
1896    def test_filter_datapipe(self):
1897        input_ds = dp.iter.IterableWrapper(range(10))
1898
1899        def _filter_fn(data, val):
1900            return data >= val
1901
1902        # Functional Test: filter works with partial function
1903        filter_dp = input_ds.filter(partial(_filter_fn, val=5))
1904        self.assertEqual(list(filter_dp), list(range(5, 10)))
1905
1906        def _non_bool_fn(data):
1907            return 1
1908
1909        # Functional Test: filter function must return bool
1910        filter_dp = input_ds.filter(filter_fn=_non_bool_fn)
1911        with self.assertRaises(ValueError):
1912            temp = list(filter_dp)
1913
1914        # Funtional Test: Specify input_col
1915        tuple_input_ds = dp.iter.IterableWrapper([(d - 1, d, d + 1) for d in range(10)])
1916
1917        # Single input_col
1918        input_col_1_dp = tuple_input_ds.filter(partial(_filter_fn, val=5), input_col=1)
1919        self.assertEqual(
1920            list(input_col_1_dp), [(d - 1, d, d + 1) for d in range(5, 10)]
1921        )
1922
1923        # Multiple input_col
1924        def _mul_filter_fn(a, b):
1925            return a + b < 10
1926
1927        input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2])
1928        self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)])
1929
1930        # invalid input col
1931        with self.assertRaises(ValueError):
1932            tuple_input_ds.filter(_mul_filter_fn, input_col=0)
1933
1934        p_mul_filter_fn = partial(_mul_filter_fn, b=1)
1935        out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0)
1936        self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
1937
1938        def _mul_filter_fn_with_defaults(a, b=1):
1939            return a + b < 10
1940
1941        out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0)
1942        self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
1943
1944        def _mul_filter_fn_with_kw_only(*, a, b):
1945            return a + b < 10
1946
1947        with self.assertRaises(ValueError):
1948            tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0)
1949
1950        def _mul_filter_fn_with_kw_only_1_default(*, a, b=1):
1951            return a + b < 10
1952
1953        with self.assertRaises(ValueError):
1954            tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0)
1955
1956        # __len__ Test: DataPipe has no valid len
1957        with self.assertRaisesRegex(TypeError, r"has no len"):
1958            len(filter_dp)
1959
1960        # Reset Test: DataPipe resets correctly
1961        filter_dp = input_ds.filter(partial(_filter_fn, val=5))
1962        n_elements_before_reset = 3
1963        res_before_reset, res_after_reset = reset_after_n_next_calls(
1964            filter_dp, n_elements_before_reset
1965        )
1966        self.assertEqual(list(range(5, 10))[:n_elements_before_reset], res_before_reset)
1967        self.assertEqual(list(range(5, 10)), res_after_reset)
1968
1969    def test_sampler_iterdatapipe(self):
1970        input_dp = dp.iter.IterableWrapper(range(10))
1971        # Default SequentialSampler
1972        sampled_dp = dp.iter.Sampler(input_dp)  # type: ignore[var-annotated]
1973        self.assertEqual(len(sampled_dp), 10)
1974        for i, x in enumerate(sampled_dp):
1975            self.assertEqual(x, i)
1976
1977        # RandomSampler
1978        random_sampled_dp = dp.iter.Sampler(
1979            input_dp, sampler=RandomSampler, sampler_kwargs={"replacement": True}
1980        )  # type: ignore[var-annotated] # noqa: B950
1981
1982        # Requires `__len__` to build SamplerDataPipe
1983        input_dp_nolen = IDP_NoLen(range(10))
1984        with self.assertRaises(AssertionError):
1985            sampled_dp = dp.iter.Sampler(input_dp_nolen)
1986
1987    def test_stream_reader_iterdatapipe(self):
1988        from io import StringIO
1989
1990        input_dp = dp.iter.IterableWrapper(
1991            [("f1", StringIO("abcde")), ("f2", StringIO("bcdef"))]
1992        )
1993        expected_res = ["abcde", "bcdef"]
1994
1995        # Functional Test: Read full chunk
1996        dp1 = input_dp.read_from_stream()
1997        self.assertEqual([d[1] for d in dp1], expected_res)
1998
1999        # Functional Test: Read full chunk
2000        dp2 = input_dp.read_from_stream(chunk=1)
2001        self.assertEqual([d[1] for d in dp2], [c for s in expected_res for c in s])
2002
2003        # `__len__` Test
2004        with self.assertRaises(TypeError):
2005            len(dp1)
2006
2007    def test_shuffler_iterdatapipe(self):
2008        input_dp = dp.iter.IterableWrapper(list(range(10)))
2009
2010        with self.assertRaises(AssertionError):
2011            shuffle_dp = input_dp.shuffle(buffer_size=0)
2012
2013        # Functional Test: No seed
2014        shuffler_dp = input_dp.shuffle()
2015        self.assertEqual(set(range(10)), set(shuffler_dp))
2016
2017        # Functional Test: With global seed
2018        torch.manual_seed(123)
2019        shuffler_dp = input_dp.shuffle()
2020        res = list(shuffler_dp)
2021        torch.manual_seed(123)
2022        self.assertEqual(list(shuffler_dp), res)
2023
2024        # Functional Test: Set seed
2025        shuffler_dp = input_dp.shuffle().set_seed(123)
2026        res = list(shuffler_dp)
2027        shuffler_dp.set_seed(123)
2028        self.assertEqual(list(shuffler_dp), res)
2029
2030        # Functional Test: deactivate shuffling via set_shuffle
2031        unshuffled_dp = input_dp.shuffle().set_shuffle(False)
2032        self.assertEqual(list(unshuffled_dp), list(input_dp))
2033
2034        # Reset Test:
2035        shuffler_dp = input_dp.shuffle()
2036        n_elements_before_reset = 5
2037        res_before_reset, res_after_reset = reset_after_n_next_calls(
2038            shuffler_dp, n_elements_before_reset
2039        )
2040        self.assertEqual(5, len(res_before_reset))
2041        for x in res_before_reset:
2042            self.assertTrue(x in set(range(10)))
2043        self.assertEqual(set(range(10)), set(res_after_reset))
2044
2045        # __len__ Test: returns the length of the input DataPipe
2046        shuffler_dp = input_dp.shuffle()
2047        self.assertEqual(10, len(shuffler_dp))
2048        exp = list(range(100))
2049
2050        # Serialization Test
2051        from torch.utils.data.datapipes._hook_iterator import _SnapshotState
2052
2053        def _serialization_helper(bs):
2054            shuffler_dp = input_dp.shuffle(buffer_size=bs)
2055            it = iter(shuffler_dp)
2056            for _ in range(2):
2057                next(it)
2058            shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp))
2059            _simple_graph_snapshot_restoration(
2060                shuffler_dp_copy.datapipe,
2061                shuffler_dp.datapipe._number_of_samples_yielded,
2062            )
2063
2064            exp = list(it)
2065            shuffler_dp_copy._snapshot_state = _SnapshotState.Restored
2066            self.assertEqual(exp, list(shuffler_dp_copy))
2067
2068        buffer_sizes = [2, 5, 15]
2069        for bs in buffer_sizes:
2070            _serialization_helper(bs)
2071
2072    def test_zip_iterdatapipe(self):
2073        # Functional Test: raises TypeError when an input is not of type `IterDataPipe`
2074        with self.assertRaises(TypeError):
2075            dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10)))  # type: ignore[arg-type]
2076
2077        # Functional Test: raises TypeError when an input does not have valid length
2078        zipped_dp = dp.iter.Zipper(
2079            dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5))
2080        )  # type: ignore[var-annotated]
2081        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
2082            len(zipped_dp)
2083
2084        # Functional Test: zips the results properly
2085        exp = [(i, i) for i in range(5)]
2086        self.assertEqual(list(zipped_dp), exp)
2087
2088        # Functional Test: zips the inputs properly even when lengths are different (zips to the shortest)
2089        zipped_dp = dp.iter.Zipper(
2090            dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5))
2091        )
2092
2093        # __len__ Test: length matches the length of the shortest input
2094        self.assertEqual(len(zipped_dp), 5)
2095
2096        # Reset Test:
2097        n_elements_before_reset = 3
2098        res_before_reset, res_after_reset = reset_after_n_next_calls(
2099            zipped_dp, n_elements_before_reset
2100        )
2101        expected_res = [(i, i) for i in range(5)]
2102        self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset)
2103        self.assertEqual(expected_res, res_after_reset)
2104
2105
2106class TestFunctionalMapDataPipe(TestCase):
2107    def _serialization_test_helper(self, datapipe, use_dill):
2108        if use_dill:
2109            serialized_dp = dill.dumps(datapipe)
2110            deserialized_dp = dill.loads(serialized_dp)
2111        else:
2112            serialized_dp = pickle.dumps(datapipe)
2113            deserialized_dp = pickle.loads(serialized_dp)
2114        try:
2115            self.assertEqual(list(datapipe), list(deserialized_dp))
2116        except AssertionError as e:
2117            print(f"{datapipe} is failing.")
2118            raise e
2119
2120    def _serialization_test_for_single_dp(self, dp, use_dill=False):
2121        # 1. Testing for serialization before any iteration starts
2122        self._serialization_test_helper(dp, use_dill)
2123        # 2. Testing for serialization after DataPipe is partially read
2124        it = iter(dp)
2125        _ = next(it)
2126        self._serialization_test_helper(dp, use_dill)
2127        # 3. Testing for serialization after DataPipe is fully read
2128        _ = list(dp)
2129        self._serialization_test_helper(dp, use_dill)
2130
2131    def test_serializable(self):
2132        picklable_datapipes: List = [
2133            (dp.map.Batcher, None, (2,), {}),
2134            (dp.map.Concater, None, (dp.map.SequenceWrapper(range(10)),), {}),
2135            (dp.map.Mapper, None, (), {}),
2136            (dp.map.Mapper, None, (_fake_fn,), {}),
2137            (dp.map.Mapper, None, (partial(_fake_add, 1),), {}),
2138            (dp.map.SequenceWrapper, range(10), (), {}),
2139            (dp.map.Shuffler, dp.map.SequenceWrapper([0] * 5), (), {}),
2140            (dp.map.Zipper, None, (dp.map.SequenceWrapper(range(10)),), {}),
2141        ]
2142        for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
2143            if custom_input is None:
2144                custom_input = dp.map.SequenceWrapper(range(10))
2145            datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
2146            self._serialization_test_for_single_dp(datapipe)
2147
2148    def test_serializable_with_dill(self):
2149        """Only for DataPipes that take in a function as argument"""
2150        input_dp = dp.map.SequenceWrapper(range(10))
2151
2152        datapipes_with_lambda_fn: List[
2153            Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
2154        ] = [
2155            (dp.map.Mapper, (lambda_fn1,), {}),
2156        ]
2157
2158        def _local_fns():
2159            def _fn1(x):
2160                return x
2161
2162            return _fn1
2163
2164        fn1 = _local_fns()
2165
2166        datapipes_with_local_fn: List[
2167            Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
2168        ] = [
2169            (dp.map.Mapper, (fn1,), {}),
2170        ]
2171
2172        if HAS_DILL:
2173            for dpipe, dp_args, dp_kwargs in (
2174                datapipes_with_lambda_fn + datapipes_with_local_fn
2175            ):
2176                _ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs))  # type: ignore[call-arg]
2177        else:
2178            msgs = (
2179                r"^Lambda function is not supported by pickle",
2180                r"^Local function is not supported by pickle",
2181            )
2182            for dps, msg in zip(
2183                (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs
2184            ):
2185                for dpipe, dp_args, dp_kwargs in dps:
2186                    with self.assertWarnsRegex(UserWarning, msg):
2187                        datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
2188                    with self.assertRaises((pickle.PicklingError, AttributeError)):
2189                        pickle.dumps(datapipe)
2190
2191    def test_docstring(self):
2192        """
2193        Ensure functional form of MapDataPipe has the correct docstring from
2194        the class form.
2195
2196        Regression test for https://github.com/pytorch/data/issues/792.
2197        """
2198        input_dp = dp.map.SequenceWrapper(range(10))
2199
2200        for dp_funcname in [
2201            "batch",
2202            "concat",
2203            "map",
2204            "shuffle",
2205            "zip",
2206        ]:
2207            if sys.version_info >= (3, 9):
2208                docstring = pydoc.render_doc(
2209                    thing=getattr(input_dp, dp_funcname), forceload=True
2210                )
2211            elif sys.version_info < (3, 9):
2212                # pydoc works differently on Python 3.8, see
2213                # https://docs.python.org/3/whatsnew/3.9.html#pydoc
2214                docstring = getattr(input_dp, dp_funcname).__doc__
2215            assert f"(functional name: ``{dp_funcname}``)" in docstring
2216            assert "Args:" in docstring
2217            assert "Example:" in docstring or "Examples:" in docstring
2218
2219    def test_sequence_wrapper_datapipe(self):
2220        seq = list(range(10))
2221        input_dp = dp.map.SequenceWrapper(seq)
2222
2223        # Functional Test: all elements are equal in the same order
2224        self.assertEqual(seq, list(input_dp))
2225
2226        # Functional Test: confirm deepcopy works by default
2227        seq.append(11)
2228        self.assertEqual(list(range(10)), list(input_dp))  # input_dp shouldn't have 11
2229
2230        # Functional Test: non-deepcopy version is working
2231        seq2 = [1, 2, 3]
2232        input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False)
2233        seq2.append(4)
2234        self.assertEqual(list(seq2), list(input_dp_non_deep))  # should have 4
2235
2236        # Reset Test: reset the DataPipe
2237        seq = list(range(10))
2238        n_elements_before_reset = 5
2239        res_before_reset, res_after_reset = reset_after_n_next_calls(
2240            input_dp, n_elements_before_reset
2241        )
2242        self.assertEqual(list(range(5)), res_before_reset)
2243        self.assertEqual(seq, res_after_reset)
2244
2245        # __len__ Test: inherits length from sequence
2246        self.assertEqual(len(seq), len(input_dp))
2247
2248    def test_concat_mapdatapipe(self):
2249        input_dp1 = dp.map.SequenceWrapper(range(10))
2250        input_dp2 = dp.map.SequenceWrapper(range(5))
2251
2252        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
2253            dp.map.Concater()
2254
2255        with self.assertRaisesRegex(
2256            TypeError, r"Expected all inputs to be `MapDataPipe`"
2257        ):
2258            dp.map.Concater(input_dp1, ())  # type: ignore[arg-type]
2259
2260        concat_dp = input_dp1.concat(input_dp2)
2261        self.assertEqual(len(concat_dp), 15)
2262        for index in range(15):
2263            self.assertEqual(
2264                concat_dp[index], (list(range(10)) + list(range(5)))[index]
2265            )
2266        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
2267
2268    def test_zip_mapdatapipe(self):
2269        input_dp1 = dp.map.SequenceWrapper(range(10))
2270        input_dp2 = dp.map.SequenceWrapper(range(5))
2271        input_dp3 = dp.map.SequenceWrapper(range(15))
2272
2273        # Functional Test: requires at least one input DataPipe
2274        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
2275            dp.map.Zipper()
2276
2277        # Functional Test: all inputs must be MapDataPipes
2278        with self.assertRaisesRegex(
2279            TypeError, r"Expected all inputs to be `MapDataPipe`"
2280        ):
2281            dp.map.Zipper(input_dp1, ())  # type: ignore[arg-type]
2282
2283        # Functional Test: Zip the elements up as a tuples
2284        zip_dp = input_dp1.zip(input_dp2, input_dp3)
2285        self.assertEqual([(i, i, i) for i in range(5)], [zip_dp[i] for i in range(5)])
2286
2287        # Functional Test: Raise IndexError when index equal or exceed the length of the shortest DataPipe
2288        with self.assertRaisesRegex(IndexError, r"out of range"):
2289            input_dp1.zip(input_dp2, input_dp3)[5]
2290
2291        # Functional Test: Ensure `zip` can combine `Batcher` with others
2292        dp1 = dp.map.SequenceWrapper(range(10))
2293        shuffle_dp1 = dp1.batch(2)
2294        dp2 = dp.map.SequenceWrapper(range(10))
2295        shuffle_dp2 = dp2.batch(3)
2296        zip_dp1 = shuffle_dp1.zip(shuffle_dp2)
2297        self.assertEqual(4, len(list(zip_dp1)))
2298        zip_dp2 = shuffle_dp1.zip(dp2)
2299        self.assertEqual(5, len(list(zip_dp2)))
2300
2301        # __len__ Test: returns the length of the shortest DataPipe
2302        zip_dp = input_dp1.zip(input_dp2, input_dp3)
2303        self.assertEqual(5, len(zip_dp))
2304
2305    def test_shuffler_mapdatapipe(self):
2306        input_dp1 = dp.map.SequenceWrapper(range(10))
2307        input_dp2 = dp.map.SequenceWrapper({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
2308
2309        # Functional Test: Assumes 0-index when indices is not given
2310        shuffler_dp = input_dp1.shuffle()
2311        self.assertEqual(set(range(10)), set(shuffler_dp))
2312
2313        # Functional Test: Custom indices are working
2314        shuffler_dp = input_dp2.shuffle(indices=["a", "b", "c", "d", "e"])
2315        self.assertEqual(set(range(1, 6)), set(shuffler_dp))
2316
2317        # Functional Test: With global seed
2318        torch.manual_seed(123)
2319        shuffler_dp = input_dp1.shuffle()
2320        res = list(shuffler_dp)
2321        torch.manual_seed(123)
2322        self.assertEqual(list(shuffler_dp), res)
2323
2324        # Functional Test: Set seed
2325        shuffler_dp = input_dp1.shuffle().set_seed(123)
2326        res = list(shuffler_dp)
2327        shuffler_dp.set_seed(123)
2328        self.assertEqual(list(shuffler_dp), res)
2329
2330        # Functional Test: deactivate shuffling via set_shuffle
2331        unshuffled_dp = input_dp1.shuffle().set_shuffle(False)
2332        self.assertEqual(list(unshuffled_dp), list(input_dp1))
2333
2334        # Reset Test:
2335        shuffler_dp = input_dp1.shuffle()
2336        n_elements_before_reset = 5
2337        res_before_reset, res_after_reset = reset_after_n_next_calls(
2338            shuffler_dp, n_elements_before_reset
2339        )
2340        self.assertEqual(5, len(res_before_reset))
2341        for x in res_before_reset:
2342            self.assertTrue(x in set(range(10)))
2343        self.assertEqual(set(range(10)), set(res_after_reset))
2344
2345        # __len__ Test: returns the length of the input DataPipe
2346        shuffler_dp = input_dp1.shuffle()
2347        self.assertEqual(10, len(shuffler_dp))
2348
2349        # Serialization Test
2350        from torch.utils.data.datapipes._hook_iterator import _SnapshotState
2351
2352        shuffler_dp = input_dp1.shuffle()
2353        it = iter(shuffler_dp)
2354        for _ in range(2):
2355            next(it)
2356        shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp))
2357
2358        exp = list(it)
2359        shuffler_dp_copy._snapshot_state = _SnapshotState.Restored
2360        self.assertEqual(exp, list(shuffler_dp_copy))
2361
2362    def test_map_mapdatapipe(self):
2363        arr = range(10)
2364        input_dp = dp.map.SequenceWrapper(arr)
2365
2366        def fn(item, dtype=torch.float, *, sum=False):
2367            data = torch.tensor(item, dtype=dtype)
2368            return data if not sum else data.sum()
2369
2370        map_dp = input_dp.map(fn)
2371        self.assertEqual(len(input_dp), len(map_dp))
2372        for index in arr:
2373            self.assertEqual(
2374                map_dp[index], torch.tensor(input_dp[index], dtype=torch.float)
2375            )
2376
2377        map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
2378        self.assertEqual(len(input_dp), len(map_dp))
2379        for index in arr:
2380            self.assertEqual(
2381                map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum()
2382            )
2383
2384    def test_batch_mapdatapipe(self):
2385        arr = list(range(13))
2386        input_dp = dp.map.SequenceWrapper(arr)
2387
2388        # Functional Test: batches top level by default
2389        batch_dp = dp.map.Batcher(input_dp, batch_size=2)
2390        self.assertEqual(
2391            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12]], list(batch_dp)
2392        )
2393
2394        # Functional Test: drop_last on command
2395        batch_dp = dp.map.Batcher(input_dp, batch_size=2, drop_last=True)
2396        self.assertEqual(
2397            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], list(batch_dp)
2398        )
2399
2400        # Functional Test: nested batching
2401        batch_dp_2 = batch_dp.batch(batch_size=3)
2402        self.assertEqual(
2403            [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], list(batch_dp_2)
2404        )
2405
2406        # Reset Test:
2407        n_elements_before_reset = 3
2408        res_before_reset, res_after_reset = reset_after_n_next_calls(
2409            batch_dp, n_elements_before_reset
2410        )
2411        self.assertEqual([[0, 1], [2, 3], [4, 5]], res_before_reset)
2412        self.assertEqual(
2413            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], res_after_reset
2414        )
2415
2416        # __len__ Test:
2417        self.assertEqual(6, len(batch_dp))
2418        self.assertEqual(2, len(batch_dp_2))
2419
2420
2421# Metaclass conflict for Python 3.6
2422# Multiple inheritance with NamedTuple is not supported for Python 3.9
2423_generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9)
2424if _generic_namedtuple_allowed:
2425
2426    class InvalidData(NamedTuple, Generic[T_co]):
2427        name: str
2428        data: T_co
2429
2430
2431class TestTyping(TestCase):
2432    def test_isinstance(self):
2433        class A(IterDataPipe):
2434            pass
2435
2436        class B(IterDataPipe):
2437            pass
2438
2439        a = A()
2440        self.assertTrue(isinstance(a, A))
2441        self.assertFalse(isinstance(a, B))
2442
2443    def test_protocol(self):
2444        try:
2445            from typing import Protocol  # type: ignore[attr-defined]
2446        except ImportError:
2447            from typing import _Protocol  # type: ignore[attr-defined]
2448
2449            Protocol = _Protocol
2450
2451        class P(Protocol):
2452            pass
2453
2454        class A(IterDataPipe[P]):
2455            pass
2456
2457    @skipTyping
2458    def test_subtype(self):
2459        from torch.utils.data.datapipes._typing import issubtype
2460
2461        basic_type = (int, str, bool, float, complex, list, tuple, dict, set, T_co)
2462        for t in basic_type:
2463            self.assertTrue(issubtype(t, t))
2464            self.assertTrue(issubtype(t, Any))
2465            if t == T_co:
2466                self.assertTrue(issubtype(Any, t))
2467            else:
2468                self.assertFalse(issubtype(Any, t))
2469        for t1, t2 in itertools.product(basic_type, basic_type):
2470            if t1 == t2 or t2 == T_co:
2471                self.assertTrue(issubtype(t1, t2))
2472            else:
2473                self.assertFalse(issubtype(t1, t2))
2474
2475        T = TypeVar("T", int, str)
2476        S = TypeVar("S", bool, Union[str, int], Tuple[int, T])  # type: ignore[valid-type]
2477        types = (
2478            (int, Optional[int]),
2479            (List, Union[int, list]),
2480            (Tuple[int, str], S),
2481            (Tuple[int, str], tuple),
2482            (T, S),
2483            (S, T_co),
2484            (T, Union[S, Set]),
2485        )
2486        for sub, par in types:
2487            self.assertTrue(issubtype(sub, par))
2488            self.assertFalse(issubtype(par, sub))
2489
2490        subscriptable_types = {
2491            List: 1,
2492            Tuple: 2,  # use 2 parameters
2493            Set: 1,
2494            Dict: 2,
2495        }
2496        for subscript_type, n in subscriptable_types.items():
2497            for ts in itertools.combinations(types, n):
2498                subs, pars = zip(*ts)
2499                sub = subscript_type[subs]  # type: ignore[index]
2500                par = subscript_type[pars]  # type: ignore[index]
2501                self.assertTrue(issubtype(sub, par))
2502                self.assertFalse(issubtype(par, sub))
2503                # Non-recursive check
2504                self.assertTrue(issubtype(par, sub, recursive=False))
2505
2506    @skipTyping
2507    def test_issubinstance(self):
2508        from torch.utils.data.datapipes._typing import issubinstance
2509
2510        basic_data = (1, "1", True, 1.0, complex(1.0, 0.0))
2511        basic_type = (int, str, bool, float, complex)
2512        S = TypeVar("S", bool, Union[str, int])
2513        for d in basic_data:
2514            self.assertTrue(issubinstance(d, Any))
2515            self.assertTrue(issubinstance(d, T_co))
2516            if type(d) in (bool, int, str):
2517                self.assertTrue(issubinstance(d, S))
2518            else:
2519                self.assertFalse(issubinstance(d, S))
2520            for t in basic_type:
2521                if type(d) == t:
2522                    self.assertTrue(issubinstance(d, t))
2523                else:
2524                    self.assertFalse(issubinstance(d, t))
2525        # list/set
2526        dt = (([1, "1", 2], List), (set({1, "1", 2}), Set))
2527        for d, t in dt:
2528            self.assertTrue(issubinstance(d, t))
2529            self.assertTrue(issubinstance(d, t[T_co]))  # type: ignore[index]
2530            self.assertFalse(issubinstance(d, t[int]))  # type: ignore[index]
2531
2532        # dict
2533        d = {"1": 1, "2": 2.0}
2534        self.assertTrue(issubinstance(d, Dict))
2535        self.assertTrue(issubinstance(d, Dict[str, T_co]))
2536        self.assertFalse(issubinstance(d, Dict[str, int]))
2537
2538        # tuple
2539        d = (1, "1", 2)
2540        self.assertTrue(issubinstance(d, Tuple))
2541        self.assertTrue(issubinstance(d, Tuple[int, str, T_co]))
2542        self.assertFalse(issubinstance(d, Tuple[int, Any]))
2543        self.assertFalse(issubinstance(d, Tuple[int, int, int]))
2544
2545    # Static checking annotation
2546    @skipTyping
2547    def test_compile_time(self):
2548        with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"):
2549
2550            class InvalidDP1(IterDataPipe[int]):
2551                def __iter__(self) -> str:  # type: ignore[misc, override]
2552                    yield 0
2553
2554        with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):
2555
2556            class InvalidDP2(IterDataPipe[Tuple]):
2557                def __iter__(self) -> Iterator[int]:  # type: ignore[override]
2558                    yield 0
2559
2560        with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):
2561
2562            class InvalidDP3(IterDataPipe[Tuple[int, str]]):
2563                def __iter__(self) -> Iterator[tuple]:  # type: ignore[override]
2564                    yield (0,)
2565
2566        if _generic_namedtuple_allowed:
2567            with self.assertRaisesRegex(
2568                TypeError, r"is not supported by Python typing"
2569            ):
2570
2571                class InvalidDP4(IterDataPipe["InvalidData[int]"]):  # type: ignore[type-arg, misc]
2572                    pass
2573
2574        class DP1(IterDataPipe[Tuple[int, str]]):
2575            def __init__(self, length):
2576                self.length = length
2577
2578            def __iter__(self) -> Iterator[Tuple[int, str]]:
2579                for d in range(self.length):
2580                    yield d, str(d)
2581
2582        self.assertTrue(issubclass(DP1, IterDataPipe))
2583        dp1 = DP1(10)
2584        self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type))  # type: ignore[attr-defined]
2585        dp1_ = DP1(5)
2586        self.assertEqual(dp1.type, dp1_.type)
2587
2588        with self.assertRaisesRegex(TypeError, r"is not a generic class"):
2589
2590            class InvalidDP5(DP1[tuple]):  # type: ignore[type-arg]
2591                def __iter__(self) -> Iterator[tuple]:  # type: ignore[override]
2592                    yield (0,)
2593
2594        class DP2(IterDataPipe[T_co]):
2595            def __iter__(self) -> Iterator[T_co]:
2596                yield from range(10)  # type: ignore[misc]
2597
2598        self.assertTrue(issubclass(DP2, IterDataPipe))
2599        dp2 = DP2()  # type: ignore[var-annotated]
2600        self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type))  # type: ignore[attr-defined]
2601        dp2_ = DP2()  # type: ignore[var-annotated]
2602        self.assertEqual(dp2.type, dp2_.type)
2603
2604        class DP3(IterDataPipe[Tuple[T_co, str]]):
2605            r"""DataPipe without fixed type with __init__ function"""
2606
2607            def __init__(self, datasource):
2608                self.datasource = datasource
2609
2610            def __iter__(self) -> Iterator[Tuple[T_co, str]]:
2611                for d in self.datasource:
2612                    yield d, str(d)
2613
2614        self.assertTrue(issubclass(DP3, IterDataPipe))
2615        dp3 = DP3(range(10))  # type: ignore[var-annotated]
2616        self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type))  # type: ignore[attr-defined]
2617        dp3_ = DP3(5)  # type: ignore[var-annotated]
2618        self.assertEqual(dp3.type, dp3_.type)
2619
2620        class DP4(IterDataPipe[tuple]):
2621            r"""DataPipe without __iter__ annotation"""
2622
2623            def __iter__(self):
2624                raise NotImplementedError
2625
2626        self.assertTrue(issubclass(DP4, IterDataPipe))
2627        dp4 = DP4()
2628        self.assertTrue(dp4.type.param == tuple)
2629
2630        class DP5(IterDataPipe):
2631            r"""DataPipe without type annotation"""
2632
2633            def __iter__(self) -> Iterator[str]:
2634                raise NotImplementedError
2635
2636        self.assertTrue(issubclass(DP5, IterDataPipe))
2637        dp5 = DP5()
2638        from torch.utils.data.datapipes._typing import issubtype
2639
2640        self.assertTrue(
2641            issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param)
2642        )
2643
2644        class DP6(IterDataPipe[int]):
2645            r"""DataPipe with plain Iterator"""
2646
2647            def __iter__(self) -> Iterator:
2648                raise NotImplementedError
2649
2650        self.assertTrue(issubclass(DP6, IterDataPipe))
2651        dp6 = DP6()
2652        self.assertTrue(dp6.type.param == int)
2653
2654        class DP7(IterDataPipe[Awaitable[T_co]]):
2655            r"""DataPipe with abstract base class"""
2656
2657        self.assertTrue(issubclass(DP7, IterDataPipe))
2658        self.assertTrue(DP7.type.param == Awaitable[T_co])  # type: ignore[attr-defined]
2659
2660        class DP8(DP7[str]):
2661            r"""DataPipe subclass from a DataPipe with abc type"""
2662
2663        self.assertTrue(issubclass(DP8, IterDataPipe))
2664        self.assertTrue(DP8.type.param == Awaitable[str])  # type: ignore[attr-defined]
2665
2666    @skipTyping
2667    def test_construct_time(self):
2668        class DP0(IterDataPipe[Tuple]):
2669            @argument_validation
2670            def __init__(self, dp: IterDataPipe):
2671                self.dp = dp
2672
2673            def __iter__(self) -> Iterator[Tuple]:
2674                for d in self.dp:
2675                    yield d, str(d)
2676
2677        class DP1(IterDataPipe[int]):
2678            @argument_validation
2679            def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
2680                self.dp = dp
2681
2682            def __iter__(self) -> Iterator[int]:
2683                for a, b in self.dp:
2684                    yield a
2685
2686        # Non-DataPipe input with DataPipe hint
2687        datasource = [(1, "1"), (2, "2"), (3, "3")]
2688        with self.assertRaisesRegex(
2689            TypeError, r"Expected argument 'dp' as a IterDataPipe"
2690        ):
2691            dp0 = DP0(datasource)
2692
2693        dp0 = DP0(dp.iter.IterableWrapper(range(10)))
2694        with self.assertRaisesRegex(
2695            TypeError, r"Expected type of argument 'dp' as a subtype"
2696        ):
2697            dp1 = DP1(dp0)
2698
2699    @skipTyping
2700    def test_runtime(self):
2701        class DP(IterDataPipe[Tuple[int, T_co]]):
2702            def __init__(self, datasource):
2703                self.ds = datasource
2704
2705            @runtime_validation
2706            def __iter__(self) -> Iterator[Tuple[int, T_co]]:
2707                yield from self.ds
2708
2709        dss = ([(1, "1"), (2, "2")], [(1, 1), (2, "2")])
2710        for ds in dss:
2711            dp0 = DP(ds)  # type: ignore[var-annotated]
2712            self.assertEqual(list(dp0), ds)
2713            # Reset __iter__
2714            self.assertEqual(list(dp0), ds)
2715
2716        dss = (
2717            [(1, 1), ("2", 2)],  # type: ignore[assignment, list-item]
2718            [[1, "1"], [2, "2"]],  # type: ignore[list-item]
2719            [1, "1", 2, "2"],
2720        )
2721        for ds in dss:
2722            dp0 = DP(ds)
2723            with self.assertRaisesRegex(
2724                RuntimeError, r"Expected an instance as subtype"
2725            ):
2726                list(dp0)
2727
2728            with runtime_validation_disabled():
2729                self.assertEqual(list(dp0), ds)
2730                with runtime_validation_disabled():
2731                    self.assertEqual(list(dp0), ds)
2732
2733            with self.assertRaisesRegex(
2734                RuntimeError, r"Expected an instance as subtype"
2735            ):
2736                list(dp0)
2737
2738    @skipTyping
2739    def test_reinforce(self):
2740        T = TypeVar("T", int, str)
2741
2742        class DP(IterDataPipe[T]):
2743            def __init__(self, ds):
2744                self.ds = ds
2745
2746            @runtime_validation
2747            def __iter__(self) -> Iterator[T]:
2748                yield from self.ds
2749
2750        ds = list(range(10))
2751        # Valid type reinforcement
2752        dp0 = DP(ds).reinforce_type(int)
2753        self.assertTrue(dp0.type, int)
2754        self.assertEqual(list(dp0), ds)
2755
2756        # Invalid type
2757        with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"):
2758            dp1 = DP(ds).reinforce_type(1)
2759
2760        # Type is not subtype
2761        with self.assertRaisesRegex(
2762            TypeError, r"Expected 'expected_type' as subtype of"
2763        ):
2764            dp2 = DP(ds).reinforce_type(float)
2765
2766        # Invalid data at runtime
2767        dp3 = DP(ds).reinforce_type(str)
2768        with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
2769            list(dp3)
2770
2771        # Context Manager to disable the runtime validation
2772        with runtime_validation_disabled():
2773            self.assertEqual(list(dp3), ds)
2774
2775
2776class NumbersDataset(IterDataPipe):
2777    def __init__(self, size=10):
2778        self.size = size
2779
2780    def __iter__(self):
2781        yield from range(self.size)
2782
2783    def __len__(self):
2784        return self.size
2785
2786
2787class TestGraph(TestCase):
2788    class CustomIterDataPipe(IterDataPipe):
2789        def add_v(self, x):
2790            return x + self.v
2791
2792        def __init__(self, source_dp, v=1):
2793            self._dp = source_dp.map(self.add_v)
2794            self.v = 1
2795
2796        def __iter__(self):
2797            yield from self._dp
2798
2799        def __hash__(self):
2800            raise NotImplementedError
2801
2802    def test_simple_traverse(self):
2803        numbers_dp = NumbersDataset(size=50)
2804        shuffled_dp = numbers_dp.shuffle()
2805        sharded_dp = shuffled_dp.sharding_filter()
2806        mapped_dp = sharded_dp.map(lambda x: x * 10)
2807        graph = traverse_dps(mapped_dp)
2808        expected: Dict[Any, Any] = {
2809            id(mapped_dp): (
2810                mapped_dp,
2811                {
2812                    id(sharded_dp): (
2813                        sharded_dp,
2814                        {
2815                            id(shuffled_dp): (
2816                                shuffled_dp,
2817                                {id(numbers_dp): (numbers_dp, {})},
2818                            )
2819                        },
2820                    )
2821                },
2822            )
2823        }
2824        self.assertEqual(expected, graph)
2825
2826        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
2827        self.assertEqual(len(dps), 4)
2828        for datapipe in (numbers_dp, shuffled_dp, sharded_dp, mapped_dp):
2829            self.assertTrue(datapipe in dps)
2830
2831    def test_traverse_forked(self):
2832        numbers_dp = NumbersDataset(size=50)
2833        dp0, dp1, dp2 = numbers_dp.fork(num_instances=3)
2834        dp0_upd = dp0.map(lambda x: x * 10)
2835        dp1_upd = dp1.filter(lambda x: x % 3 == 1)
2836        combined_dp = dp0_upd.mux(dp1_upd, dp2)
2837        graph = traverse_dps(combined_dp)
2838        expected = {
2839            id(combined_dp): (
2840                combined_dp,
2841                {
2842                    id(dp0_upd): (
2843                        dp0_upd,
2844                        {
2845                            id(dp0): (
2846                                dp0,
2847                                {
2848                                    id(dp0.main_datapipe): (
2849                                        dp0.main_datapipe,
2850                                        {
2851                                            id(dp0.main_datapipe.main_datapipe): (
2852                                                dp0.main_datapipe.main_datapipe,
2853                                                {},
2854                                            )
2855                                        },
2856                                    )
2857                                },
2858                            )
2859                        },
2860                    ),
2861                    id(dp1_upd): (
2862                        dp1_upd,
2863                        {
2864                            id(dp1): (
2865                                dp1,
2866                                {
2867                                    id(dp1.main_datapipe): (
2868                                        dp1.main_datapipe,
2869                                        {
2870                                            id(dp1.main_datapipe.main_datapipe): (
2871                                                dp1.main_datapipe.main_datapipe,
2872                                                {},
2873                                            )
2874                                        },
2875                                    )
2876                                },
2877                            )
2878                        },
2879                    ),
2880                    id(dp2): (
2881                        dp2,
2882                        {
2883                            id(dp2.main_datapipe): (
2884                                dp2.main_datapipe,
2885                                {
2886                                    id(dp2.main_datapipe.main_datapipe): (
2887                                        dp2.main_datapipe.main_datapipe,
2888                                        {},
2889                                    )
2890                                },
2891                            )
2892                        },
2893                    ),
2894                },
2895            )
2896        }
2897        self.assertEqual(expected, graph)
2898
2899        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
2900        self.assertEqual(len(dps), 8)
2901        for _dp in [
2902            numbers_dp,
2903            dp0.main_datapipe,
2904            dp0,
2905            dp1,
2906            dp2,
2907            dp0_upd,
2908            dp1_upd,
2909            combined_dp,
2910        ]:
2911            self.assertTrue(_dp in dps)
2912
2913    def test_traverse_mapdatapipe(self):
2914        source_dp = dp.map.SequenceWrapper(range(10))
2915        map_dp = source_dp.map(partial(_fake_add, 1))
2916        graph = traverse_dps(map_dp)
2917        expected: Dict[Any, Any] = {
2918            id(map_dp): (map_dp, {id(source_dp): (source_dp, {})})
2919        }
2920        self.assertEqual(expected, graph)
2921
2922    def test_traverse_mixdatapipe(self):
2923        source_map_dp = dp.map.SequenceWrapper(range(10))
2924        iter_dp = dp.iter.IterableWrapper(source_map_dp)
2925        graph = traverse_dps(iter_dp)
2926        expected: Dict[Any, Any] = {
2927            id(iter_dp): (iter_dp, {id(source_map_dp): (source_map_dp, {})})
2928        }
2929        self.assertEqual(expected, graph)
2930
2931    def test_traverse_circular_datapipe(self):
2932        source_iter_dp = dp.iter.IterableWrapper(list(range(10)))
2933        circular_dp = TestGraph.CustomIterDataPipe(source_iter_dp)
2934        graph = traverse_dps(circular_dp)
2935        # See issue: https://github.com/pytorch/data/issues/535
2936        expected: Dict[Any, Any] = {
2937            id(circular_dp): (
2938                circular_dp,
2939                {
2940                    id(circular_dp._dp): (
2941                        circular_dp._dp,
2942                        {id(source_iter_dp): (source_iter_dp, {})},
2943                    )
2944                },
2945            )
2946        }
2947        self.assertEqual(expected, graph)
2948
2949        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
2950        self.assertEqual(len(dps), 3)
2951        for _dp in [circular_dp, circular_dp._dp, source_iter_dp]:
2952            self.assertTrue(_dp in dps)
2953
2954    def test_traverse_unhashable_datapipe(self):
2955        source_iter_dp = dp.iter.IterableWrapper(list(range(10)))
2956        unhashable_dp = TestGraph.CustomIterDataPipe(source_iter_dp)
2957        graph = traverse_dps(unhashable_dp)
2958        with self.assertRaises(NotImplementedError):
2959            hash(unhashable_dp)
2960        expected: Dict[Any, Any] = {
2961            id(unhashable_dp): (
2962                unhashable_dp,
2963                {
2964                    id(unhashable_dp._dp): (
2965                        unhashable_dp._dp,
2966                        {id(source_iter_dp): (source_iter_dp, {})},
2967                    )
2968                },
2969            )
2970        }
2971        self.assertEqual(expected, graph)
2972
2973
2974def unbatch(x):
2975    return x[0]
2976
2977
2978class TestSerialization(TestCase):
2979    @skipIfNoDill
2980    def test_spawn_lambdas_iter(self):
2981        idp = dp.iter.IterableWrapper(range(3)).map(lambda x: x + 1).shuffle()
2982        dl = DataLoader(
2983            idp,
2984            num_workers=2,
2985            shuffle=True,
2986            multiprocessing_context="spawn",
2987            collate_fn=unbatch,
2988            batch_size=1,
2989        )
2990        result = list(dl)
2991        self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))
2992
2993    @skipIfNoDill
2994    def test_spawn_lambdas_map(self):
2995        mdp = dp.map.SequenceWrapper(range(3)).map(lambda x: x + 1).shuffle()
2996        dl = DataLoader(
2997            mdp,
2998            num_workers=2,
2999            shuffle=True,
3000            multiprocessing_context="spawn",
3001            collate_fn=unbatch,
3002            batch_size=1,
3003        )
3004        result = list(dl)
3005        self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))
3006
3007
3008class TestCircularSerialization(TestCase):
3009    class CustomIterDataPipe(IterDataPipe):
3010        @staticmethod
3011        def add_one(x):
3012            return x + 1
3013
3014        @classmethod
3015        def classify(cls, x):
3016            return 0
3017
3018        def add_v(self, x):
3019            return x + self.v
3020
3021        def __init__(self, fn, source_dp=None):
3022            self.fn = fn
3023            self.source_dp = (
3024                source_dp if source_dp else dp.iter.IterableWrapper([1, 2, 4])
3025            )
3026            self._dp = (
3027                self.source_dp.map(self.add_one)
3028                .map(self.add_v)
3029                .demux(2, self.classify)[0]
3030            )
3031            self.v = 1
3032
3033        def __iter__(self):
3034            yield from self._dp
3035
3036    def test_circular_serialization_with_pickle(self):
3037        # Test for circular reference issue with pickle
3038        dp1 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn)
3039        self.assertTrue(list(dp1) == list(pickle.loads(pickle.dumps(dp1))))
3040
3041        child_1 = dp1._dp
3042        dm_1 = child_1.main_datapipe
3043        m2_1 = dm_1.main_datapipe
3044        m1_1 = m2_1.datapipe
3045        src_1 = m1_1.datapipe
3046
3047        res1 = traverse_dps(dp1)
3048        exp_res_1 = {
3049            id(dp1): (
3050                dp1,
3051                {
3052                    id(src_1): (src_1, {}),
3053                    id(child_1): (
3054                        child_1,
3055                        {
3056                            id(dm_1): (
3057                                dm_1,
3058                                {
3059                                    id(m2_1): (
3060                                        m2_1,
3061                                        {id(m1_1): (m1_1, {id(src_1): (src_1, {})})},
3062                                    )
3063                                },
3064                            )
3065                        },
3066                    ),
3067                },
3068            )
3069        }
3070        self.assertEqual(res1, exp_res_1)
3071        dp2 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn, source_dp=dp1)
3072        self.assertTrue(list(dp2) == list(pickle.loads(pickle.dumps(dp2))))
3073
3074        child_2 = dp2._dp
3075        dm_2 = child_2.main_datapipe
3076        m2_2 = dm_2.main_datapipe
3077        m1_2 = m2_2.datapipe
3078
3079        res2 = traverse_dps(dp2)
3080        exp_res_2 = {
3081            id(dp2): (
3082                dp2,
3083                {
3084                    id(dp1): (
3085                        dp1,
3086                        {
3087                            id(src_1): (src_1, {}),
3088                            id(child_1): (
3089                                child_1,
3090                                {
3091                                    id(dm_1): (
3092                                        dm_1,
3093                                        {
3094                                            id(m2_1): (
3095                                                m2_1,
3096                                                {
3097                                                    id(m1_1): (
3098                                                        m1_1,
3099                                                        {id(src_1): (src_1, {})},
3100                                                    )
3101                                                },
3102                                            )
3103                                        },
3104                                    )
3105                                },
3106                            ),
3107                        },
3108                    ),
3109                    id(child_2): (
3110                        child_2,
3111                        {
3112                            id(dm_2): (
3113                                dm_2,
3114                                {
3115                                    id(m2_2): (
3116                                        m2_2,
3117                                        {
3118                                            id(m1_2): (
3119                                                m1_2,
3120                                                {
3121                                                    id(dp1): (
3122                                                        dp1,
3123                                                        {
3124                                                            id(src_1): (src_1, {}),
3125                                                            id(child_1): (
3126                                                                child_1,
3127                                                                {
3128                                                                    id(dm_1): (
3129                                                                        dm_1,
3130                                                                        {
3131                                                                            id(m2_1): (
3132                                                                                m2_1,
3133                                                                                {
3134                                                                                    id(
3135                                                                                        m1_1
3136                                                                                    ): (
3137                                                                                        m1_1,
3138                                                                                        {
3139                                                                                            id(
3140                                                                                                src_1
3141                                                                                            ): (
3142                                                                                                src_1,
3143                                                                                                {},
3144                                                                                            )
3145                                                                                        },
3146                                                                                    )
3147                                                                                },
3148                                                                            )
3149                                                                        },
3150                                                                    )
3151                                                                },
3152                                                            ),
3153                                                        },
3154                                                    ),
3155                                                },
3156                                            )
3157                                        },
3158                                    )
3159                                },
3160                            )
3161                        },
3162                    ),
3163                },
3164            )
3165        }
3166        self.assertEqual(res2, exp_res_2)
3167
3168    class LambdaIterDataPipe(CustomIterDataPipe):
3169        def __init__(self, fn, source_dp=None):
3170            super().__init__(fn, source_dp)
3171            self.container = [
3172                lambda x: x + 1,
3173            ]
3174            self.lambda_fn = lambda x: x + 1
3175            self._dp = (
3176                self.source_dp.map(self.add_one)
3177                .map(self.lambda_fn)
3178                .map(self.add_v)
3179                .demux(2, self.classify)[0]
3180            )
3181
3182    @skipIfNoDill
3183    @skipIf(True, "Dill Tests")
3184    def test_circular_serialization_with_dill(self):
3185        # Test for circular reference issue with dill
3186        dp1 = TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1)
3187        self.assertTrue(list(dp1) == list(dill.loads(dill.dumps(dp1))))
3188
3189        child_1 = dp1._dp
3190        dm_1 = child_1.main_datapipe
3191        m2_1 = dm_1.main_datapipe
3192        m1_1 = m2_1.datapipe
3193        src_1 = m1_1.datapipe
3194
3195        res1 = traverse_dps(dp1)
3196
3197        exp_res_1 = {
3198            id(dp1): (
3199                dp1,
3200                {
3201                    id(src_1): (src_1, {}),
3202                    id(child_1): (
3203                        child_1,
3204                        {
3205                            id(dm_1): (
3206                                dm_1,
3207                                {
3208                                    id(m2_1): (
3209                                        m2_1,
3210                                        {id(m1_1): (m1_1, {id(src_1): (src_1, {})})},
3211                                    )
3212                                },
3213                            )
3214                        },
3215                    ),
3216                },
3217            )
3218        }
3219
3220        self.assertEqual(res1, exp_res_1)
3221
3222        dp2 = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn, source_dp=dp1)
3223        self.assertTrue(list(dp2) == list(dill.loads(dill.dumps(dp2))))
3224
3225        child_2 = dp2._dp
3226        dm_2 = child_2.main_datapipe
3227        m2_2 = dm_2.main_datapipe
3228        m1_2 = m2_2.datapipe
3229
3230        res2 = traverse_dps(dp2)
3231        exp_res_2 = {
3232            id(dp2): (
3233                dp2,
3234                {
3235                    id(dp1): (
3236                        dp1,
3237                        {
3238                            id(src_1): (src_1, {}),
3239                            id(child_1): (
3240                                child_1,
3241                                {
3242                                    id(dm_1): (
3243                                        dm_1,
3244                                        {
3245                                            id(m2_1): (
3246                                                m2_1,
3247                                                {
3248                                                    id(m1_1): (
3249                                                        m1_1,
3250                                                        {id(src_1): (src_1, {})},
3251                                                    )
3252                                                },
3253                                            )
3254                                        },
3255                                    )
3256                                },
3257                            ),
3258                        },
3259                    ),
3260                    id(child_2): (
3261                        child_2,
3262                        {
3263                            id(dm_2): (
3264                                dm_2,
3265                                {
3266                                    id(m2_2): (
3267                                        m2_2,
3268                                        {
3269                                            id(m1_2): (
3270                                                m1_2,
3271                                                {
3272                                                    id(dp1): (
3273                                                        dp1,
3274                                                        {
3275                                                            id(src_1): (src_1, {}),
3276                                                            id(child_1): (
3277                                                                child_1,
3278                                                                {
3279                                                                    id(dm_1): (
3280                                                                        dm_1,
3281                                                                        {
3282                                                                            id(m2_1): (
3283                                                                                m2_1,
3284                                                                                {
3285                                                                                    id(
3286                                                                                        m1_1
3287                                                                                    ): (
3288                                                                                        m1_1,
3289                                                                                        {
3290                                                                                            id(
3291                                                                                                src_1
3292                                                                                            ): (
3293                                                                                                src_1,
3294                                                                                                {},
3295                                                                                            )
3296                                                                                        },
3297                                                                                    )
3298                                                                                },
3299                                                                            )
3300                                                                        },
3301                                                                    )
3302                                                                },
3303                                                            ),
3304                                                        },
3305                                                    ),
3306                                                },
3307                                            )
3308                                        },
3309                                    )
3310                                },
3311                            )
3312                        },
3313                    ),
3314                },
3315            )
3316        }
3317        self.assertEqual(res2, exp_res_2)
3318
3319
3320class CustomShardingIterDataPipe(IterDataPipe):
3321    def __init__(self, dp):
3322        self.dp = dp
3323        self.num_of_instances = 1
3324        self.instance_id = 0
3325
3326    def apply_sharding(self, num_of_instances, instance_id):
3327        self.num_of_instances = num_of_instances
3328        self.instance_id = instance_id
3329
3330    def __iter__(self):
3331        for i, d in enumerate(self.dp):
3332            if i % self.num_of_instances == self.instance_id:
3333                yield d
3334
3335
3336class TestSharding(TestCase):
3337    def _get_pipeline(self):
3338        numbers_dp = NumbersDataset(size=10)
3339        dp0, dp1 = numbers_dp.fork(num_instances=2)
3340        dp0_upd = dp0.map(_mul_10)
3341        dp1_upd = dp1.filter(_mod_3_test)
3342        combined_dp = dp0_upd.mux(dp1_upd)
3343        return combined_dp
3344
3345    def _get_dill_pipeline(self):
3346        numbers_dp = NumbersDataset(size=10)
3347        dp0, dp1 = numbers_dp.fork(num_instances=2)
3348        dp0_upd = dp0.map(lambda x: x * 10)
3349        dp1_upd = dp1.filter(lambda x: x % 3 == 1)
3350        combined_dp = dp0_upd.mux(dp1_upd)
3351        return combined_dp
3352
3353    def test_simple_sharding(self):
3354        sharded_dp = self._get_pipeline().sharding_filter()
3355        torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1)
3356        items = list(sharded_dp)
3357        self.assertEqual([1, 20], items)
3358
3359        all_items = [0, 1, 10, 4, 20, 7]
3360        items = []
3361        for i in range(3):
3362            sharded_dp = self._get_pipeline().sharding_filter()
3363            torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, i)
3364            items += list(sharded_dp)
3365        self.assertEqual(sorted(all_items), sorted(items))
3366
3367    def test_sharding_groups(self):
3368        def construct_sharded_pipe():
3369            sharding_pipes = []
3370            dp = NumbersDataset(size=90)
3371            dp = dp.sharding_filter(
3372                sharding_group_filter=SHARDING_PRIORITIES.DISTRIBUTED
3373            )
3374            sharding_pipes.append(dp)
3375            dp = dp.sharding_filter(
3376                sharding_group_filter=SHARDING_PRIORITIES.MULTIPROCESSING
3377            )
3378            sharding_pipes.append(dp)
3379            dp = dp.sharding_filter(sharding_group_filter=300)
3380            sharding_pipes.append(dp)
3381            return dp, sharding_pipes
3382
3383        dp, sharding_pipes = construct_sharded_pipe()
3384
3385        for pipe in sharding_pipes:
3386            pipe.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED)
3387            pipe.apply_sharding(
3388                5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING
3389            )
3390            pipe.apply_sharding(3, 1, sharding_group=300)
3391
3392        actual = list(dp)
3393        expected = [17, 47, 77]
3394        self.assertEqual(expected, actual)
3395        self.assertEqual(3, len(dp))
3396
3397        dp, _ = construct_sharded_pipe()
3398        dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)
3399        with self.assertRaises(Exception):
3400            dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)
3401
3402        dp, _ = construct_sharded_pipe()
3403        dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)
3404        with self.assertRaises(Exception):
3405            dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)
3406
3407    # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatbility
3408    # TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated
3409    def test_sharding_groups_in_legacy_grouping_package(self):
3410        with self.assertWarnsRegex(
3411            FutureWarning,
3412            r"Please use `SHARDING_PRIORITIES` "
3413            "from the `torch.utils.data.datapipes.iter.sharding`",
3414        ):
3415            from torch.utils.data.datapipes.iter.grouping import (
3416                SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES,
3417            )
3418
3419        def construct_sharded_pipe():
3420            sharding_pipes = []
3421            dp = NumbersDataset(size=90)
3422            dp = dp.sharding_filter(
3423                sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
3424            )
3425            sharding_pipes.append(dp)
3426            dp = dp.sharding_filter(
3427                sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3428            )
3429            sharding_pipes.append(dp)
3430            dp = dp.sharding_filter(sharding_group_filter=300)
3431            sharding_pipes.append(dp)
3432            return dp, sharding_pipes
3433
3434        dp, sharding_pipes = construct_sharded_pipe()
3435
3436        for pipe in sharding_pipes:
3437            pipe.apply_sharding(
3438                2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
3439            )
3440            pipe.apply_sharding(
3441                5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3442            )
3443            pipe.apply_sharding(3, 1, sharding_group=300)
3444
3445        actual = list(dp)
3446        expected = [17, 47, 77]
3447        self.assertEqual(expected, actual)
3448        self.assertEqual(3, len(dp))
3449
3450        dp, _ = construct_sharded_pipe()
3451        dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
3452        with self.assertRaises(Exception):
3453            dp.apply_sharding(
3454                5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3455            )
3456
3457        dp, _ = construct_sharded_pipe()
3458        dp.apply_sharding(
3459            5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
3460        )
3461        with self.assertRaises(Exception):
3462            dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
3463
3464    def test_legacy_custom_sharding(self):
3465        dp = self._get_pipeline()
3466        sharded_dp = CustomShardingIterDataPipe(dp)
3467        torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1)
3468        items = list(sharded_dp)
3469        self.assertEqual([1, 20], items)
3470
3471    def test_sharding_length(self):
3472        numbers_dp = dp.iter.IterableWrapper(range(13))
3473        sharded_dp0 = numbers_dp.sharding_filter()
3474        torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 3, 0)
3475        sharded_dp1 = numbers_dp.sharding_filter()
3476        torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 3, 1)
3477        sharded_dp2 = numbers_dp.sharding_filter()
3478        torch.utils.data.graph_settings.apply_sharding(sharded_dp2, 3, 2)
3479        self.assertEqual(13, len(numbers_dp))
3480        self.assertEqual(5, len(sharded_dp0))
3481        self.assertEqual(4, len(sharded_dp1))
3482        self.assertEqual(4, len(sharded_dp2))
3483
3484        numbers_dp = dp.iter.IterableWrapper(range(1))
3485        sharded_dp0 = numbers_dp.sharding_filter()
3486        torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 2, 0)
3487        sharded_dp1 = numbers_dp.sharding_filter()
3488        torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 2, 1)
3489        self.assertEqual(1, len(sharded_dp0))
3490        self.assertEqual(0, len(sharded_dp1))
3491
3492    def test_old_dataloader(self):
3493        dp0 = self._get_pipeline()
3494        expected = list(dp0)
3495
3496        dp0 = self._get_pipeline().sharding_filter()
3497        dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2)
3498        items = list(dl)
3499
3500        self.assertEqual(sorted(expected), sorted(items))
3501
3502    def test_legacy_custom_sharding_with_old_dataloader(self):
3503        dp0 = self._get_pipeline()
3504        expected = list(dp0)
3505
3506        dp0 = self._get_pipeline()
3507        dp0 = CustomShardingIterDataPipe(dp0)
3508        dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2)
3509        items = list(dl)
3510
3511        self.assertEqual(sorted(expected), sorted(items))
3512
3513    def test_multi_sharding(self):
3514        # Raises Error when multiple sharding on the single branch
3515        numbers_dp = dp.iter.IterableWrapper(range(13))
3516        sharded_dp = numbers_dp.sharding_filter()
3517        sharded_dp = sharded_dp.sharding_filter()
3518        with self.assertRaisesRegex(
3519            RuntimeError, "Sharding twice on a single pipeline"
3520        ):
3521            torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0)
3522
3523        # Raises Error when sharding on both data source and branch
3524        numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter()
3525        dp1, dp2 = numbers_dp.fork(2)
3526        sharded_dp = dp1.sharding_filter()
3527        zip_dp = dp2.zip(sharded_dp)
3528        with self.assertRaisesRegex(
3529            RuntimeError, "Sharding twice on a single pipeline"
3530        ):
3531            torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3532
3533        # Raises Error when multiple sharding on the branch and end
3534        numbers_dp = dp.iter.IterableWrapper(range(13))
3535        dp1, dp2 = numbers_dp.fork(2)
3536        sharded_dp = dp1.sharding_filter()
3537        zip_dp = dp2.zip(sharded_dp).sharding_filter()
3538        with self.assertRaisesRegex(
3539            RuntimeError, "Sharding twice on a single pipeline"
3540        ):
3541            torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3542
3543        # Single sharding_filter on data source
3544        numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter()
3545        dp1, dp2 = numbers_dp.fork(2)
3546        zip_dp = dp1.zip(dp2)
3547        torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3548        self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)])
3549
3550        # Single sharding_filter per branch
3551        numbers_dp = dp.iter.IterableWrapper(range(13))
3552        dp1, dp2 = numbers_dp.fork(2)
3553        sharded_dp1 = dp1.sharding_filter()
3554        sharded_dp2 = dp2.sharding_filter()
3555        zip_dp = sharded_dp1.zip(sharded_dp2)
3556        torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
3557        self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)])
3558
3559
3560class TestIterDataPipeSingletonConstraint(TestCase):
3561    r"""
3562    Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older
3563    iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior.
3564    """
3565
3566    def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe):
3567        r"""
3568        Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of
3569        a second iterator invalidates the first one.
3570        """
3571        it1 = iter(source_dp)
3572        self.assertEqual(list(range(10)), list(it1))
3573        it1 = iter(source_dp)
3574        self.assertEqual(
3575            list(range(10)), list(it1)
3576        )  # A fresh iterator can be read in full again
3577        it1 = iter(source_dp)
3578        self.assertEqual(0, next(it1))
3579        it2 = iter(source_dp)  # This should invalidate `it1`
3580        self.assertEqual(0, next(it2))  # Should read from the beginning again
3581        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3582            next(it1)
3583
3584    def test_iterdatapipe_singleton_generator(self):
3585        r"""
3586        Testing for the case where IterDataPipe's `__iter__` is a generator function.
3587        """
3588
3589        # Functional Test: Check if invalidation logic is correct
3590        source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
3591        self._check_single_iterator_invalidation_logic(source_dp)
3592
3593        # Functional Test: extend the test to a pipeline
3594        dps = source_dp.map(_fake_fn).filter(_fake_filter_fn)
3595        self._check_single_iterator_invalidation_logic(dps)
3596
3597        # Functional Test: multiple simultaneous references to the same DataPipe fails
3598        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3599            for _ in zip(source_dp, source_dp):
3600                pass
3601
3602        # Function Test: sequential references work
3603        for _ in zip(list(source_dp), list(source_dp)):
3604            pass
3605
3606    def test_iterdatapipe_singleton_self_next(self):
3607        r"""
3608        Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method
3609        Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`).
3610        """
3611
3612        class _CustomIterDP_Self(IterDataPipe):
3613            def __init__(self, iterable):
3614                self.source = iterable
3615                self.iterable = iter(iterable)
3616
3617            def __iter__(self):
3618                self.reset()
3619                return self
3620
3621            def __next__(self):
3622                return next(self.iterable)
3623
3624            def reset(self):
3625                self.iterable = iter(self.source)
3626
3627        # Functional Test: Check that every `__iter__` call returns the same object
3628        source_dp = _CustomIterDP_Self(range(10))
3629        res = list(source_dp)
3630        it = iter(source_dp)
3631        self.assertEqual(res, list(it))
3632
3633        # Functional Test: Check if invalidation logic is correct
3634        source_dp = _CustomIterDP_Self(range(10))
3635        self._check_single_iterator_invalidation_logic(source_dp)
3636        self.assertEqual(
3637            1, next(source_dp)
3638        )  # `source_dp` is still valid and can be read
3639
3640        # Functional Test: extend the test to a pipeline
3641        source_dp = _CustomIterDP_Self(
3642            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
3643        )
3644        self._check_single_iterator_invalidation_logic(source_dp)
3645        self.assertEqual(
3646            1, next(source_dp)
3647        )  # `source_dp` is still valid and can be read
3648
3649        # Functional Test: multiple simultaneous references to the same DataPipe fails
3650        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3651            for _ in zip(source_dp, source_dp):
3652                pass
3653
3654    def test_iterdatapipe_singleton_new_object(self):
3655        r"""
3656        Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`,
3657        and there isn't a `__next__` method.
3658        """
3659
3660        class _CustomIterDP(IterDataPipe):
3661            def __init__(self, iterable):
3662                self.iterable = iter(iterable)
3663
3664            def __iter__(self):  # Note that this doesn't reset
3665                return self.iterable  # Intentionally not returning `self`
3666
3667        # Functional Test: Check if invalidation logic is correct
3668        source_dp = _CustomIterDP(range(10))
3669        it1 = iter(source_dp)
3670        self.assertEqual(0, next(it1))
3671        it2 = iter(source_dp)
3672        self.assertEqual(1, next(it2))
3673        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3674            next(it1)
3675
3676        # Functional Test: extend the test to a pipeline
3677        source_dp = _CustomIterDP(
3678            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
3679        )
3680        it1 = iter(source_dp)
3681        self.assertEqual(0, next(it1))
3682        it2 = iter(source_dp)
3683        self.assertEqual(1, next(it2))
3684        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3685            next(it1)
3686
3687        # Functional Test: multiple simultaneous references to the same DataPipe fails
3688        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3689            for _ in zip(source_dp, source_dp):
3690                pass
3691
3692    def test_iterdatapipe_singleton_buggy(self):
3693        r"""
3694        Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has
3695        a `__next__` method.
3696        """
3697
3698        class _CustomIterDP(IterDataPipe):
3699            def __init__(self, iterable):
3700                self.source = iterable
3701                self.iterable = iter(iterable)
3702
3703            def __iter__(self):
3704                return iter(self.source)  # Intentionally not returning `self`
3705
3706            def __next__(self):
3707                return next(self.iterable)
3708
3709        # Functional Test: Check if invalidation logic is correct
3710        source_dp = _CustomIterDP(range(10))
3711        self._check_single_iterator_invalidation_logic(source_dp)
3712        self.assertEqual(0, next(source_dp))  # `__next__` is unrelated with `__iter__`
3713
3714        # Functional Test: Special case to show `__next__` is unrelated with `__iter__`
3715        source_dp = _CustomIterDP(range(10))
3716        self.assertEqual(0, next(source_dp))
3717        it1 = iter(source_dp)
3718        self.assertEqual(0, next(it1))
3719        self.assertEqual(1, next(source_dp))
3720        it2 = iter(source_dp)  # invalidates both `it1`
3721        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3722            next(it1)
3723        self.assertEqual(2, next(source_dp))  # not impacted by the creation of `it2`
3724        self.assertEqual(
3725            list(range(10)), list(it2)
3726        )  # `it2` still works because it is a new object
3727
3728    def test_iterdatapipe_singleton_constraint_multiple_outputs(self):
3729        r"""
3730        Testing for the case where IterDataPipe has multiple child DataPipes as outputs.
3731        """
3732        # Functional Test: all previous related iterators should be invalidated when a new iterator
3733        #                  is created from a ChildDataPipe
3734        source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
3735        cdp1, cdp2 = source_dp.fork(num_instances=2)
3736        it1, it2 = iter(cdp1), iter(cdp2)
3737        self.assertEqual(list(range(10)), list(it1))
3738        self.assertEqual(list(range(10)), list(it2))
3739        it1, it2 = iter(cdp1), iter(cdp2)
3740        with warnings.catch_warnings(record=True) as wa:
3741            it3 = iter(cdp1)  # This should invalidate `it1` and `it2`
3742            self.assertEqual(len(wa), 1)
3743            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
3744        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3745            next(it1)
3746        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3747            next(it2)
3748        self.assertEqual(0, next(it3))
3749        # The next line should not invalidate anything, as there was no new iterator created
3750        # for `cdp2` after `it2` was invalidated
3751        it4 = iter(cdp2)
3752        self.assertEqual(1, next(it3))  # An error shouldn't be raised here
3753        self.assertEqual(list(range(10)), list(it4))
3754
3755        # Functional Test: invalidation when a new iterator is created from `source_dp`
3756        source_dp = dp.iter.IterableWrapper(range(10))
3757        cdp1, cdp2 = source_dp.fork(num_instances=2)
3758        it1, it2 = iter(cdp1), iter(cdp2)
3759        self.assertEqual(list(range(10)), list(it1))
3760        self.assertEqual(list(range(10)), list(it2))
3761        it1, it2 = iter(cdp1), iter(cdp2)
3762        self.assertEqual(0, next(it1))
3763        self.assertEqual(0, next(it2))
3764        it3 = iter(source_dp)  # note that a new iterator is created from `source_dp`
3765        self.assertEqual(
3766            0, next(it3)
3767        )  # `it3` should invalidate `it1` and `it2` since they both use `source_dp`
3768        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3769            next(it1)
3770        self.assertEqual(1, next(it3))
3771
3772        # Function Test: Extending test to pipeline
3773        source_dp = (
3774            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
3775        )
3776        cdp1, cdp2 = source_dp.fork(num_instances=2)
3777        it1, it2 = iter(cdp1), iter(cdp2)
3778        self.assertEqual(list(range(10)), list(it1))
3779        self.assertEqual(list(range(10)), list(it2))
3780        it1, it2 = iter(cdp1), iter(cdp2)
3781        with warnings.catch_warnings(record=True) as wa:
3782            it3 = iter(cdp1)  # This should invalidate `it1` and `it2`
3783            self.assertEqual(len(wa), 1)
3784            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
3785        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3786            next(it1)
3787        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3788            next(it2)
3789        with warnings.catch_warnings(record=True) as wa:
3790            it1, it2 = iter(cdp1), iter(cdp2)
3791            self.assertEqual(len(wa), 1)
3792            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
3793        self.assertEqual(0, next(it1))
3794        self.assertEqual(0, next(it2))
3795        it3 = iter(source_dp)  # note that a new iterator is created from `source_dp`
3796        self.assertEqual(
3797            0, next(it3)
3798        )  # `it3` should invalidate `it1` and `it2` since they both use `source_dp`
3799        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
3800            next(it1)
3801        self.assertEqual(1, next(it3))
3802
3803
3804class TestIterDataPipeCountSampleYielded(TestCase):
3805    def _yield_count_test_helper(self, datapipe, n_expected_samples):
3806        # Functional Test: Check if number of samples yielded is as expected
3807        res = list(datapipe)
3808        self.assertEqual(len(res), datapipe._number_of_samples_yielded)
3809
3810        # Functional Test: Check if the count is correct when DataPipe is partially read
3811        it = iter(datapipe)
3812        res = []
3813        for i, value in enumerate(it):
3814            res.append(value)
3815            if i == n_expected_samples - 1:
3816                break
3817        self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded)
3818
3819        # Functional Test: Check for reset behavior and if iterator also works
3820        it = iter(datapipe)  # reset the DataPipe
3821        res = list(it)
3822        self.assertEqual(len(res), datapipe._number_of_samples_yielded)
3823
3824    def test_iterdatapipe_sample_yielded_generator_function(self):
3825        # Functional Test: `__iter__` is a generator function
3826        datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10))
3827        self._yield_count_test_helper(datapipe, n_expected_samples=5)
3828
3829    def test_iterdatapipe_sample_yielded_generator_function_exception(self):
3830        # Functional Test: `__iter__` is a custom generator function with exception
3831        class _CustomGeneratorFnDataPipe(IterDataPipe):
3832            # This class's `__iter__` has a Runtime Error
3833            def __iter__(self):
3834                yield 0
3835                yield 1
3836                yield 2
3837                raise RuntimeError("Custom test error after yielding 3 elements")
3838                yield 3
3839
3840        # Functional Test: Ensure the count is correct even when exception is raised
3841        datapipe: IterDataPipe = _CustomGeneratorFnDataPipe()
3842        with self.assertRaisesRegex(
3843            RuntimeError, "Custom test error after yielding 3 elements"
3844        ):
3845            list(datapipe)
3846        self.assertEqual(3, datapipe._number_of_samples_yielded)
3847
3848        # Functional Test: Check for reset behavior and if iterator also works
3849        it = iter(datapipe)  # reset the DataPipe
3850        with self.assertRaisesRegex(
3851            RuntimeError, "Custom test error after yielding 3 elements"
3852        ):
3853            list(it)
3854        self.assertEqual(3, datapipe._number_of_samples_yielded)
3855
3856    def test_iterdatapipe_sample_yielded_return_self(self):
3857        class _CustomGeneratorDataPipe(IterDataPipe):
3858            # This class's `__iter__` is not a generator function
3859            def __init__(self) -> None:
3860                self.source = iter(range(10))
3861
3862            def __iter__(self):
3863                return self.source
3864
3865            def reset(self):
3866                self.source = iter(range(10))
3867
3868        datapipe: IterDataPipe = _CustomGeneratorDataPipe()
3869        self._yield_count_test_helper(datapipe, n_expected_samples=5)
3870
3871    def test_iterdatapipe_sample_yielded_next(self):
3872        class _CustomNextDataPipe(IterDataPipe):
3873            # This class's `__iter__` returns `self` and has a `__next__`
3874            def __init__(self) -> None:
3875                self.source = iter(range(10))
3876
3877            def __iter__(self):
3878                return self
3879
3880            def __next__(self):
3881                return next(self.source)
3882
3883            def reset(self):
3884                self.source = iter(range(10))
3885
3886        datapipe: IterDataPipe = _CustomNextDataPipe()
3887        self._yield_count_test_helper(datapipe, n_expected_samples=5)
3888
3889    def test_iterdatapipe_sample_yielded_next_exception(self):
3890        class _CustomNextDataPipe(IterDataPipe):
3891            # This class's `__iter__` returns `self` and has a `__next__`
3892            def __init__(self) -> None:
3893                self.source = iter(range(10))
3894                self.count = 0
3895
3896            def __iter__(self):
3897                return self
3898
3899            def __next__(self):
3900                if self.count == 3:
3901                    raise RuntimeError("Custom test error after yielding 3 elements")
3902                self.count += 1
3903                return next(self.source)
3904
3905            def reset(self):
3906                self.count = 0
3907                self.source = iter(range(10))
3908
3909        # Functional Test: Ensure the count is correct even when exception is raised
3910        datapipe: IterDataPipe = _CustomNextDataPipe()
3911        with self.assertRaisesRegex(
3912            RuntimeError, "Custom test error after yielding 3 elements"
3913        ):
3914            list(datapipe)
3915        self.assertEqual(3, datapipe._number_of_samples_yielded)
3916
3917        # Functional Test: Check for reset behavior and if iterator also works
3918        it = iter(datapipe)  # reset the DataPipe
3919        with self.assertRaisesRegex(
3920            RuntimeError, "Custom test error after yielding 3 elements"
3921        ):
3922            list(it)
3923        self.assertEqual(3, datapipe._number_of_samples_yielded)
3924
3925
3926class _CustomNonGeneratorTestDataPipe(IterDataPipe):
3927    def __init__(self) -> None:
3928        self.n = 10
3929        self.source = list(range(self.n))
3930
3931    # This class's `__iter__` is not a generator function
3932    def __iter__(self):
3933        return iter(self.source)
3934
3935    def __len__(self):
3936        return self.n
3937
3938
3939class _CustomSelfNextTestDataPipe(IterDataPipe):
3940    def __init__(self) -> None:
3941        self.n = 10
3942        self.iter = iter(range(self.n))
3943
3944    def __iter__(self):
3945        return self
3946
3947    def __next__(self):
3948        return next(self.iter)
3949
3950    def reset(self):
3951        self.iter = iter(range(self.n))
3952
3953    def __len__(self):
3954        return self.n
3955
3956
3957class TestIterDataPipeGraphFastForward(TestCase):
3958    def _fast_forward_graph_test_helper(
3959        self, datapipe, fast_forward_fn, expected_res, n_iterations=3, rng=None
3960    ):
3961        if rng is None:
3962            rng = torch.Generator()
3963        rng = rng.manual_seed(0)
3964        torch.utils.data.graph_settings.apply_random_seed(datapipe, rng)
3965
3966        # Test Case: fast forward works with list
3967        rng.manual_seed(0)
3968        fast_forward_fn(datapipe, n_iterations, rng)
3969        actual_res = list(datapipe)
3970        self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
3971        self.assertEqual(expected_res[n_iterations:], actual_res)
3972
3973        # Test Case: fast forward works with iterator
3974        rng.manual_seed(0)
3975        fast_forward_fn(datapipe, n_iterations, rng)
3976        it = iter(datapipe)
3977        actual_res = list(it)
3978        self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
3979        self.assertEqual(expected_res[n_iterations:], actual_res)
3980        with self.assertRaises(StopIteration):
3981            next(it)
3982
3983    def test_simple_snapshot_graph(self):
3984        graph1 = dp.iter.IterableWrapper(range(10))
3985        res1 = list(range(10))
3986        self._fast_forward_graph_test_helper(
3987            graph1, _simple_graph_snapshot_restoration, expected_res=res1
3988        )
3989
3990        graph2 = graph1.map(_mul_10)
3991        res2 = [10 * x for x in res1]
3992        self._fast_forward_graph_test_helper(
3993            graph2, _simple_graph_snapshot_restoration, expected_res=res2
3994        )
3995
3996        rng = torch.Generator()
3997        graph3 = graph2.shuffle()
3998        rng.manual_seed(0)
3999        torch.utils.data.graph_settings.apply_random_seed(graph3, rng)
4000        res3 = list(graph3)
4001        self._fast_forward_graph_test_helper(
4002            graph3, _simple_graph_snapshot_restoration, expected_res=res3
4003        )
4004
4005        graph4 = graph3.map(_mul_10)
4006        res4 = [10 * x for x in res3]
4007        self._fast_forward_graph_test_helper(
4008            graph4, _simple_graph_snapshot_restoration, expected_res=res4
4009        )
4010
4011        batch_size = 2
4012        graph5 = graph4.batch(batch_size)
4013        res5 = [
4014            res4[i : i + batch_size] for i in range(0, len(res4), batch_size)
4015        ]  # .batch(2)
4016        self._fast_forward_graph_test_helper(
4017            graph5, _simple_graph_snapshot_restoration, expected_res=res5
4018        )
4019
4020        # With `fork` and `zip`
4021        cdp1, cdp2 = graph5.fork(2)
4022        graph6 = cdp1.zip(cdp2)
4023        rng = rng.manual_seed(100)
4024        torch.utils.data.graph_settings.apply_random_seed(graph6, rng)
4025        res6 = [(x, x) for x in res5]
4026        self._fast_forward_graph_test_helper(
4027            graph6, _simple_graph_snapshot_restoration, expected_res=res6
4028        )
4029
4030        # With `fork` and `concat`
4031        graph7 = cdp1.concat(cdp2)
4032        res7 = res5 * 2
4033        self._fast_forward_graph_test_helper(
4034            graph7, _simple_graph_snapshot_restoration, expected_res=res7
4035        )
4036
4037        # Raises an exception if the graph has already been restored
4038        with self.assertRaisesRegex(
4039            RuntimeError, "Snapshot restoration cannot be applied."
4040        ):
4041            _simple_graph_snapshot_restoration(graph7, 1)
4042            _simple_graph_snapshot_restoration(graph7, 1)
4043
4044    def test_simple_snapshot_custom_non_generator(self):
4045        graph = _CustomNonGeneratorTestDataPipe()
4046        self._fast_forward_graph_test_helper(
4047            graph, _simple_graph_snapshot_restoration, expected_res=range(10)
4048        )
4049
4050    def test_simple_snapshot_custom_self_next(self):
4051        graph = _CustomSelfNextTestDataPipe()
4052        self._fast_forward_graph_test_helper(
4053            graph, _simple_graph_snapshot_restoration, expected_res=range(10)
4054        )
4055
4056    def _snapshot_test_helper(self, datapipe, expected_res, n_iter=3, rng=None):
4057        """
4058        Extend the previous test with serialization and deserialization test.
4059        """
4060        if rng is None:
4061            rng = torch.Generator()
4062        rng.manual_seed(0)
4063        torch.utils.data.graph_settings.apply_random_seed(datapipe, rng)
4064        it = iter(datapipe)
4065        for _ in range(n_iter):
4066            next(it)
4067        serialized_graph = pickle.dumps(datapipe)
4068        deserialized_graph = pickle.loads(serialized_graph)
4069        self.assertEqual(n_iter, datapipe._number_of_samples_yielded)
4070        self.assertEqual(n_iter, deserialized_graph._number_of_samples_yielded)
4071
4072        rng_for_deserialized = torch.Generator()
4073        rng_for_deserialized.manual_seed(0)
4074        _simple_graph_snapshot_restoration(
4075            deserialized_graph, n_iter, rng=rng_for_deserialized
4076        )
4077        self.assertEqual(expected_res[n_iter:], list(it))
4078        self.assertEqual(expected_res[n_iter:], list(deserialized_graph))
4079
4080    def test_simple_snapshot_graph_with_serialization(self):
4081        graph1 = dp.iter.IterableWrapper(range(10))
4082        res1 = list(range(10))
4083        self._snapshot_test_helper(graph1, expected_res=res1)
4084
4085        graph2 = graph1.map(_mul_10)
4086        res2 = [10 * x for x in res1]
4087        self._snapshot_test_helper(graph2, expected_res=res2)
4088
4089        rng = torch.Generator()
4090        graph3 = graph2.shuffle()
4091        rng.manual_seed(0)
4092        torch.utils.data.graph_settings.apply_random_seed(graph3, rng)
4093        res3 = list(graph3)
4094        self._snapshot_test_helper(graph3, expected_res=res3)
4095
4096        graph4 = graph3.map(_mul_10)
4097        res4 = [10 * x for x in res3]
4098        self._snapshot_test_helper(graph4, expected_res=res4)
4099
4100        batch_size = 2
4101        graph5 = graph4.batch(batch_size)
4102        res5 = [
4103            res4[i : i + batch_size] for i in range(0, len(res4), batch_size)
4104        ]  # .batch(2)
4105        self._snapshot_test_helper(graph5, expected_res=res5)
4106
4107        # With `fork` and `zip`
4108        cdp1, cdp2 = graph5.fork(2)
4109        graph6 = cdp1.zip(cdp2)
4110        res6 = [(x, x) for x in res5]
4111        self._snapshot_test_helper(graph6, expected_res=res6)
4112
4113        # With `fork` and `concat`
4114        graph7 = cdp1.concat(cdp2)
4115        res7 = res5 * 2
4116        self._snapshot_test_helper(graph7, expected_res=res7)
4117
4118    def test_simple_snapshot_graph_repeated(self):
4119        cdp1, cdp2 = (
4120            dp.iter.IterableWrapper(range(10))
4121            .map(_mul_10)
4122            .shuffle()
4123            .map(_mul_10)
4124            .map(_mul_10)
4125            .fork(2)
4126        )
4127        graph = cdp1.zip(cdp2)
4128
4129        rng = torch.Generator()
4130        rng.manual_seed(0)
4131        torch.utils.data.graph_settings.apply_random_seed(graph, rng)
4132
4133        # Get expected result
4134        expected_res = list(graph)
4135
4136        rng.manual_seed(0)
4137        torch.utils.data.graph_settings.apply_random_seed(graph, rng)
4138        it = iter(graph)
4139        n_iter = 3
4140        for _ in range(n_iter):
4141            next(it)
4142
4143        # First serialization/deserialization
4144        serialized_graph = pickle.dumps(graph)
4145        deserialized_graph = pickle.loads(serialized_graph)
4146
4147        rng_for_deserialized = torch.Generator()
4148        rng_for_deserialized.manual_seed(0)
4149        _simple_graph_snapshot_restoration(
4150            deserialized_graph,
4151            deserialized_graph._number_of_samples_yielded,
4152            rng=rng_for_deserialized,
4153        )
4154
4155        it = iter(deserialized_graph)
4156        # Get the next element and ensure it is as expected
4157        self.assertEqual(expected_res[3], next(it))
4158
4159        # Serializalize/Deserialize and fast-forward again after to ensure it works
4160        serialized_graph2 = pickle.dumps(deserialized_graph)
4161        deserialized_graph2 = pickle.loads(serialized_graph2)
4162
4163        rng_for_deserialized = torch.Generator()
4164        rng_for_deserialized.manual_seed(0)
4165        _simple_graph_snapshot_restoration(
4166            deserialized_graph2,
4167            deserialized_graph._number_of_samples_yielded,
4168            rng=rng_for_deserialized,
4169        )
4170
4171        # Get the next element and ensure it is as expected
4172        self.assertEqual(expected_res[4:], list(deserialized_graph2))
4173
4174
4175if __name__ == "__main__":
4176    run_tests()
4177