xref: /aosp_15_r20/external/pytorch/test/test_dataloader.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dataloader"]
2
3import ctypes
4import errno
5import faulthandler
6import functools
7import gc
8import itertools
9import math
10import operator
11import os
12import signal
13import sys
14import tempfile
15import time
16import unittest
17import warnings
18
19import torch
20import torch.utils.data.datapipes as dp
21from torch import multiprocessing as mp
22from torch._utils import ExceptionWrapper
23from torch.testing._internal.common_device_type import instantiate_device_type_tests
24from torch.testing._internal.common_utils import (
25    IS_CI,
26    IS_JETSON,
27    IS_MACOS,
28    IS_SANDCASTLE,
29    IS_WINDOWS,
30    load_tests,
31    NO_MULTIPROCESSING_SPAWN,
32    parametrize,
33    run_tests,
34    skipIfNoDill,
35    skipIfRocm,
36    slowTest,
37    TEST_CUDA,
38    TEST_NUMPY,
39    TEST_WITH_ASAN,
40    TEST_WITH_ROCM,
41    TEST_WITH_TSAN,
42    TestCase,
43)
44from torch.utils.data import (
45    _utils,
46    ChainDataset,
47    ConcatDataset,
48    DataLoader,
49    Dataset,
50    IterableDataset,
51    IterDataPipe,
52    StackDataset,
53    Subset,
54    TensorDataset,
55)
56from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
57from torch.utils.data.datapipes.iter import IterableWrapper
58from torch.utils.data.dataset import random_split
59
60
61try:
62    import psutil
63
64    HAS_PSUTIL = True
65except ModuleNotFoundError:
66    HAS_PSUTIL = False
67    psutil = None
68    err_msg = (
69        "psutil not found. Some critical data loader tests relying on it "
70        "(e.g., TestDataLoader.test_proper_exit) will not run."
71    )
72    if IS_CI:
73        raise ModuleNotFoundError(err_msg) from None
74    else:
75        warnings.warn(err_msg)
76
77
78try:
79    import numpy as np
80
81    HAS_NUMPY = True
82except ModuleNotFoundError:
83    HAS_NUMPY = False
84    np = None
85skipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy")
86
87# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
88# sharding on sandcastle. This line silences flake warnings
89load_tests = load_tests
90
91TEST_CUDA_IPC = (
92    torch.cuda.is_available()
93    and sys.platform != "darwin"
94    and sys.platform != "win32"
95    and not IS_JETSON
96    and not TEST_WITH_ROCM
97)  # https://github.com/pytorch/pytorch/issues/90940
98
99TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
100
101if not NO_MULTIPROCESSING_SPAWN:
102    # We want to use `spawn` if able because some of our tests check that the
103    # data loader terminiates gracefully. To prevent hanging in the testing
104    # process, such data loaders are run in a separate subprocess.
105    #
106    # We also want to test the `pin_memory=True` configuration, thus `spawn` is
107    # required to launch such processes and they initialize the CUDA context.
108    #
109    # Mixing different start method is a recipe for disaster (e.g., using a fork
110    # `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally
111    # to avoid bugs.
112    #
113    # Get a multiprocessing context because some test / third party library will
114    # set start_method when imported, and setting again triggers `RuntimeError`.
115    mp = mp.get_context(method="spawn")
116
117
118# 60s of timeout?
119# Yes, in environments where physical CPU resources are shared, e.g., CI, the
120# time for a inter-process communication can be highly varying.  With 15~17s of
121# timeout, we have observed flakiness in some CI builds (see
122# pytorch/pytorch#14501, pytorch/pytorch#16608).  We follow the CPython
123# multiprocessing setup and set the timeout to 60s here:
124#
125# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73
126JOIN_TIMEOUT = 60.0  # seconds
127
128
129supported_multiprocessing_contexts = [None] + list(
130    torch.multiprocessing.get_all_start_methods()
131)
132
133
134# collate_fn that returns the batch cloned; defined globally here for pickle purposes.
135def _clone_collate(b):
136    return [x.clone() for x in b]
137
138
139@unittest.skipIf(
140    TEST_WITH_TSAN,
141    "Fails with TSAN with the following error: starting new threads after multi-threaded "
142    "fork is not supported. Dying (set die_after_fork=0 to override)",
143)
144class TestDatasetRandomSplit(TestCase):
145    def test_lengths_must_equal_dataset_size(self):
146        with self.assertRaises(ValueError):
147            random_split([1, 2, 3, 4], [1, 2])
148
149    def test_splits_have_correct_size(self):
150        splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
151        self.assertEqual(len(splits), 2)
152        self.assertEqual(len(splits[0]), 2)
153        self.assertEqual(len(splits[1]), 4)
154
155        splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5])
156        self.assertEqual(len(splits), 2)
157        self.assertEqual(len(splits[0]), 3)
158        self.assertEqual(len(splits[1]), 3)
159
160        # Odd size splits
161        self.assertEqual(
162            len(
163                random_split(
164                    range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1)
165                )
166            ),
167            2,
168        )
169
170        # Odd sized round-robin splits
171        splits = random_split(
172            range(106), [0.1, 0.2, 0.3, 0.4], generator=torch.Generator().manual_seed(1)
173        )
174        self.assertEqual(len(splits[0]), 11)
175        self.assertEqual(len(splits[1]), 22)
176        self.assertEqual(len(splits[2]), 31)
177        self.assertEqual(len(splits[3]), 42)
178
179    def test_splits_are_mutually_exclusive(self):
180        data = [5, 2, 3, 4, 1, 6]
181        splits = random_split(data, [2, 4])
182        all_values = []
183        all_values.extend(list(splits[0]))
184        all_values.extend(list(splits[1]))
185        data.sort()
186        all_values.sort()
187        self.assertListEqual(data, all_values)
188
189        splits = random_split(data, [0.33, 0.67])
190        all_values = []
191        all_values.extend(list(splits[0]))
192        all_values.extend(list(splits[1]))
193        data.sort()
194        all_values.sort()
195        self.assertListEqual(data, all_values)
196
197        data = [1, 2, 3, 4]
198        splits = random_split(data, [0.25, 0.75])
199        all_values = []
200        all_values.extend(list(splits[0]))
201        all_values.extend(list(splits[1]))
202        data.sort()
203        all_values.sort()
204        self.assertListEqual(data, all_values)
205
206    def test_splits_indexing_type(self):
207        r"""Indices generated by random_split
208        should be of integer type
209        """
210
211        class CustomDataset:
212            def __init__(self, test_object, custom_list):
213                self.data = custom_list
214                self.test_object = test_object
215
216            def __getitem__(self, key):
217                self.test_object.assertEqual(type(key), int)
218                return self.data[key]
219
220            def __len__(self):
221                return len(self.data)
222
223        x = [1, 2, 3, 4, 5]
224        dataset = CustomDataset(self, x)
225        dataset = random_split(dataset, [5])[0]
226        data_loader = DataLoader(dataset)
227        for batch in data_loader:
228            pass
229
230        # fractional splitting
231        dataset = CustomDataset(self, x)
232        dataset = random_split(dataset, [1.0])[0]
233        data_loader = DataLoader(dataset)
234        for batch in data_loader:
235            pass
236
237    def test_splits_reproducibility(self):
238        self.assertEqual(
239            [
240                list(x)
241                for x in random_split(
242                    range(10), [3, 7], generator=torch.Generator().manual_seed(1)
243                )
244            ],
245            [[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]],
246        )
247        self.assertEqual(
248            random_split(
249                range(100), [60, 40], generator=torch.Generator().manual_seed(42)
250            ),
251            random_split(
252                range(100), [60, 40], generator=torch.Generator().manual_seed(42)
253            ),
254        )
255        self.assertEqual(
256            random_split(
257                range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
258            ),
259            random_split(
260                range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
261            ),
262        )
263        self.assertEqual(
264            random_split(
265                range(100),
266                [0.33, 0.33, 0.34],
267                generator=torch.Generator().manual_seed(42),
268            ),
269            random_split(
270                range(100),
271                [0.33, 0.33, 0.34],
272                generator=torch.Generator().manual_seed(42),
273            ),
274        )
275
276    def test_incomplete_fractional_splits(self):
277        with self.assertRaises(ValueError):
278            # should raise since the sum of fractions is not 1
279            random_split([1, 2, 3, 4], [0.1])
280
281        with self.assertRaises(ValueError):
282            # should raise since fraction > 1
283            random_split([1, 2, 3, 4], [1.1])
284
285    def test_splits_generator(self):
286        # A random_split without a specific generator should affect the default one
287        state = torch.get_rng_state()
288        a = torch.rand(10)
289        torch.set_rng_state(state)
290        random_split(range(10), [5, 5])
291        b = torch.rand(10)
292        self.assertNotEqual(a, b)
293
294        # A random_split with a specific generator should not affect the default one
295        state = torch.get_rng_state()
296        a = torch.rand(10)
297        torch.set_rng_state(state)
298        random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42))
299        b = torch.rand(10)
300        self.assertEqual(a, b)
301
302    def test_slicing_of_subset_of_dataset(self):
303        # Testing slicing a subset initialized with a dataset
304        dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
305        subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
306        self.assertEqual(subset_of_dataset[:], dataset[:])
307        self.assertEqual(subset_of_dataset[1:2], dataset[1:2])
308        self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2])
309        # Testing slicing of subset from random split
310        subset1, subset2 = random_split(dataset, [3, 2])
311        self.assertEqual(subset1[:], dataset[subset1.indices[:]])
312        self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]])
313        self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]])
314
315    def test_slicing_of_subset_of_subset(self):
316        # Testing slicing a subset initialized with a subset
317        dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
318        subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
319        subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4])
320        self.assertEqual(subset_of_subset[:], dataset[:])
321        self.assertEqual(subset_of_subset[0:2], dataset[0:2])
322        self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2])
323        # Testing slicing of subset of subset from random split
324        subset1, subset2 = random_split(dataset, [4, 1])
325        subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1])
326        idx = [subset1.indices[i] for i in subset_of_subset1.indices]
327        self.assertEqual(subset_of_subset1[:], dataset[idx.copy()])
328        self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]])
329        self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]])
330
331
332class CUDACountingDataset(Dataset):
333    def __init__(self, n):
334        super().__init__()
335        self.n = n
336
337    def __getitem__(self, i):
338        return torch.as_tensor(i, device="cuda")
339
340    def __len__(self):
341        return self.n
342
343
344class CountingDataset(Dataset):
345    def __init__(self, n):
346        super().__init__()
347        self.n = n
348
349    def __getitem__(self, i):
350        return i
351
352    def __len__(self):
353        return self.n
354
355
356class CountingIterableDataset(IterableDataset):
357    def __init__(self, n):
358        super().__init__()
359        self.n = n
360
361    def __iter__(self):
362        return iter(range(self.n))
363
364    def __len__(self):
365        return self.n
366
367
368@unittest.skipIf(
369    TEST_WITH_TSAN,
370    "Fails with TSAN with the following error: starting new threads after multi-threaded "
371    "fork is not supported. Dying (set die_after_fork=0 to override)",
372)
373class TestTensorDataset(TestCase):
374    def test_len(self):
375        source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
376        self.assertEqual(len(source), 15)
377
378    def test_getitem(self):
379        t = torch.randn(15, 10, 2, 3, 4, 5)
380        l = torch.randn(15, 10)
381        source = TensorDataset(t, l)
382        for i in range(15):
383            self.assertEqual(t[i], source[i][0])
384            self.assertEqual(l[i], source[i][1])
385
386    def test_getitem_1d(self):
387        t = torch.randn(15)
388        l = torch.randn(15)
389        source = TensorDataset(t, l)
390        for i in range(15):
391            self.assertEqual(t[i], source[i][0])
392            self.assertEqual(l[i], source[i][1])
393
394    def test_single_tensor(self):
395        t = torch.randn(5, 10)
396        source = TensorDataset(t)
397        self.assertEqual(len(source), 5)
398        for i in range(5):
399            self.assertEqual(t[i], source[i][0])
400
401    def test_many_tensors(self):
402        t0 = torch.randn(5, 10, 2, 3, 4, 5)
403        t1 = torch.randn(5, 10)
404        t2 = torch.randn(5, 10, 2, 5)
405        t3 = torch.randn(5, 10, 3, 7)
406        source = TensorDataset(t0, t1, t2, t3)
407        self.assertEqual(len(source), 5)
408        for i in range(5):
409            self.assertEqual(t0[i], source[i][0])
410            self.assertEqual(t1[i], source[i][1])
411            self.assertEqual(t2[i], source[i][2])
412            self.assertEqual(t3[i], source[i][3])
413
414
415@unittest.skipIf(
416    TEST_WITH_TSAN,
417    "Fails with TSAN with the following error: starting new threads after multi-threaded "
418    "fork is not supported. Dying (set die_after_fork=0 to override)",
419)
420class TestStackDataset(TestCase):
421    def test_empty(self):
422        with self.assertRaisesRegex(
423            ValueError, "At least one dataset should be passed"
424        ):
425            StackDataset()
426
427    def test_mixed(self):
428        with self.assertRaisesRegex(ValueError, "Supported either"):
429            StackDataset(
430                TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15))
431            )
432
433    def test_size_mismatch(self):
434        with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
435            StackDataset(
436                TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15))
437            )
438        with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
439            StackDataset(
440                a=TensorDataset(torch.randn(15, 10)),
441                b=TensorDataset(torch.randn(10, 15)),
442            )
443
444    def test_len(self):
445        source = StackDataset(
446            TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15))
447        )
448        self.assertEqual(len(source), 15)
449        source = StackDataset(TensorDataset(torch.randn(15, 10)))
450        self.assertEqual(len(source), 15)
451        source = StackDataset(
452            a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15))
453        )
454        self.assertEqual(len(source), 15)
455        source = StackDataset(a=TensorDataset(torch.randn(15, 10)))
456        self.assertEqual(len(source), 15)
457
458    def test_single(self):
459        t = TensorDataset(torch.randn(15, 10))
460        source = StackDataset(t)
461        for i in range(15):
462            self.assertEqual(t[i], source[i][0])
463        source = StackDataset(a=t)
464        for i in range(15):
465            self.assertEqual(t[i], source[i]["a"])
466
467    def test_getitem(self):
468        t = TensorDataset(torch.randn(15, 10))
469        l = TensorDataset(torch.randn(15, 5, 4))
470        source = StackDataset(t, l)
471        for i in range(15):
472            self.assertEqual(t[i], source[i][0])
473            self.assertEqual(l[i], source[i][1])
474        source = StackDataset(a=t, b=l)
475        for i in range(15):
476            self.assertEqual(t[i], source[i]["a"])
477            self.assertEqual(l[i], source[i]["b"])
478
479    def test_getitems(self):
480        class GetItemsDataset(Dataset):
481            def __init__(self) -> None:
482                self.data = torch.randn(4)
483
484            def __getitem__(self, item):
485                return self.data[item]
486
487            def __getitems__(self, items):
488                return self.data[items]
489
490            def __len__(self):
491                return 4
492
493        t = GetItemsDataset()
494        l = [1, 2, 3, 4]
495
496        source = StackDataset(t, l)
497        batch = source.__getitems__([0, 1, 2, 3])
498        for i in range(4):
499            self.assertEqual(t[i], batch[i][0])
500            self.assertEqual(l[i], batch[i][1])
501
502        source = StackDataset(t=t, l=l)
503        batch = source.__getitems__([0, 1, 2, 3])
504        for i in range(4):
505            self.assertEqual(t[i], batch[i]["t"])
506            self.assertEqual(l[i], batch[i]["l"])
507
508    def test_getitems_raises_index_error(self):
509        class GetItemsDataset(Dataset):
510            def __init__(self) -> None:
511                self.data = torch.randn(4)
512
513            def __getitem__(self, item):
514                return self.data[item]
515
516            def __getitems__(self, items):
517                return self.data[items]
518
519            def __len__(self):
520                return 4
521
522        t = GetItemsDataset()
523        l = [1, 2, 3, 4]
524
525        source = StackDataset(t, l)
526
527        with self.assertRaises(IndexError):
528            source.__getitems__([0, 4])
529
530    def test_getitems_value_error(self):
531        class GetItemsDataset(Dataset):
532            def __init__(self) -> None:
533                self.data = torch.randn(4)
534
535            def __getitem__(self, item):
536                return self.data[item]
537
538            def __getitems__(self, items):
539                return self.data[items][:-1]  # return less
540
541            def __len__(self):
542                return 4
543
544        t = GetItemsDataset()
545        l = [1, 2, 3, 4]
546
547        source = StackDataset(t, l)
548
549        with self.assertRaisesRegex(
550            ValueError, "Nested dataset's output size mismatch. Expected 4, got 3"
551        ):
552            source.__getitems__([0, 1, 2, 3])
553
554
555@unittest.skipIf(
556    TEST_WITH_TSAN,
557    "Fails with TSAN with the following error: starting new threads after multi-threaded "
558    "fork is not supported. Dying (set die_after_fork=0 to override)",
559)
560class TestConcatDataset(TestCase):
561    def test_concat_two_singletons(self):
562        result = ConcatDataset([[0], [1]])
563        self.assertEqual(2, len(result))
564        self.assertEqual(0, result[0])
565        self.assertEqual(1, result[1])
566
567    def test_concat_two_non_singletons(self):
568        result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
569        self.assertEqual(10, len(result))
570        self.assertEqual(0, result[0])
571        self.assertEqual(5, result[5])
572
573    def test_concat_two_non_singletons_with_empty(self):
574        # Adding an empty dataset somewhere is correctly handled
575        result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
576        self.assertEqual(10, len(result))
577        self.assertEqual(0, result[0])
578        self.assertEqual(5, result[5])
579
580    def test_concat_raises_index_error(self):
581        result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
582        with self.assertRaises(IndexError):
583            # this one goes to 11
584            result[11]
585
586    def test_add_dataset(self):
587        d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
588        d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
589        d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
590        result = d1 + d2 + d3
591        self.assertEqual(21, len(result))
592        self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
593        self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
594        self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
595
596    def test_iterable_dataset_err(self):
597        d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
598        it1 = CountingIterableDataset(5)
599        it2 = CountingIterableDataset(10)
600
601        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
602            ConcatDataset([d1, it2, it1])
603
604        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
605            ConcatDataset([it2])
606
607        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
608            ConcatDataset([it1, d1])
609
610
611# takes in dummy var so this can also be used as a `worker_init_fn`
612def set_faulthander_if_available(_=None):
613    faulthandler.enable(sys.__stderr__)
614    if not IS_WINDOWS:
615        # windows does not have faulthandler.register
616        # chain=False prevents the default behavior of killing the process
617        faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False)
618
619
620set_faulthander_if_available()
621
622
623# Process `pid` must have called `set_faulthander_if_available`
624def print_traces_of_all_threads(pid):
625    if not IS_WINDOWS:
626        # use the custom signal if available
627        os.kill(pid, signal.SIGUSR1)
628    else:
629        # otherwise we can still use the handler given by faulthandler.enable()
630        # at the cost of killing the process.
631        os.kill(pid, signal.SIGSEGV)
632
633    # wait in parent process to give subprocess some time to print
634    time.sleep(5)
635
636
637# The following `ErrorTrackingProcess` stores the first encountered exception in
638# its `.exception` attribute.
639# Inspired by https://stackoverflow.com/a/33599967
640class ErrorTrackingProcess(mp.Process):
641    # Why no *args?
642    #   py2 doesn't support def fn(x, *args, key=val, **kwargs)
643    # Setting disable_stderr=True may generate a lot of unrelated error outputs
644    # but could be helpful for debugging.
645    def __init__(self, disable_stderr=True, **kwargs):
646        super().__init__(**kwargs)
647        self._pconn, self._cconn = mp.Pipe()
648        self._exception = None
649        self.disable_stderr = disable_stderr
650
651    def run(self):
652        set_faulthander_if_available()
653        if self.disable_stderr:
654            # Disable polluting stderr with errors that are supposed to happen.
655            with open(os.devnull, "w") as devnull:
656                os.dup2(devnull.fileno(), sys.stderr.fileno())
657        try:
658            super().run()
659            self._cconn.send(None)
660        except Exception:
661            self._cconn.send(ExceptionWrapper(sys.exc_info()))
662            raise
663
664    def print_traces_of_all_threads(self):
665        assert (
666            self.is_alive()
667        ), "can only use print_traces_of_all_threads if the process is alive"
668        assert (
669            not self.disable_stderr
670        ), "do not disable stderr if you use print_traces_of_all_threads"
671        # On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
672        # `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
673        # the process. So let's poll the exception first
674        _ = self.exception
675        print_traces_of_all_threads(self.pid)
676
677    @property
678    def exception(self):
679        if self._pconn.poll():
680            self._exception = self._pconn.recv()
681        if self._exception is None:
682            return None
683        else:
684            return self._exception.exc_type(self._exception.exc_msg)
685
686    # ESRCH means that os.kill can't finds alive proc
687    def send_signal(self, signum, ignore_ESRCH=False):
688        try:
689            os.kill(self.pid, signum)
690        except OSError as e:
691            if not ignore_ESRCH or e.errno != errno.ESRCH:
692                raise
693
694
695class ErrorDataset(Dataset):
696    def __init__(self, size):
697        self.size = size
698
699    def __len__(self):
700        return self.size
701
702
703class SegfaultDataset(Dataset):
704    def __init__(self, size):
705        self.size = size
706
707    def __getitem__(self, idx):
708        return ctypes.string_at(0)
709
710    def __len__(self):
711        return self.size
712
713
714class SleepDataset(Dataset):
715    def __init__(self, size, sleep_sec):
716        self.size = size
717        self.sleep_sec = sleep_sec
718        self.sleeped = False
719
720    def __getitem__(self, idx):
721        if not self.sleeped:
722            time.sleep(self.sleep_sec)
723            self.sleeped = True
724        return idx
725
726    def __len__(self):
727        return self.size
728
729
730class SeedDataset(Dataset):
731    def __init__(self, size):
732        self.size = size
733
734    def __getitem__(self, idx):
735        return torch.initial_seed()
736
737    def __len__(self):
738        return self.size
739
740
741class WorkerSpecificIterableDataset(IterableDataset):
742    def __init__(self, sizes_for_all_workers):
743        self.sizes_for_all_workers = sizes_for_all_workers
744
745    def __iter__(self):
746        worker_info = torch.utils.data.get_worker_info()
747        assert worker_info is not None
748        return iter(range(self.sizes_for_all_workers[worker_info.id]))
749
750    def __len__(self):
751        return sum(self.sizes_for_all_workers)
752
753
754# Inspired by https://stackoverflow.com/a/26703365
755# If all workers will call `sync_once`, they will be blocked until all workers
756# reach the call (i.e., acting like a barrier).
757# This can be used to ensure that each worker at least processes one data.
758class SynchronizedDataset(Dataset):
759    def __init__(self, size, batch_size, num_workers):
760        assert size >= num_workers * batch_size
761        self.count = mp.Value("i", 0, lock=True)
762        self.barrier = mp.Semaphore(0)
763        self.num_workers = num_workers
764        self.size = size
765
766    def sync_once(self):
767        with self.count.get_lock():
768            self.count.value += 1
769            if self.count.value == self.num_workers:
770                self.barrier.release()
771        self.barrier.acquire()
772        self.barrier.release()
773
774    def __getitem__(self, idx):
775        raise NotImplementedError
776
777    def __len__(self):
778        return self.size
779
780
781class EmptyTensorDataset(torch.utils.data.Dataset):
782    def __init__(self, len):
783        self.len = len
784
785    def __len__(self):
786        return self.len
787
788    def __getitem__(self, any):
789        return torch.empty(0)
790
791
792class SynchronizedSeedDataset(SynchronizedDataset):
793    def __getitem__(self, idx):
794        self.sync_once()
795        return torch.initial_seed()
796
797
798def _test_timeout(persistent_workers):
799    dataset = SleepDataset(10, 3)
800    dataloader = DataLoader(
801        dataset,
802        batch_size=2,
803        num_workers=2,
804        timeout=1,
805        persistent_workers=persistent_workers,
806    )
807    _ = next(iter(dataloader))
808
809
810def _test_timeout_pin_memory(persistent_workers):
811    dataset = SleepDataset(10, 3)
812    dataloader = DataLoader(
813        dataset,
814        batch_size=2,
815        num_workers=2,
816        timeout=1,
817        pin_memory=True,
818        persistent_workers=persistent_workers,
819    )
820    _ = next(iter(dataloader))
821
822
823def _test_large_sampler_indices(persistent_workers):
824    # See
825    #   test_large_sampler_indices
826    #   https://github.com/pytorch/pytorch/issues/48666
827
828    dataloader = torch.utils.data.DataLoader(
829        EmptyTensorDataset(10000000),
830        batch_size=40960,
831        persistent_workers=persistent_workers,
832        num_workers=1,
833    )
834
835    it = iter(dataloader)
836
837    for x in it:
838        assert x.numel() == 0
839        raise RuntimeError("My Error")
840
841
842def disable_stderr(worker_id):
843    r"""
844    Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
845    from workers. Since worker signal handler prints with low-level write(),
846    this has to be done on OS level via dup.
847
848    This is used as worker_init_fn for test_segfault.
849    """
850    sys.stderr.flush()  # flush library buffers that dup2 knows nothing about
851    # Can't use a with-block because otherwise the fd will be closed when this
852    # function ends.
853    with open(os.devnull, "w") as devnull:
854        os.dup2(devnull.fileno(), sys.stderr.fileno())
855
856
857def _test_segfault():
858    dataset = SegfaultDataset(10)
859    dataloader = DataLoader(
860        dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr
861    )
862    _ = next(iter(dataloader))
863
864
865def _test_no_segfault():
866    dataset = [1, 2, 3]
867    num_threads = torch.get_num_threads()
868    if num_threads < 4:
869        torch.set_num_threads(4)
870    else:
871        torch.set_num_threads(num_threads)
872    mp_ctx = torch.multiprocessing.get_context(method="fork")
873    dataloader = DataLoader(
874        dataset,
875        num_workers=1,
876        worker_init_fn=disable_stderr,
877        multiprocessing_context=mp_ctx,
878    )
879    _ = next(iter(dataloader))
880
881
882class TestProperExitDataset(Dataset):
883    def __init__(self, size, error_event):
884        self.size = size
885        self.error_event = error_event
886
887    def __len__(self):
888        return self.size
889
890    def __getitem__(self, idx):
891        worker_info = torch.utils.data.get_worker_info()
892        if (
893            self.error_event is not None
894            and self.error_event.is_set()
895            and worker_info.id == worker_info.num_workers - 1
896        ):
897            # only error in the last worker
898            raise RuntimeError("Worker error")
899        return torch.tensor([idx])
900
901
902class TestProperExitIterableDataset(IterableDataset):
903    def __init__(self, size, error_event):
904        self.error_event = error_event
905        self.size = size
906        self.remaining = size
907
908    def __len__(self):
909        return self.size
910
911    def __iter__(self):
912        return self
913
914    def __next__(self):
915        worker_info = torch.utils.data.get_worker_info()
916        if (
917            self.error_event is not None
918            and self.error_event.is_set()
919            and worker_info.id == worker_info.num_workers - 1
920        ):
921            # only error in the last worker
922            raise RuntimeError("Worker error")
923        self.remaining -= 1
924        if self.remaining < 0:
925            raise StopIteration
926        return torch.tensor(-1000)
927
928
929# See TestDataLoader.test_proper_exit for usage
930def _test_proper_exit(
931    is_iterable_dataset,
932    use_workers,
933    pin_memory,
934    exit_method,
935    hold_iter_reference,
936    loader_setup_event,
937    tester_setup_event,
938    persistent_workers,
939):
940    num_workers = 2 if use_workers else 0
941
942    if exit_method == "worker_error" or exit_method == "worker_kill":
943        assert use_workers is True
944
945    if exit_method == "worker_error":
946        worker_error_event = mp.Event()
947    else:
948        worker_error_event = None
949
950    if is_iterable_dataset:
951        ds = TestProperExitIterableDataset(7, worker_error_event)
952    else:
953        ds = TestProperExitDataset(12, worker_error_event)
954
955    loader = DataLoader(
956        ds,
957        batch_size=1,
958        shuffle=False,
959        num_workers=num_workers,
960        pin_memory=pin_memory,
961        worker_init_fn=set_faulthander_if_available,
962        persistent_workers=persistent_workers,
963    )
964
965    error_it = 2
966
967    if use_workers:
968        # 2 is the magical per-worker prefetch number...
969        # FIXME: change this after the number becomes configurable.
970        if is_iterable_dataset:
971            assert len(ds) * num_workers > (error_it + 2 + 1)
972        else:
973            assert len(loader) > (error_it + 2 + 1) * num_workers
974    else:
975        if is_iterable_dataset:
976            assert len(ds) > error_it + 1
977        else:
978            assert len(loader) > error_it + 1
979
980    it = iter(loader)
981    if use_workers:
982        workers = it._workers
983
984    def kill_pid(pid):
985        psutil_p = psutil.Process(pid)
986        psutil_p.kill()
987        psutil_p.wait(JOIN_TIMEOUT)
988        assert not psutil_p.is_running()
989
990    for i, _ in enumerate(it):
991        if i == 0:
992            if not hold_iter_reference:
993                del it
994                del loader
995            loader_setup_event.set()
996            tester_setup_event.wait()
997            # ensure that the workers are still alive
998            if use_workers:
999                for w in workers:
1000                    assert w.is_alive()
1001            if worker_error_event is not None:
1002                worker_error_event.set()
1003
1004        if i == error_it:
1005            if exit_method == "loader_error":
1006                raise RuntimeError("Loader error")
1007            elif exit_method == "loader_kill":
1008                kill_pid(os.getpid())
1009            elif exit_method == "worker_kill":
1010                kill_pid(workers[-1].pid)  # kill last worker
1011
1012    if not hold_iter_reference:
1013        # Tries to trigger the __del__ clean-up rather than the automatic
1014        # exiting of daemonic children. Technically it should be automatically
1015        # triggered, but I don't want to rely on the implementation detail of
1016        # Python gc.
1017        gc.collect()
1018
1019
1020class TestWorkerInfoDataset(SynchronizedDataset):
1021    def __getitem__(self, idx):
1022        self.sync_once()
1023        return torch.tensor(self.value)
1024
1025
1026# Should be used as worker_init_fn with TestWorkerInfoDataset.
1027# See _test_get_worker_info below for usage.
1028def _test_worker_info_init_fn(worker_id):
1029    worker_info = torch.utils.data.get_worker_info()
1030    assert (
1031        worker_id == worker_info.id
1032    ), "worker_init_fn and worker_info should have consistent id"
1033    assert (
1034        worker_id < worker_info.num_workers
1035    ), "worker_init_fn and worker_info should have valid id"
1036    assert (
1037        worker_info.seed == torch.initial_seed()
1038    ), "worker_init_fn and worker_info should have consistent seed"
1039    dataset = worker_info.dataset
1040    assert isinstance(
1041        dataset, TestWorkerInfoDataset
1042    ), "worker_info should have correct dataset copy"
1043    assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy"
1044    # test that WorkerInfo attributes are read-only
1045    try:
1046        worker_info.id = 3999
1047    except RuntimeError as e:
1048        assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1049    try:
1050        worker_info.a = 3
1051    except RuntimeError as e:
1052        assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1053    for k in ["id", "num_workers", "seed", "dataset"]:
1054        assert f"{k}=" in repr(worker_info)
1055    dataset.value = [worker_id, os.getpid()]
1056
1057
1058def _test_get_worker_info():
1059    # get_worker_info returns None in main proc
1060    assert torch.utils.data.get_worker_info() is None
1061    num_workers = 2
1062    batch_size = 2
1063    dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
1064    dataloader = DataLoader(
1065        dataset,
1066        batch_size=batch_size,
1067        num_workers=num_workers,
1068        worker_init_fn=_test_worker_info_init_fn,
1069    )
1070    it = iter(dataloader)
1071    data = []
1072    for d in it:
1073        data.append(d)  # noqa: PERF402
1074    worker_pids = [w.pid for w in it._workers]
1075    data = torch.cat(data, 0)
1076    for d in data:
1077        # each `d` is a [worker_id, worker_pid] pair, which is set in
1078        # _test_worker_info_init_fn
1079        assert d[1] == worker_pids[d[0]]
1080    # get_worker_info returns None in main proc after data loading
1081    assert torch.utils.data.get_worker_info() is None
1082    # main proc dataset was never assigned this attribute
1083    assert not hasattr(dataset, "value")
1084    try:
1085        _ = dataset[0]
1086    except AttributeError:
1087        return
1088    raise RuntimeError("Expected AttributeError")
1089
1090
1091# test custom init function
1092def init_fn(worker_id):
1093    torch.manual_seed(12345)
1094
1095
1096# used with test_error_in_init
1097class ErrorIterableDataset(IterableDataset):
1098    def __iter__(self):
1099        raise RuntimeError("Error in __iter__")
1100
1101
1102# used with test_error_in_init
1103def error_worker_init_fn(_):
1104    raise RuntimeError("Error in worker_init_fn")
1105
1106
1107class BulkLoadingDataset(Dataset):
1108    def __init__(self, length):
1109        self.length = length
1110
1111    def __getitem__(self, indices):
1112        assert isinstance(indices, (list, tuple))
1113        return torch.as_tensor(indices)
1114
1115    def __len__(self):
1116        return self.length
1117
1118
1119class BulkLoadingSampler(torch.utils.data.Sampler):
1120    def __init__(self, dataset, batch_size):
1121        self.dataset = dataset
1122        self.batch_size = batch_size
1123
1124    def __iter__(self):
1125        for x in torch.randperm(len(self.dataset)).split(self.batch_size):
1126            yield x.tolist()
1127
1128    def __len__(self):
1129        return int(math.ceil(len(self.dataset) / float(self.batch_size)))
1130
1131
1132class TestMultiEpochDataset(IterableDataset):
1133    def __init__(self, length):
1134        self.length = length
1135
1136    def __iter__(self):
1137        worker_info = torch.utils.data.get_worker_info()
1138        assert worker_info is not None
1139        worker_id = worker_info.id
1140        for idx in range(self.length // worker_info.num_workers):
1141            yield worker_id
1142
1143    def __len__(self):
1144        return self.length
1145
1146
1147class CustomList(list):
1148    pass
1149
1150
1151class CustomDict(dict):
1152    pass
1153
1154
1155def row_processor(row):
1156    return np.add(row, 1)
1157
1158
1159def filter_len(row):
1160    return len(row) == 4
1161
1162
1163@unittest.skipIf(
1164    TEST_WITH_TSAN,
1165    "Fails with TSAN with the following error: starting new threads after multi-threaded "
1166    "fork is not supported. Dying (set die_after_fork=0 to override)",
1167)
1168@unittest.skipIf(
1169    TEST_WITH_ASAN,
1170    "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
1171)
1172class TestDataLoader(TestCase):
1173    def setUp(self):
1174        super().setUp()
1175        self.data = torch.randn(100, 2, 3, 5)
1176        self.labels = torch.randperm(50).repeat(2)
1177        self.dataset = TensorDataset(self.data, self.labels)
1178        self.persistent_workers = False
1179
1180    def _get_data_loader(self, dataset, **kwargs):
1181        persistent_workers = kwargs.get("persistent_workers", self.persistent_workers)
1182        if persistent_workers and kwargs.get("num_workers", 0) == 0:
1183            persistent_workers = False
1184        kwargs["persistent_workers"] = persistent_workers
1185        return DataLoader(dataset, **kwargs)
1186
1187    def _test_sequential(self, loader):
1188        batch_size = loader.batch_size
1189        if batch_size is None:
1190            for idx, (sample, target) in enumerate(loader):
1191                self.assertEqual(sample, self.data[idx])
1192                self.assertEqual(target, self.labels[idx])
1193            self.assertEqual(idx, len(self.dataset) - 1)
1194        else:
1195            for i, (sample, target) in enumerate(loader):
1196                idx = i * batch_size
1197                self.assertEqual(sample, self.data[idx : idx + batch_size])
1198                self.assertEqual(target, self.labels[idx : idx + batch_size])
1199            self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1200
1201    def _test_shuffle(self, loader):
1202        found_data = dict.fromkeys(range(self.data.size(0)), 0)
1203        found_labels = dict.fromkeys(range(self.labels.size(0)), 0)
1204        batch_size = loader.batch_size
1205        if batch_size is None:
1206            for i, (batch_samples, batch_targets) in enumerate(loader):
1207                sample, target = (batch_samples, batch_targets)
1208                for data_point_idx, data_point in enumerate(self.data):
1209                    if data_point.eq(sample).all():
1210                        self.assertFalse(found_data[data_point_idx])
1211                        found_data[data_point_idx] += 1
1212                        break
1213                self.assertEqual(target, self.labels[data_point_idx])
1214                found_labels[data_point_idx] += 1
1215                self.assertEqual(sum(found_data.values()), (i + 1))
1216                self.assertEqual(sum(found_labels.values()), (i + 1))
1217            self.assertEqual(i, (len(self.dataset) - 1))
1218        else:
1219            for i, (batch_samples, batch_targets) in enumerate(loader):
1220                for sample, target in zip(batch_samples, batch_targets):
1221                    for data_point_idx, data_point in enumerate(self.data):
1222                        if data_point.eq(sample).all():
1223                            self.assertFalse(found_data[data_point_idx])
1224                            found_data[data_point_idx] += 1
1225                            break
1226                    self.assertEqual(target, self.labels[data_point_idx])
1227                    found_labels[data_point_idx] += 1
1228                self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
1229                self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
1230            self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1231
1232    def _test_error(self, loader):
1233        it = iter(loader)
1234        errors = 0
1235        while True:
1236            try:
1237                next(it)
1238            except NotImplementedError:
1239                errors += 1
1240            except StopIteration:
1241                self.assertEqual(
1242                    errors, math.ceil(float(len(loader.dataset)) / loader.batch_size)
1243                )
1244                return
1245
1246    def test_error_in_init(self):
1247        for num_workers in [0, 2]:
1248            loader = self._get_data_loader(
1249                ErrorIterableDataset(), num_workers=num_workers
1250            )
1251            with self.assertRaisesRegex(RuntimeError, "Error in __iter__"):
1252                list(iter(loader))
1253
1254        loader = self._get_data_loader(
1255            self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn
1256        )
1257        with self.assertRaisesRegex(RuntimeError, "Error in worker_init_fn"):
1258            list(iter(loader))
1259
1260    def test_typing(self):
1261        from typing import List
1262
1263        # Make sure there is no TypeError
1264
1265        class SomeDatasetClass(Dataset[List[torch.Tensor]]):
1266            pass
1267
1268        def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
1269            pass
1270
1271    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
1272    @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
1273    def test_fd_limit_exceeded(self):
1274        # See NOTE [ DataLoader on Linux and open files limit ]
1275        import subprocess
1276
1277        subprocess.check_output(
1278            [
1279                sys.executable,
1280                "-c",
1281                """\
1282import torch
1283import resource
1284from torch.utils.data import DataLoader, IterableDataset
1285
1286class RandomDataset(IterableDataset):
1287    def __init__(self, len, size):
1288        super(RandomDataset).__init__()
1289        self.len = len
1290        self.size = size
1291
1292    def __iter__(self):
1293        return self
1294
1295    def __next__(self):
1296        if self.len <= 0:
1297            raise StopIteration
1298        self.len -= 1
1299        return torch.randn(self.size)
1300
1301try:
1302    keep_fds_alive = []
1303    resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
1304    for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
1305                               num_workers=1):
1306      random_t.max(dim=0)
1307      keep_fds_alive.append(random_t)
1308except RuntimeError as e:
1309    assert "ulimit -n" in str(e)
1310    assert "set_sharing_strategy" in str(e)
1311""",
1312            ]
1313        )
1314
1315    def test_invalid_assign_after_init(self):
1316        dl = self._get_data_loader(self.dataset)
1317        for attr in ("batch_size", "sampler", "batch_sampler", "drop_last", "dataset"):
1318
1319            def fn():
1320                setattr(dl, attr, {})
1321
1322            self.assertRaises(ValueError, fn)
1323
1324    def test_sequential_nonbatch(self):
1325        self._test_sequential(self._get_data_loader(self.dataset, batch_size=None))
1326
1327    def test_sequential_batch(self):
1328        self._test_sequential(self._get_data_loader(self.dataset))
1329        self._test_sequential(self._get_data_loader(self.dataset, batch_size=2))
1330
1331    def test_bulk_loading_nobatch(self):
1332        n = 35
1333        bs = 4
1334        ds = BulkLoadingDataset(n)
1335        sampler = BulkLoadingSampler(ds, batch_size=4)
1336
1337        for num_workers in [0, 4]:
1338            dl = self._get_data_loader(
1339                ds,
1340                num_workers=num_workers,
1341                batch_size=None,
1342                sampler=sampler,
1343                pin_memory=TEST_CUDA,
1344            )
1345            self.assertFalse(dl._auto_collation)
1346            samples = list(dl)
1347            self.assertEqual(samples[0].is_pinned(), TEST_CUDA)
1348            self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n)))
1349
1350    def test_growing_dataset(self):
1351        dataset = [torch.ones(4) for _ in range(4)]
1352        dataloader_seq = self._get_data_loader(dataset, shuffle=False)
1353        dataloader_shuffle = self._get_data_loader(dataset, shuffle=True)
1354        dataset.append(torch.ones(4))
1355        self.assertEqual(len(dataloader_seq), 5)
1356        self.assertEqual(len(dataloader_shuffle), 5)
1357
1358    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1359    def test_sequential_pin_memory(self):
1360        loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True)
1361        for input, target in loader:
1362            self.assertTrue(input.is_pinned())
1363            self.assertTrue(target.is_pinned())
1364
1365    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1366    def test_multiple_dataloaders(self):
1367        for multiprocessing_context in supported_multiprocessing_contexts:
1368            loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1))
1369            loader2_it = iter(
1370                self._get_data_loader(
1371                    self.dataset,
1372                    num_workers=2,
1373                    multiprocessing_context=multiprocessing_context,
1374                )
1375            )
1376            next(loader1_it)
1377            next(loader1_it)
1378            next(loader2_it)
1379            next(loader2_it)
1380            next(loader1_it)
1381            next(loader2_it)
1382            del loader1_it
1383            del loader2_it
1384
1385    @unittest.skipIf(True, "This test is disabled in pytorch/pytorch")
1386    def test_segfault(self):
1387        p = ErrorTrackingProcess(target=_test_segfault)
1388        p.start()
1389        p.join(JOIN_TIMEOUT)
1390        try:
1391            self.assertFalse(p.is_alive())
1392            self.assertNotEqual(p.exitcode, 0)
1393            if IS_WINDOWS:
1394                self.assertIsInstance(p.exception, OSError)
1395                self.assertRegex(str(p.exception), r"access violation reading ")
1396            else:
1397                self.assertIsInstance(p.exception, RuntimeError)
1398                self.assertRegex(
1399                    str(p.exception),
1400                    r"DataLoader worker \(pid \d+\) is killed by signal: ",
1401                )
1402        finally:
1403            p.terminate()
1404
1405    # Tests if the child process forked by the DataLoader segfaults due to having more than 3 threads
1406    # in the parent process after at least one set_num_threads invocation in the parent process.
1407    # After forking, set_num_threads(1) in the child process entails handling some inherited data-structures
1408    # of the Caffe2 thread-pool of the parent process, culminating in a segfault.
1409    # Reference: https://github.com/pytorch/pytorch/issues/54752
1410    @unittest.skipIf(IS_WINDOWS, "Needs fork")
1411    def test_no_segfault(self):
1412        p = ErrorTrackingProcess(target=_test_no_segfault)
1413        p.start()
1414        p.join(JOIN_TIMEOUT)
1415        try:
1416            self.assertFalse(p.is_alive())
1417            if p.exception:
1418                self.assertIsInstance(p.exception, RuntimeError)
1419                self.assertRegex(
1420                    str(p.exception),
1421                    r"DataLoader worker \(pid \d+\) is killed by signal: ",
1422                )
1423                self.fail("Segfault occurred in worker process after fork")
1424        finally:
1425            p.terminate()
1426
1427    def test_timeout(self):
1428        if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
1429            # This test runs in a subprocess, which can only initialize CUDA with spawn.
1430            # _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is
1431            # constructed.
1432            targets = (_test_timeout, _test_timeout_pin_memory)
1433        else:
1434            targets = (_test_timeout,)
1435        for target in targets:
1436            p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,))
1437            p.start()
1438            p.join(JOIN_TIMEOUT)
1439            try:
1440                self.assertFalse(p.is_alive())
1441                self.assertNotEqual(p.exitcode, 0)
1442                self.assertIsInstance(p.exception, RuntimeError)
1443                self.assertRegex(
1444                    str(p.exception), r"DataLoader timed out after \d+ seconds"
1445                )
1446            finally:
1447                p.terminate()
1448
1449    def test_large_sampler_indices(self):
1450        # Test that the data loader cleanly exit when the process errors
1451        #   1. having an reference to the iterator
1452        #   2. using a sampler that yields big elements s.t. _index_queues putters block
1453        #
1454        # More context: https://github.com/pytorch/pytorch/issues/48666
1455
1456        p = ErrorTrackingProcess(
1457            target=_test_large_sampler_indices, args=(self.persistent_workers,)
1458        )
1459        p.start()
1460        p.join(JOIN_TIMEOUT)
1461        try:
1462            self.assertFalse(p.is_alive())
1463            self.assertNotEqual(p.exitcode, 0)
1464            self.assertIsInstance(p.exception, RuntimeError)
1465            self.assertRegex(str(p.exception), r"My Error")
1466        finally:
1467            p.terminate()
1468
1469    def test_invalid_ctor_args_combinations(self):
1470        # general
1471        with self.assertRaisesRegex(
1472            ValueError, "num_workers option should be non-negative"
1473        ):
1474            self._get_data_loader(self.dataset, num_workers=-1)
1475        with self.assertRaisesRegex(
1476            ValueError, "timeout option should be non-negative"
1477        ):
1478            self._get_data_loader(self.dataset, timeout=-1)
1479
1480        # disable auto-batching
1481        with self.assertRaisesRegex(
1482            ValueError,
1483            "batch_size=None option disables auto-batching and is mutually exclusive",
1484        ):
1485            self._get_data_loader(self.dataset, batch_size=None, drop_last=True)
1486
1487        valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1]
1488        with self.assertRaisesRegex(
1489            ValueError, r"multi-process loading \(num_workers > 0\), but got"
1490        ):
1491            self._get_data_loader(
1492                self.dataset, num_workers=0, multiprocessing_context=valid_ctx
1493            )
1494        with self.assertRaisesRegex(
1495            ValueError, "should specify a valid start method in"
1496        ):
1497            self._get_data_loader(
1498                self.dataset, num_workers=1, multiprocessing_context="bad"
1499            )
1500        with self.assertRaisesRegex(
1501            TypeError, "multiprocessing_context option should be a valid context "
1502        ):
1503            self._get_data_loader(
1504                self.dataset, num_workers=1, multiprocessing_context=object()
1505            )
1506
1507        # map-style
1508        sampler = torch.utils.data.SequentialSampler(self.dataset)
1509        batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False)
1510        with self.assertRaisesRegex(
1511            ValueError, "sampler option is mutually exclusive with shuffle"
1512        ):
1513            self._get_data_loader(
1514                self.dataset, batch_size=11, sampler=sampler, shuffle=True
1515            )
1516        with self.assertRaisesRegex(
1517            ValueError, "sampler option is mutually exclusive with shuffle"
1518        ):
1519            self._get_data_loader(
1520                self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True
1521            )
1522        with self.assertRaisesRegex(
1523            ValueError, "sampler option is mutually exclusive with shuffle"
1524        ):
1525            self._get_data_loader(
1526                self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3
1527            )
1528        with self.assertRaisesRegex(
1529            ValueError, "batch_sampler option is mutually exclusive with"
1530        ):
1531            self._get_data_loader(
1532                self.dataset, batch_size=11, batch_sampler=batch_sampler
1533            )
1534        with self.assertRaisesRegex(
1535            ValueError, "batch_sampler option is mutually exclusive with"
1536        ):
1537            self._get_data_loader(
1538                self.dataset, shuffle=True, batch_sampler=batch_sampler
1539            )
1540        with self.assertRaisesRegex(
1541            ValueError, "batch_sampler option is mutually exclusive with"
1542        ):
1543            self._get_data_loader(
1544                self.dataset, drop_last=True, batch_sampler=batch_sampler
1545            )
1546        with self.assertRaisesRegex(
1547            ValueError, "batch_sampler option is mutually exclusive with"
1548        ):
1549            self._get_data_loader(
1550                self.dataset, drop_last=3, batch_sampler=batch_sampler
1551            )
1552
1553        # iterable-style
1554        dataset = CountingIterableDataset(20)
1555        with self.assertRaisesRegex(
1556            ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1557        ):
1558            self._get_data_loader(dataset, shuffle=True)
1559        with self.assertRaisesRegex(
1560            ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1561        ):
1562            self._get_data_loader(dataset, shuffle=3)
1563        with self.assertRaisesRegex(
1564            ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1565        ):
1566            self._get_data_loader(
1567                dataset, sampler=torch.utils.data.SequentialSampler(dataset)
1568            )
1569        with self.assertRaisesRegex(
1570            ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1571        ):
1572            self._get_data_loader(dataset, sampler=3)
1573        with self.assertRaisesRegex(
1574            ValueError,
1575            "DataLoader with IterableDataset: expected unspecified batch_sampler",
1576        ):
1577            self._get_data_loader(
1578                dataset,
1579                batch_sampler=torch.utils.data.BatchSampler(
1580                    torch.utils.data.SequentialSampler(dataset), 3, False
1581                ),
1582            )
1583        with self.assertRaisesRegex(
1584            ValueError,
1585            "DataLoader with IterableDataset: expected unspecified batch_sampler",
1586        ):
1587            self._get_data_loader(dataset, batch_sampler=3)
1588
1589    def test_builtin_collection_conversion(self):
1590        for coll_ty in (list, tuple):
1591            for num_workers in (0, 1):
1592                # map-style dataset
1593                dataset = CountingDataset(20)
1594                # no auto-batching
1595                fetched = coll_ty(
1596                    self._get_data_loader(
1597                        dataset, batch_size=None, num_workers=num_workers
1598                    )
1599                )
1600                self.assertEqual(fetched, coll_ty(range(20)))
1601                # auto-batching
1602                fetched = coll_ty(
1603                    self._get_data_loader(
1604                        dataset, batch_size=2, num_workers=num_workers
1605                    )
1606                )
1607                self.assertEqual(
1608                    fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1609                )
1610
1611                # iterable-style dataset
1612                dataset = CountingIterableDataset(20)
1613                # no auto-batching
1614                fetched = coll_ty(
1615                    self._get_data_loader(
1616                        dataset, batch_size=None, num_workers=num_workers
1617                    )
1618                )
1619                self.assertEqual(fetched, coll_ty(range(20)))
1620                # auto-batching
1621                # this IterableDataset isn't configured for each worker, so for
1622                # the equality test below to be valid, we cannot have more than 1 workers.
1623                assert num_workers in [0, 1], "invalid test"
1624                fetched = coll_ty(
1625                    self._get_data_loader(
1626                        dataset, batch_size=2, num_workers=num_workers
1627                    )
1628                )
1629                self.assertEqual(
1630                    fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1631                )
1632
1633    def test_iterable_style_dataset(self):
1634        # [no auto-batching] single process loading
1635        dataset = CountingIterableDataset(20)
1636        dataloader = self._get_data_loader(dataset, batch_size=None)
1637        fetched = list(dataloader)
1638        self.assertEqual(len(fetched), 20)
1639        for i, d in enumerate(fetched):
1640            # non-batched should not convert ints into tensors
1641            self.assertIsInstance(d, int)
1642            self.assertEqual(d, i)
1643        # DataLoader should match len of the iterable-style dataset (if implemented)
1644        self.assertEqual(len(dataloader), len(dataset))
1645
1646        # [no auto-batching] multiprocessing loading
1647        num_workers = 3
1648        sizes_for_all_workers = [0, 4, 20]
1649        expected = sorted(
1650            functools.reduce(
1651                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1652            )
1653        )
1654        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1655        for prefetch_factor in [2, 3, 4]:
1656            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1657            dataloader = self._get_data_loader(
1658                dataset,
1659                num_workers=num_workers,
1660                batch_size=None,
1661                worker_init_fn=set_faulthander_if_available,
1662                prefetch_factor=prefetch_factor,
1663            )
1664            dataloader_iter = iter(dataloader)
1665            fetched = sorted(dataloader_iter)
1666            for a, b in zip(fetched, expected):
1667                # non-batched should not convert ints into tensors
1668                self.assertIsInstance(a, int)
1669                self.assertEqual(a, b)
1670            # DataLoader should match len of the iterable-style dataset (if implemented)
1671            self.assertEqual(len(dataloader), len(dataset))
1672            # When loading more than len(dataset) data, after accessing len(dataloader),
1673            # we should get a warning. See NOTE [ IterableDataset and __len__ ].
1674            dataset = CountingIterableDataset(20)
1675            dataloader = self._get_data_loader(
1676                dataset,
1677                num_workers=num_workers,
1678                worker_init_fn=set_faulthander_if_available,
1679                prefetch_factor=prefetch_factor,
1680            )
1681            it = iter(dataloader)
1682            for _ in range(40):
1683                self.assertNotWarn(
1684                    lambda: next(it), "Should not warn before accessing len(dataloader)"
1685                )
1686            self.assertEqual(len(dataloader), len(dataset))
1687            self.assertEqual(len(dataloader), 20)
1688            it = iter(dataloader)
1689            for _ in range(20):
1690                self.assertNotWarn(
1691                    lambda: next(it), "Should not warn before exceeding length"
1692                )
1693            for _ in range(3):
1694                with self.assertWarnsRegex(
1695                    UserWarning,
1696                    r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
1697                    msg="Should always warn after exceeding length",
1698                ):
1699                    next(it)
1700        # [no auto-batching] test that workers exit gracefully
1701        workers = dataloader_iter._workers
1702        del dataloader_iter
1703        del dataloader
1704        try:
1705            for w in workers:
1706                w.join(JOIN_TIMEOUT)
1707                self.assertFalse(w.is_alive())
1708                self.assertEqual(w.exitcode, 0)
1709        finally:
1710            for w in workers:
1711                w.terminate()
1712
1713        # [auto-batching] single process loading
1714        dataset = CountingIterableDataset(20)
1715        fetched = list(self._get_data_loader(dataset, batch_size=7))
1716        self.assertEqual(len(fetched), 3)
1717        self.assertEqual(fetched[0].tolist(), list(range(7)))
1718        self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1719        self.assertEqual(fetched[2].tolist(), list(range(14, 20)))
1720
1721        # [auto-batching] multiprocessing loading
1722        num_workers = 3
1723        sizes_for_all_workers = [0, 4, 20]
1724        expected = sorted(
1725            functools.reduce(
1726                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1727            )
1728        )
1729        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1730        for prefetch_factor in [2, 3, 4]:
1731            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1732            # worker 0 should return 0 batches
1733            # worker 1 should return 1 batches
1734            # worker 2 should return 3 batches
1735            dataloader = self._get_data_loader(
1736                dataset,
1737                num_workers=num_workers,
1738                batch_size=7,
1739                prefetch_factor=prefetch_factor,
1740            )
1741            dataloader_iter = iter(dataloader)
1742            fetched = list(dataloader_iter)
1743            self.assertEqual(len(fetched), 4)
1744            fetched = {tuple(t.tolist()) for t in fetched}
1745            self.assertEqual(
1746                fetched,
1747                {
1748                    tuple(range(4)),
1749                    tuple(range(7)),
1750                    tuple(range(7, 14)),
1751                    tuple(range(14, 20)),
1752                },
1753            )
1754
1755            # [auto-batching] test that workers exit gracefully
1756            workers = dataloader_iter._workers
1757            del dataloader_iter
1758            del dataloader
1759            try:
1760                for w in workers:
1761                    w.join(JOIN_TIMEOUT)
1762                    self.assertFalse(w.is_alive())
1763                    self.assertEqual(w.exitcode, 0)
1764            finally:
1765                for w in workers:
1766                    w.terminate()
1767        # [auto-batching & drop_last] single process loading
1768        dataset = CountingIterableDataset(20)
1769        fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True))
1770        self.assertEqual(len(fetched), 2)
1771        self.assertEqual(fetched[0].tolist(), list(range(7)))
1772        self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1773
1774        # [auto-batching & drop_last] multiprocessing loading
1775        num_workers = 3
1776        sizes_for_all_workers = [0, 4, 20]
1777        expected = sorted(
1778            functools.reduce(
1779                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1780            )
1781        )
1782        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1783        for prefetch_factor in [2, 3, 4]:
1784            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1785            # worker 0 should return 0 batches
1786            # worker 1 should return 1 batches
1787            # worker 2 should return 3 batches
1788            dataloader = self._get_data_loader(
1789                dataset,
1790                num_workers=num_workers,
1791                batch_size=7,
1792                drop_last=True,
1793                worker_init_fn=set_faulthander_if_available,
1794                prefetch_factor=prefetch_factor,
1795            )
1796            dataloader_iter = iter(dataloader)
1797            fetched = list(dataloader_iter)
1798            self.assertEqual(len(fetched), 2)
1799            fetched = {tuple(t.tolist()) for t in fetched}
1800            self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
1801
1802            # [auto-batching & drop_last] test that workers exit gracefully
1803            workers = dataloader_iter._workers
1804            del dataloader_iter
1805            del dataloader
1806            try:
1807                for w in workers:
1808                    w.join(JOIN_TIMEOUT)
1809                    self.assertFalse(w.is_alive())
1810                    self.assertEqual(w.exitcode, 0)
1811            finally:
1812                for w in workers:
1813                    w.terminate()
1814
1815    def test_chain_iterable_style_dataset(self):
1816        # chaining (concatenation)
1817        dataset1 = CountingIterableDataset(20)
1818        dataset2 = CountingIterableDataset(15)
1819        expected = list(range(20)) + list(range(15))
1820        for num_workers in [0, 1]:
1821            for chained_dataset in [
1822                dataset1 + dataset2,
1823                ChainDataset([dataset1, dataset2]),
1824            ]:
1825                fetched = list(
1826                    self._get_data_loader(chained_dataset, num_workers=num_workers)
1827                )
1828                self.assertEqual(len(fetched), len(expected))
1829                for e, d in zip(expected, fetched):
1830                    self.assertIsInstance(d, torch.Tensor)
1831                    self.assertEqual(e, d)
1832
1833        with self.assertRaisesRegex(
1834            AssertionError, "ChainDataset only supports IterableDataset"
1835        ):
1836            list(iter(dataset1 + self.dataset))
1837
1838        with self.assertRaisesRegex(
1839            AssertionError, "ChainDataset only supports IterableDataset"
1840        ):
1841            list(iter(ChainDataset([dataset1, self.dataset])))
1842
1843    @unittest.skipIf(IS_MACOS, "Not working on macos")
1844    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1845    @skipIfRocm  # https://github.com/pytorch/pytorch/issues/90940
1846    def test_multiprocessing_contexts(self):
1847        reference = [
1848            torch.arange(3),
1849            torch.arange(3, 6),
1850            torch.arange(6, 9),
1851            torch.arange(9, 11),
1852        ]
1853        counting_ds_n = 11
1854        dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA))
1855        for ctx in supported_multiprocessing_contexts:
1856            # windows and jetson devices don't support sharing cuda tensor; ROCm does not yet fully support IPC
1857            if (
1858                ctx in ["spawn", "forkserver"]
1859                and TEST_CUDA
1860                and not IS_WINDOWS
1861                and not IS_JETSON
1862            ):
1863                ds_cls = CUDACountingDataset
1864            else:
1865                ds_cls = CountingDataset
1866            self.assertEqual(
1867                reference,
1868                list(
1869                    self._get_data_loader(
1870                        ds_cls(counting_ds_n),
1871                        multiprocessing_context=ctx,
1872                        **dl_common_args,
1873                    )
1874                ),
1875            )
1876            if ctx is not None:
1877                # test ctx object
1878                ctx = mp.get_context(ctx)
1879                self.assertEqual(
1880                    reference,
1881                    list(
1882                        self._get_data_loader(
1883                            ds_cls(counting_ds_n),
1884                            multiprocessing_context=ctx,
1885                            **dl_common_args,
1886                        )
1887                    ),
1888                )
1889
1890    def _test_multiprocessing_iterdatapipe(self, with_dill):
1891        # Testing to make sure that function from global scope (e.g. imported from library) can be serialized
1892        # and used with multiprocess DataLoader
1893
1894        reference = [
1895            torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1896            torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1897        ]
1898        datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]])
1899        datapipe = datapipe.map(row_processor)
1900        datapipe = (
1901            datapipe.filter(lambda row: len(row) == 4)
1902            if with_dill
1903            else datapipe.filter(filter_len)
1904        )
1905
1906        dl_common_args = dict(
1907            num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA)
1908        )
1909        for ctx in supported_multiprocessing_contexts:
1910            self.assertEqual(
1911                reference,
1912                [
1913                    t.type(torch.int64)
1914                    for t in self._get_data_loader(
1915                        datapipe, multiprocessing_context=ctx, **dl_common_args
1916                    )
1917                ],
1918            )
1919            if ctx is not None:
1920                # test ctx object
1921                ctx = mp.get_context(ctx)
1922                self.assertEqual(
1923                    reference,
1924                    [
1925                        t.type(torch.int64)
1926                        for t in self._get_data_loader(
1927                            datapipe, multiprocessing_context=ctx, **dl_common_args
1928                        )
1929                    ],
1930                )
1931
1932    @skipIfNoNumpy
1933    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1934    def test_multiprocessing_iterdatapipe(self):
1935        self._test_multiprocessing_iterdatapipe(with_dill=False)
1936
1937    @unittest.expectedFailure
1938    @skipIfNoNumpy
1939    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1940    @skipIfNoDill
1941    def test_multiprocessing_iterdatapipe_with_dill(self):
1942        self._test_multiprocessing_iterdatapipe(with_dill=True)
1943
1944    def test_worker_seed(self):
1945        num_workers = 6
1946        batch_size = 1
1947        dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1948        dataloader = self._get_data_loader(
1949            dataset, batch_size=batch_size, num_workers=num_workers
1950        )
1951        seeds = set()
1952        seeds.update(batch[0] for batch in dataloader)
1953        self.assertEqual(len(seeds), num_workers)
1954
1955    def test_worker_seed_reproducibility(self):
1956        def get_dataloader():
1957            return DataLoader(
1958                dataset,
1959                batch_size=batch_size,
1960                num_workers=num_workers,
1961                generator=torch.Generator().manual_seed(42),
1962            )
1963
1964        num_workers = 6
1965        batch_size = 1
1966        dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1967        self.assertEqual(
1968            {int(batch) for batch in get_dataloader()},
1969            {int(batch) for batch in get_dataloader()},
1970        )
1971
1972    def test_multi_epochs_reproducibility(self):
1973        num_workers = 2
1974        batch_size = 10
1975        num_epochs = 3
1976
1977        dataset = TestMultiEpochDataset(batch_size * num_workers)
1978        dataloader = self._get_data_loader(
1979            dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
1980        )
1981
1982        for ind in range(num_epochs):
1983            for batch_idx, sample in enumerate(dataloader):
1984                self.assertEqual(
1985                    sample.tolist(), [batch_idx % num_workers] * batch_size
1986                )
1987
1988    def test_worker_init_fn(self):
1989        dataset = SeedDataset(4)
1990        dataloader = self._get_data_loader(
1991            dataset, batch_size=2, num_workers=2, worker_init_fn=init_fn
1992        )
1993        for batch in dataloader:
1994            self.assertEqual(12345, batch[0])
1995            self.assertEqual(12345, batch[1])
1996
1997    def test_get_worker_info(self):
1998        p = ErrorTrackingProcess(target=_test_get_worker_info)
1999        p.start()
2000        p.join(JOIN_TIMEOUT)
2001        try:
2002            self.assertFalse(p.is_alive())
2003            self.assertEqual(p.exitcode, 0)
2004        finally:
2005            p.terminate()
2006
2007    def test_shuffle(self):
2008        self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True))
2009
2010    def test_shuffle_batch_none(self):
2011        self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True))
2012
2013    def test_shuffle_batch(self):
2014        self._test_shuffle(
2015            self._get_data_loader(self.dataset, batch_size=2, shuffle=True)
2016        )
2017
2018    def test_shuffle_reproducibility(self):
2019        for fn in (
2020            lambda: DataLoader(
2021                self.dataset,
2022                shuffle=True,
2023                num_workers=0,
2024                generator=torch.Generator().manual_seed(42),
2025            ),
2026            lambda: DataLoader(
2027                self.dataset,
2028                shuffle=True,
2029                num_workers=2,
2030                generator=torch.Generator().manual_seed(42),
2031            ),
2032        ):
2033            self.assertEqual(list(fn()), list(fn()))
2034
2035    def test_sequential_workers(self):
2036        self._test_sequential(self._get_data_loader(self.dataset, num_workers=4))
2037
2038    def test_seqential_batch_workers(self):
2039        self._test_sequential(
2040            self._get_data_loader(self.dataset, batch_size=2, num_workers=4)
2041        )
2042
2043    def test_seqential_batch_workers_prefetch(self):
2044        self._test_sequential(
2045            DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3)
2046        )
2047
2048    def test_shuffle_workers(self):
2049        self._test_shuffle(
2050            self._get_data_loader(self.dataset, shuffle=True, num_workers=4)
2051        )
2052
2053    def test_shuffle_batch_workers(self):
2054        self._test_shuffle(
2055            self._get_data_loader(
2056                self.dataset, batch_size=2, shuffle=True, num_workers=4
2057            )
2058        )
2059
2060    def test_shuffle_batch_workers_prefetch(self):
2061        self._test_shuffle(
2062            DataLoader(
2063                self.dataset,
2064                batch_size=2,
2065                shuffle=True,
2066                num_workers=4,
2067                prefetch_factor=3,
2068            )
2069        )
2070
2071    def test_random_sampler(self):
2072        from collections import Counter
2073
2074        from torch.utils.data import RandomSampler
2075
2076        def sample_stat(sampler, num_samples):
2077            counts = Counter(sampler)
2078            count_repeated = sum(val > 1 for val in counts.values())
2079            return (
2080                count_repeated,
2081                min(counts.keys()),
2082                max(counts.keys()),
2083                sum(counts.values()),
2084            )
2085
2086        # test sample with replacement
2087        n = len(self.dataset) + 1  # ensure at least one sample is drawn more than once
2088        sampler_with_replacement = RandomSampler(
2089            self.dataset, replacement=True, num_samples=n
2090        )
2091        count_repeated, minval, maxval, count_total = sample_stat(
2092            sampler_with_replacement, n
2093        )
2094        self.assertTrue(count_repeated > 0)
2095        self.assertTrue(minval >= 0)
2096        self.assertTrue(maxval < len(self.dataset))
2097        self.assertTrue(count_total == n)
2098
2099        # test sample without replacement and without specified num_samples
2100        sampler_without_replacement = RandomSampler(self.dataset)
2101        count_repeated, minval, maxval, count_total = sample_stat(
2102            sampler_without_replacement, len(self.dataset)
2103        )
2104        self.assertTrue(count_repeated == 0)
2105        self.assertTrue(minval == 0)
2106        self.assertTrue(maxval == len(self.dataset) - 1)
2107        self.assertTrue(count_total == len(self.dataset))
2108
2109        # test sample without replacement and with specified num_samples
2110        n = len(self.dataset) * 2
2111        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2112        count_repeated, minval, maxval, count_total = sample_stat(
2113            sampler_without_replacement, len(self.dataset)
2114        )
2115        self.assertTrue(count_repeated == len(self.dataset))
2116        self.assertTrue(minval == 0)
2117        self.assertTrue(maxval == len(self.dataset) - 1)
2118        self.assertTrue(count_total == n)
2119
2120        n = len(self.dataset) - 1
2121        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2122        count_repeated, minval, maxval, count_total = sample_stat(
2123            sampler_without_replacement, len(self.dataset)
2124        )
2125        self.assertTrue(count_repeated == 0)
2126        self.assertTrue(minval >= 0)
2127        self.assertTrue(maxval < len(self.dataset))
2128        self.assertTrue(count_total == n)
2129
2130        n = len(self.dataset) + 1
2131        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2132        count_repeated, minval, maxval, count_total = sample_stat(
2133            sampler_without_replacement, len(self.dataset)
2134        )
2135        self.assertTrue(count_repeated == 1)
2136        self.assertTrue(minval == 0)
2137        self.assertTrue(maxval == len(self.dataset) - 1)
2138        self.assertTrue(count_total == n)
2139
2140        # raise error when replacement is non-boolean
2141        with self.assertRaisesRegex(
2142            TypeError, "replacement should be a boolean value, but got replacement=0"
2143        ):
2144            RandomSampler(self.dataset, replacement=0)
2145
2146    def test_random_sampler_len_with_replacement(self):
2147        from torch.utils.data import RandomSampler
2148
2149        # add 5 extra samples
2150        num_samples = len(self.dataset) + 5
2151        sampler = RandomSampler(self.dataset, replacement=True, num_samples=num_samples)
2152        # test len method
2153        self.assertEqual(num_samples, len(sampler))
2154
2155        # test with iteration
2156        count_num_samples = sum(1 for _ in sampler)
2157        self.assertEqual(num_samples, count_num_samples)
2158
2159        # test with dataloader, batch_size = 1
2160        batch_size = 1
2161        count_num_samples_in_data_loader = len(
2162            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2163        )
2164        self.assertEqual(num_samples, count_num_samples_in_data_loader)
2165
2166        # test with dataloader, batch_size = 6
2167        batch_size = 6
2168        count_num_samples_in_data_loader = len(
2169            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2170        )
2171        self.assertEqual(
2172            int(math.ceil(float(num_samples) / batch_size)),
2173            count_num_samples_in_data_loader,
2174        )
2175
2176    def test_random_sampler_len_without_replacement(self):
2177        from torch.utils.data import RandomSampler
2178
2179        # add 5 extra samples
2180        num_samples = len(self.dataset) + 5
2181        sampler = RandomSampler(
2182            self.dataset, replacement=False, num_samples=num_samples
2183        )
2184        # test len method
2185        self.assertEqual(num_samples, len(sampler))
2186
2187        # test with iteration
2188        count_num_samples = sum(1 for _ in sampler)
2189        self.assertEqual(num_samples, count_num_samples)
2190
2191        # test with dataloader, batch_size = 1
2192        batch_size = 1
2193        count_num_samples_in_data_loader = len(
2194            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2195        )
2196        self.assertEqual(num_samples, count_num_samples_in_data_loader)
2197
2198        # test with dataloader, batch_size = 6
2199        batch_size = 6
2200        count_num_samples_in_data_loader = len(
2201            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2202        )
2203        self.assertEqual(
2204            num_samples // batch_size + (num_samples % batch_size > 0),
2205            count_num_samples_in_data_loader,
2206        )
2207
2208    def test_distributed_sampler_invalid_rank(self):
2209        from torch.utils.data.distributed import DistributedSampler
2210
2211        dataset = torch.IntTensor(range(10))
2212        with self.assertRaisesRegex(ValueError, "Invalid rank"):
2213            sampler = DistributedSampler(dataset, 3, 3)
2214
2215        with self.assertRaisesRegex(ValueError, "Invalid rank"):
2216            sampler = DistributedSampler(dataset, 3, -1)
2217
2218    def test_duplicating_data_with_drop_last(self):
2219        from torch.utils.data.distributed import DistributedSampler
2220
2221        num_processes = 4
2222        num_batches = 9
2223        data_set = torch.IntTensor(range(num_batches))
2224        scanned_data = torch.IntTensor([])
2225        for i in range(num_processes):
2226            s = DistributedSampler(data_set, num_processes, i)
2227            d_loader = self._get_data_loader(
2228                data_set,
2229                batch_size=int(num_batches / num_processes),
2230                drop_last=True,
2231                sampler=s,
2232            )
2233            for data in d_loader:
2234                scanned_data = torch.cat((scanned_data, data), 0)
2235
2236        self.assertEqual(scanned_data.size(), scanned_data.unique().size())
2237
2238    def test_sampler_reproducibility(self):
2239        from torch.utils.data import (
2240            RandomSampler,
2241            SubsetRandomSampler,
2242            WeightedRandomSampler,
2243        )
2244
2245        weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6]
2246        for fn in (
2247            lambda: RandomSampler(
2248                self.dataset,
2249                num_samples=5,
2250                replacement=True,
2251                generator=torch.Generator().manual_seed(42),
2252            ),
2253            lambda: RandomSampler(
2254                self.dataset,
2255                replacement=False,
2256                generator=torch.Generator().manual_seed(42),
2257            ),
2258            lambda: WeightedRandomSampler(
2259                weights,
2260                num_samples=5,
2261                replacement=True,
2262                generator=torch.Generator().manual_seed(42),
2263            ),
2264            lambda: WeightedRandomSampler(
2265                weights,
2266                num_samples=5,
2267                replacement=False,
2268                generator=torch.Generator().manual_seed(42),
2269            ),
2270            lambda: SubsetRandomSampler(
2271                range(10), generator=torch.Generator().manual_seed(42)
2272            ),
2273        ):
2274            self.assertEqual(list(fn()), list(fn()))
2275
2276        for sampler in (
2277            RandomSampler(self.dataset, num_samples=5, replacement=True),
2278            RandomSampler(self.dataset, replacement=False),
2279            WeightedRandomSampler(weights, num_samples=5, replacement=True),
2280            WeightedRandomSampler(weights, num_samples=5, replacement=False),
2281            SubsetRandomSampler(range(10)),
2282        ):
2283            torch.manual_seed(0)
2284            l1 = list(sampler) + list(sampler)
2285
2286            torch.manual_seed(0)
2287            l2 = list(sampler) + list(sampler)
2288            self.assertEqual(l1, l2)
2289
2290            its = (iter(sampler), iter(sampler))
2291            ls = ([], [])
2292            for idx in range(len(sampler)):
2293                for i in range(2):
2294                    if idx == 0:
2295                        torch.manual_seed(0)
2296                    ls[i].append(next(its[i]))
2297            self.assertEqual(ls[0], ls[1])
2298
2299    def _test_sampler(self, **kwargs):
2300        indices = range(2, 12)  # using a regular iterable
2301        dl = self._get_data_loader(
2302            self.dataset, sampler=indices, batch_size=2, **kwargs
2303        )
2304        self.assertEqual(len(dl), 5)
2305        for i, (input, _target) in enumerate(dl):
2306            self.assertEqual(len(input), 2)
2307            self.assertEqual(input, self.data[i * 2 + 2 : i * 2 + 4])
2308
2309    def test_sampler(self):
2310        self._test_sampler()
2311        self._test_sampler(num_workers=4)
2312        if not NO_MULTIPROCESSING_SPAWN:
2313            self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2314
2315    def _test_batch_sampler(self, **kwargs):
2316        # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
2317        batches = []  # using a regular iterable
2318        for i in range(0, 20, 5):
2319            batches.append(tuple(range(i, i + 2)))
2320            batches.append(tuple(range(i + 2, i + 5)))
2321
2322        dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs)
2323        self.assertEqual(len(dl), 8)
2324        for i, (input, _target) in enumerate(dl):
2325            if i % 2 == 0:
2326                offset = i * 5 // 2
2327                self.assertEqual(len(input), 2)
2328                self.assertEqual(input, self.data[offset : offset + 2])
2329            else:
2330                offset = i * 5 // 2
2331                self.assertEqual(len(input), 3)
2332                self.assertEqual(input, self.data[offset : offset + 3])
2333
2334    def test_batch_sampler(self):
2335        self._test_batch_sampler()
2336        self._test_batch_sampler(num_workers=4)
2337        if not NO_MULTIPROCESSING_SPAWN:
2338            self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2339
2340    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2341    def test_shuffle_pin_memory(self):
2342        loader = self._get_data_loader(
2343            self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
2344        )
2345        for input, target in loader:
2346            self.assertTrue(input.is_pinned())
2347            self.assertTrue(target.is_pinned())
2348
2349    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2350    def test_numpy(self):
2351        import numpy as np
2352
2353        class TestDataset(torch.utils.data.Dataset):
2354            def __getitem__(self, i):
2355                return np.ones((2, 3, 4)) * i
2356
2357            def __len__(self):
2358                return 1000
2359
2360        loader = self._get_data_loader(TestDataset(), batch_size=12)
2361        batch = next(iter(loader))
2362        self.assertIsInstance(batch, torch.DoubleTensor)
2363        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
2364
2365    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2366    def test_numpy_gen_state(self):
2367        from torch.utils.data._utils.worker import _generate_state
2368
2369        # Using NumPy generated states as the reference to test `_generate_state`
2370        # having the same result.
2371        # Test case: ((worker_id, base_seed), expected_state)
2372        test_cases = [
2373            (
2374                (4, 13434589827475259383),
2375                (2884386318, 1088094898, 3523808998, 3860348662),
2376            ),
2377            ((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)),
2378            (
2379                (10, 978296274032934101),
2380                (1759791917, 3550927336, 1225977135, 1036538043),
2381            ),
2382            (
2383                (12, 11868770762134256968),
2384                (3974661794, 3331131333, 3630387033, 2885815368),
2385            ),
2386            (
2387                (9, 15378787925219019706),
2388                (3815056996, 3162224466, 2735102421, 3190253477),
2389            ),
2390            ((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)),
2391            (
2392                (15, 14617792358407278405),
2393                (3402479508, 1588702753, 1169536393, 3675067356),
2394            ),
2395            (
2396                (9, 17363320784006640087),
2397                (957989458, 2518334477, 1421725660, 3086155459),
2398            ),
2399            (
2400                (12, 480002904169484764),
2401                (2732851467, 1762620729, 4055801988, 1277640511),
2402            ),
2403            (
2404                (15, 16803975943592702950),
2405                (3479415043, 4022359553, 295994005, 3358606349),
2406            ),
2407            (
2408                (9, 11704776406047813044),
2409                (1968928009, 710113752, 2442656196, 1587420279),
2410            ),
2411            (
2412                (10, 16357891985431864516),
2413                (1271733898, 4197047399, 3727213786, 2338547348),
2414            ),
2415            (
2416                (2, 17423369006318065007),
2417                (544294336, 1911284083, 3299147734, 3231058347),
2418            ),
2419            ((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)),
2420            ((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)),
2421            (
2422                (6, 6269787272229682235),
2423                (2548857855, 1216457374, 1012973562, 2999759647),
2424            ),
2425        ]
2426
2427        for (worker_id, base_seed), exp in test_cases:
2428            self.assertEqual(exp, _generate_state(base_seed, worker_id))
2429
2430    def test_error(self):
2431        self._test_error(
2432            self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True)
2433        )
2434
2435    def test_error_workers(self):
2436        self._test_error(
2437            self._get_data_loader(
2438                ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4
2439            )
2440        )
2441
2442    @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
2443    def test_partial_workers(self):
2444        r"""Check that workers exit even if the iterator is not exhausted."""
2445        if TEST_CUDA:
2446            pin_memory_configs = (True, False)
2447        else:
2448            pin_memory_configs = (False,)
2449
2450        for pin_memory in pin_memory_configs:
2451            loader = iter(
2452                self._get_data_loader(
2453                    self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory
2454                )
2455            )
2456            workers = loader._workers
2457            if pin_memory:
2458                pin_memory_thread = loader._pin_memory_thread
2459            for i, _ in enumerate(loader):
2460                if i == 10:
2461                    break
2462            assert i == 10
2463            del loader
2464            for w in workers:
2465                w.join(JOIN_TIMEOUT)
2466                self.assertFalse(w.is_alive(), "subprocess not terminated")
2467            if pin_memory:
2468                pin_memory_thread.join(JOIN_TIMEOUT)
2469                self.assertFalse(pin_memory_thread.is_alive())
2470
2471    # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
2472    @skipIfRocm
2473    @unittest.skipIf(not HAS_PSUTIL, "psutil not found")
2474    @slowTest
2475    def test_proper_exit(self):
2476        (
2477            r"""There might be ConnectionResetError or leaked semaphore warning """
2478            r"""(due to dirty process exit), but they are all safe to ignore"""
2479        )
2480
2481        # TODO: test the case where the pin_memory_thread triggers an
2482        #       error/fatal signal. I haven't found out how to properly do that.
2483
2484        for (
2485            is_iterable_dataset,
2486            use_workers,
2487            pin_memory,
2488            hold_iter_reference,
2489        ) in itertools.product([True, False], repeat=4):
2490            # `hold_iter_reference` specifies whether we hold a reference to the
2491            # iterator. This is interesting because Python3 error traces holds a
2492            # reference to the frames, which hold references to all the local
2493            # variables including the iterator, and then the iterator dtor may
2494            # not be called before process end. It is important to see that the
2495            # processes still exit in both cases.
2496
2497            if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS):
2498                # This test runs in a subprocess, which can only initialize CUDA with spawn.
2499                # DataLoader with pin_memory=True initializes CUDA when its iterator is constructed.
2500                # For windows, pin_memory sometimes causes CUDA oom.
2501                continue
2502
2503            # `exit_method` controls the way the loader process ends.
2504            #   - `*_kill` means that `*` is killed by OS.
2505            #   - `*_error` means that `*` raises an error.
2506            #   - `None` means that no error happens.
2507            # In all cases, all processes should end properly.
2508            if use_workers:
2509                # TODO: Fix test for 'loader_kill' that would cause running out of shared memory.
2510                # Killing loader process would prevent DataLoader iterator clean up all queues
2511                # and worker processes
2512                exit_methods = [None, "loader_error", "worker_error", "worker_kill"]
2513                persistent_workers = self.persistent_workers
2514            else:
2515                exit_methods = [None, "loader_error", "loader_kill"]
2516                persistent_workers = False
2517
2518            for exit_method in exit_methods:
2519                if exit_method == "worker_kill":
2520                    # FIXME: This sometimes hangs. See #16608.
2521                    continue
2522
2523                desc = []
2524                desc.append(f"is_iterable_dataset={is_iterable_dataset}")
2525                desc.append(f"use_workers={use_workers}")
2526                desc.append(f"pin_memory={pin_memory}")
2527                desc.append(f"hold_iter_reference={hold_iter_reference}")
2528                desc.append(f"exit_method={exit_method}")
2529                desc = "test_proper_exit with " + ", ".join(desc)
2530
2531                # Event that the loader process uses to signal testing process
2532                # that various things are setup, including that the worker pids
2533                # are specified in `worker_pids` array.
2534                loader_setup_event = mp.Event()
2535
2536                # Event that this process has finished setting up, and the
2537                # loader process can now proceed to trigger error events or
2538                # finish normally.
2539                tester_setup_event = mp.Event()
2540
2541                loader_p = ErrorTrackingProcess(
2542                    target=_test_proper_exit,
2543                    args=(
2544                        is_iterable_dataset,
2545                        use_workers,
2546                        pin_memory,
2547                        exit_method,
2548                        hold_iter_reference,
2549                        loader_setup_event,
2550                        tester_setup_event,
2551                        persistent_workers,
2552                    ),
2553                    disable_stderr=False,
2554                )
2555                loader_p.start()
2556                loader_psutil_p = psutil.Process(loader_p.pid)
2557
2558                # Wait for loader process to set everything up, e.g., starting
2559                # workers.
2560                loader_setup_event.wait(timeout=JOIN_TIMEOUT)
2561                if not loader_setup_event.is_set():
2562                    fail_msg = (
2563                        desc + ": loader process failed to setup within given time"
2564                    )
2565                    if loader_p.exception is not None:
2566                        fail_msg += f", and had exception {loader_p.exception}"
2567                    elif not loader_p.is_alive():
2568                        fail_msg += f", and exited with code {loader_p.exitcode} but had no exception"
2569                    else:
2570                        fail_msg += ", and is still alive."
2571                    if loader_p.is_alive():
2572                        # this may kill the process, needs to run after the above lines
2573                        loader_p.print_traces_of_all_threads()
2574                    self.fail(fail_msg)
2575
2576                # We are certain that the workers have started now.
2577                worker_psutil_ps = loader_psutil_p.children()
2578
2579                def fail(reason):
2580                    report_psutil_attrs = [
2581                        "pid",
2582                        "name",
2583                        "cpu_times",
2584                        "io_counters",
2585                        "memory_full_info",
2586                        "num_ctx_switches",
2587                        "open_files",
2588                        "threads",
2589                        "status",
2590                        "nice",
2591                        "ionice",
2592                    ]
2593                    if reason is None:
2594                        err_msg = desc
2595                    else:
2596                        err_msg = f"{desc}: {reason}"
2597                    err_msg += "\nLoader info:\n\t"
2598                    if loader_psutil_p.is_running():
2599                        err_msg += str(
2600                            loader_psutil_p.as_dict(attrs=report_psutil_attrs)
2601                        )
2602                        # this may kill the process, needs to run after the above line
2603                        loader_p.print_traces_of_all_threads()
2604                    else:
2605                        err_msg += f"exited with code {loader_p.exitcode}"
2606                    if use_workers:
2607                        err_msg += "\nWorker(s) info:"
2608                        for idx, worker_psutil_p in enumerate(worker_psutil_ps):
2609                            err_msg += f"\n\tWorker {idx}:\n\t\t"
2610                            if worker_psutil_p.is_running():
2611                                err_msg += str(
2612                                    worker_psutil_p.as_dict(attrs=report_psutil_attrs)
2613                                )
2614                                # this may kill the process, needs to run after the above line
2615                                print_traces_of_all_threads(worker_psutil_p.pid)
2616                            else:
2617                                err_msg += "exited with unknown code"
2618                    self.fail(err_msg)
2619
2620                tester_setup_event.set()
2621
2622                try:
2623                    loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
2624                    if loader_p.is_alive():
2625                        fail_reason = "loader process did not terminate"
2626                        if loader_p.exception is not None:
2627                            fail(
2628                                fail_reason
2629                                + f", and had exception {loader_p.exception}"
2630                            )
2631                        else:
2632                            fail(fail_reason + ", and had no exception")
2633                    _, alive = psutil.wait_procs(
2634                        worker_psutil_ps,
2635                        timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT),
2636                    )
2637                    if len(alive) > 0:
2638                        fail(
2639                            "worker process (pid(s) {}) did not terminate".format(
2640                                ", ".join(str(p.pid) for p in alive)
2641                            )
2642                        )
2643                    if exit_method is None:
2644                        if loader_p.exitcode != 0:
2645                            fail(
2646                                f"loader process had nonzero exitcode {loader_p.exitcode}"
2647                            )
2648                    else:
2649                        if loader_p.exitcode == 0:
2650                            fail("loader process had zero exitcode")
2651                        if exit_method == "loader_error":
2652                            if not isinstance(
2653                                loader_p.exception, RuntimeError
2654                            ) or "Loader error" not in str(loader_p.exception):
2655                                fail(
2656                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2657                                )
2658                        elif exit_method == "worker_kill":
2659                            if isinstance(loader_p.exception, RuntimeError):
2660                                if "DataLoader worker (pid" not in str(
2661                                    loader_p.exception
2662                                ):
2663                                    fail(
2664                                        f"loader process did not raise expected exception, but had {loader_p.exception}"
2665                                    )
2666                            elif isinstance(loader_p.exception, ConnectionRefusedError):
2667                                # Sometimes, when the worker is being killed and is freeing its
2668                                # resources, the unpickling in loader process will be met an
2669                                # a `ConnectionRefusedError` as it can not open a socket to receive
2670                                # resource. In such cases, the worker may not have fully exited,
2671                                # and the loader can't know this via `is_alive` check or `SIGCHLD`
2672                                # handler. So we permit this as an allowed error as well.
2673                                # After all, we are happy as long as it terminates.
2674                                pass
2675                            else:
2676                                fail(
2677                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2678                                )
2679                        elif exit_method == "worker_error":
2680                            if not isinstance(
2681                                loader_p.exception, RuntimeError
2682                            ) or "Worker error" not in str(loader_p.exception):
2683                                fail(
2684                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2685                                )
2686                finally:
2687                    loader_p.terminate()
2688
2689    def test_len(self):
2690        def check_len(dl, expected):
2691            self.assertEqual(len(dl), expected)
2692            n = 0
2693            for _ in dl:
2694                n += 1
2695            self.assertEqual(n, expected)
2696
2697        check_len(self.dataset, 100)
2698        check_len(self._get_data_loader(self.dataset, batch_size=2), 50)
2699        check_len(self._get_data_loader(self.dataset, batch_size=3), 34)
2700
2701    def test_iterabledataset_len(self):
2702        class IterableDataset(torch.utils.data.IterableDataset):
2703            def __len__(self):
2704                return 10
2705
2706            def __iter__(self):
2707                return iter(range(10))
2708
2709        iterable_loader = DataLoader(IterableDataset(), batch_size=1)
2710        self.assertEqual(len(iterable_loader), 10)
2711        iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True)
2712        self.assertEqual(len(iterable_loader), 10)
2713
2714        iterable_loader = DataLoader(IterableDataset(), batch_size=2)
2715        self.assertEqual(len(iterable_loader), 5)
2716        iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True)
2717        self.assertEqual(len(iterable_loader), 5)
2718
2719        iterable_loader = DataLoader(IterableDataset(), batch_size=3)
2720        self.assertEqual(len(iterable_loader), 4)
2721        iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True)
2722        self.assertEqual(len(iterable_loader), 3)
2723
2724    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2725    def test_numpy_scalars(self):
2726        import numpy as np
2727
2728        class ScalarDataset(torch.utils.data.Dataset):
2729            def __init__(self, dtype):
2730                self.dtype = dtype
2731
2732            def __getitem__(self, i):
2733                return self.dtype()
2734
2735            def __len__(self):
2736                return 4
2737
2738        dtypes = {
2739            np.float64: torch.DoubleTensor,
2740            np.float32: torch.FloatTensor,
2741            np.float16: torch.HalfTensor,
2742            np.int64: torch.LongTensor,
2743            np.int32: torch.IntTensor,
2744            np.int16: torch.ShortTensor,
2745            np.int8: torch.CharTensor,
2746            np.uint8: torch.ByteTensor,
2747        }
2748        for dt, tt in dtypes.items():
2749            dset = ScalarDataset(dt)
2750            loader = self._get_data_loader(dset, batch_size=2)
2751            batch = next(iter(loader))
2752            self.assertIsInstance(batch, tt)
2753
2754    def test_default_convert_mapping_keep_type(self):
2755        data = CustomDict({"a": 1, "b": 2})
2756        converted = _utils.collate.default_convert(data)
2757
2758        self.assertEqual(converted, data)
2759
2760    def test_default_convert_sequence_keep_type(self):
2761        data = CustomList([1, 2, 3])
2762        converted = _utils.collate.default_convert(data)
2763
2764        self.assertEqual(converted, data)
2765
2766    def test_default_convert_sequence_dont_keep_type(self):
2767        data = range(2)
2768        converted = _utils.collate.default_convert(data)
2769
2770        self.assertEqual(converted, [0, 1])
2771
2772    def test_default_collate_dtype(self):
2773        arr = [1, 2, -1]
2774        collated = _utils.collate.default_collate(arr)
2775        self.assertEqual(collated, torch.tensor(arr))
2776        self.assertEqual(collated.dtype, torch.int64)
2777
2778        arr = [1.1, 2.3, -0.9]
2779        collated = _utils.collate.default_collate(arr)
2780        self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))
2781
2782        arr = [True, False]
2783        collated = _utils.collate.default_collate(arr)
2784        self.assertEqual(collated, torch.tensor(arr))
2785        self.assertEqual(collated.dtype, torch.bool)
2786
2787        # Should be a no-op
2788        arr = ["a", "b", "c"]
2789        self.assertEqual(arr, _utils.collate.default_collate(arr))
2790
2791    def test_default_collate_mapping_keep_type(self):
2792        batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
2793        collated = _utils.collate.default_collate(batch)
2794
2795        expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
2796        self.assertEqual(collated, expected)
2797
2798    def test_default_collate_sequence_keep_type(self):
2799        batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
2800        collated = _utils.collate.default_collate(batch)
2801
2802        expected = CustomList(
2803            [
2804                torch.tensor([1, 4]),
2805                torch.tensor([2, 5]),
2806                torch.tensor([3, 6]),
2807            ]
2808        )
2809        self.assertEqual(collated, expected)
2810
2811    def test_default_collate_sequence_dont_keep_type(self):
2812        batch = [range(2), range(2)]
2813        collated = _utils.collate.default_collate(batch)
2814
2815        self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])
2816
2817    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2818    def test_default_collate_bad_numpy_types(self):
2819        import numpy as np
2820
2821        # Should be a no-op
2822        arr = np.array(["a", "b", "c"])
2823        self.assertEqual(arr, _utils.collate.default_collate(arr))
2824
2825        arr = np.array([[["a", "b", "c"]]])
2826        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2827
2828        arr = np.array([object(), object(), object()])
2829        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2830
2831        arr = np.array([[[object(), object(), object()]]])
2832        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2833
2834    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2835    def test_default_collate_numpy_memmap(self):
2836        import numpy as np
2837
2838        with tempfile.TemporaryFile() as f:
2839            arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
2840            arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
2841            arr_memmap[:] = arr[:]
2842            arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
2843            tensor = _utils.collate.default_collate(list(arr_new))
2844
2845        self.assertTrue(
2846            (tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()
2847        )
2848
2849    def test_default_collate_bad_sequence_type(self):
2850        batch = [["X"], ["X", "X"]]
2851        self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
2852        self.assertRaises(
2853            RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])
2854        )
2855
2856    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2857    def test_default_collate_shared_tensor(self):
2858        import numpy as np
2859
2860        t_in = torch.zeros(1)
2861        n_in = np.zeros(1)
2862
2863        self.assertEqual(t_in.is_shared(), False)
2864
2865        self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
2866        self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
2867
2868        # FIXME: fix the following hack that makes `default_collate` believe
2869        #        that it is in a worker process (since it tests
2870        #        `get_worker_info() != None`), even though it is not.
2871        old = _utils.worker._worker_info
2872        try:
2873            _utils.worker._worker_info = "x"
2874            self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
2875            self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
2876        finally:
2877            _utils.worker._worker_info = old
2878
2879    def test_excessive_thread_creation_warning(self):
2880        with self.assertWarnsRegex(
2881            UserWarning,
2882            r"excessive worker creation might get DataLoader running slow or even freeze",
2883        ):
2884            dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
2885
2886
2887class TestDataLoaderDeviceType(TestCase):
2888    @parametrize(
2889        "context",
2890        [ctx for ctx in supported_multiprocessing_contexts if ctx is not None],
2891    )
2892    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
2893    def test_nested_tensor_multiprocessing(self, device, context):
2894        # The 'fork' multiprocessing context doesn't work for CUDA so skip it
2895        if "cuda" in device and context == "fork":
2896            # TODO: Skip this better in a better way when the test framework allows
2897            return
2898
2899        dataset = [
2900            torch.nested.nested_tensor([torch.randn(5)], device=device)
2901            for _ in range(10)
2902        ]
2903
2904        pin_memory_settings = [False]
2905        if device == "cpu" and torch.cuda.is_available():
2906            pin_memory_settings.append(True)
2907
2908        for pin_memory in pin_memory_settings:
2909            loader = torch.utils.data.DataLoader(
2910                dataset,
2911                batch_size=1,
2912                num_workers=4,
2913                collate_fn=_clone_collate,
2914                pin_memory=pin_memory,
2915                multiprocessing_context=context,
2916            )
2917
2918            for i, batch in enumerate(loader):
2919                self.assertEqual(batch[0], dataset[i])
2920
2921        # Error case: default collate_fn doesn't currently support batches of nested tensors.
2922        # Following the current semantics, we'd need to stack them, which isn't possible atm.
2923        with self.assertRaisesRegex(
2924            RuntimeError, "not currently supported by the default collate_fn"
2925        ):
2926            loader = torch.utils.data.DataLoader(
2927                dataset,
2928                batch_size=1,
2929                num_workers=4,
2930                multiprocessing_context=context,
2931            )
2932
2933            next(iter(loader))
2934
2935
2936class IntegrationTestDataLoaderDataPipe(TestCase):
2937    r"""
2938    Verify the behavior of a certain ``DataPipes`` with ``DataLoader``
2939    """
2940
2941    def test_shuffler_iterdatapipe(self):
2942        r"""
2943        Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader``
2944        to generate different seeds deterministically per epoch.
2945        """
2946        exp = list(range(100))
2947
2948        def _create_dp(buffer_size):
2949            input_ds = dp.iter.IterableWrapper(exp)
2950            return input_ds.shuffle(buffer_size=buffer_size).sharding_filter()
2951
2952        for bs in (5, 20, 33):
2953            # Test Deterministic
2954            for num_workers, pw in itertools.product((0, 1, 2), (True, False)):
2955                if num_workers == 0 and pw:
2956                    continue
2957
2958                shuffle_dp = _create_dp(bs)
2959
2960                mp_ctx = "spawn" if num_workers > 0 else None
2961                dl = DataLoader(
2962                    shuffle_dp,
2963                    num_workers=num_workers,
2964                    shuffle=True,
2965                    multiprocessing_context=mp_ctx,
2966                    persistent_workers=pw,
2967                )
2968
2969                # No seed
2970                dl_res_ns = list(dl)
2971                self.assertEqual(sorted(dl_res_ns), exp)
2972
2973                # Same seeds
2974                dl_res = []
2975                for epoch in range(2):
2976                    torch.manual_seed(123)
2977                    dl_res.append(list(dl))
2978                self.assertEqual(dl_res[0], dl_res[1])
2979                self.assertEqual(sorted(dl_res[0]), exp)
2980
2981                # Different seeds
2982                torch.manual_seed(321)
2983                dl_res.append(list(dl))
2984
2985                self.assertEqual(len(dl_res[0]), len(dl_res[2]))
2986                self.assertNotEqual(dl_res[0], dl_res[2])
2987                self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2]))
2988
2989                if dl._iterator is not None:
2990                    dl._iterator._shutdown_workers()
2991                    dl._iterator = None
2992                del dl
2993
2994
2995class StringDataset(Dataset):
2996    def __init__(self) -> None:
2997        self.s = "12345"
2998
2999    def __len__(self):
3000        return len(self.s)
3001
3002    def __getitem__(self, ndx):
3003        return (self.s[ndx], ndx)
3004
3005
3006@unittest.skipIf(
3007    TEST_WITH_TSAN,
3008    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3009    "fork is not supported. Dying (set die_after_fork=0 to override)",
3010)
3011class TestStringDataLoader(TestCase):
3012    def setUp(self):
3013        super().setUp()
3014        self.dataset = StringDataset()
3015
3016    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3017    def test_shuffle_pin_memory(self):
3018        loader = DataLoader(
3019            self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
3020        )
3021        for s, n in loader:
3022            self.assertIsInstance(s[0], str)
3023            self.assertTrue(n.is_pinned())
3024
3025
3026class DictDataset(Dataset):
3027    def __len__(self):
3028        return 4
3029
3030    def __getitem__(self, ndx):
3031        return {
3032            "a_tensor": torch.empty(4, 2).fill_(ndx),
3033            "another_dict": {"a_number": ndx},
3034        }
3035
3036
3037@unittest.skipIf(
3038    TEST_WITH_TSAN,
3039    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3040    "fork is not supported. Dying (set die_after_fork=0 to override)",
3041)
3042class TestDictDataLoader(TestCase):
3043    def setUp(self):
3044        super().setUp()
3045        self.dataset = DictDataset()
3046
3047    def test_sequential_batch(self):
3048        for persistent_workers in (False, True):
3049            if persistent_workers:
3050                loader = DataLoader(
3051                    self.dataset,
3052                    batch_size=2,
3053                    shuffle=False,
3054                    persistent_workers=persistent_workers,
3055                    num_workers=1,
3056                )
3057            else:
3058                loader = DataLoader(
3059                    self.dataset,
3060                    batch_size=2,
3061                    shuffle=False,
3062                    persistent_workers=persistent_workers,
3063                )
3064            batch_size = loader.batch_size
3065            for i, sample in enumerate(loader):
3066                idx = i * batch_size
3067                self.assertEqual(set(sample.keys()), {"a_tensor", "another_dict"})
3068                self.assertEqual(set(sample["another_dict"].keys()), {"a_number"})
3069
3070                t = sample["a_tensor"]
3071                self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
3072                self.assertTrue((t[0] == idx).all())
3073                self.assertTrue((t[1] == idx + 1).all())
3074
3075                n = sample["another_dict"]["a_number"]
3076                self.assertEqual(n.size(), torch.Size([batch_size]))
3077                self.assertEqual(n[0], idx)
3078                self.assertEqual(n[1], idx + 1)
3079
3080    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3081    def test_pin_memory(self):
3082        loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
3083        for sample in loader:
3084            self.assertTrue(sample["a_tensor"].is_pinned())
3085            self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
3086
3087    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3088    def test_pin_memory_device(self):
3089        loader = DataLoader(
3090            self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda"
3091        )
3092        for sample in loader:
3093            self.assertTrue(sample["a_tensor"].is_pinned(device="cuda"))
3094            self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda"))
3095
3096    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3097    def test_pin_memory_with_only_device(self):
3098        loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda")
3099        for sample in loader:
3100            self.assertFalse(sample["a_tensor"].is_pinned(device="cuda"))
3101            self.assertFalse(
3102                sample["another_dict"]["a_number"].is_pinned(device="cuda")
3103            )
3104
3105
3106class DummyDataset(torch.utils.data.Dataset):
3107    def __init__(self) -> None:
3108        self.data = list(range(10))
3109
3110    def __len__(self):
3111        return len(self.data)
3112
3113    def __getitem__(self, idx):
3114        if torch.is_tensor(idx):
3115            idx = idx.tolist()
3116        # The persistent workers always maintain the original
3117        # dataset through the dataloader lifetime
3118        # so the attributes will remain the same as the
3119        # first time the workers where spawned (dataloader iteration)
3120        assert self.start == 0
3121        return self.data[idx]
3122
3123
3124@unittest.skipIf(
3125    TEST_WITH_TSAN,
3126    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3127    "fork is not supported. Dying (set die_after_fork=0 to override)",
3128)
3129@unittest.skipIf(
3130    TEST_WITH_ASAN,
3131    "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
3132)
3133class TestDataLoaderPersistentWorkers(TestDataLoader):
3134    def setUp(self):
3135        super().setUp()
3136        self.persistent_workers = True
3137
3138    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3139    @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
3140    def test_fd_limit_exceeded(self):
3141        # See NOTE [ DataLoader on Linux and open files limit ]
3142        import subprocess
3143
3144        subprocess.check_output(
3145            [
3146                sys.executable,
3147                "-c",
3148                """\
3149import torch
3150import resource
3151from torch.utils.data import DataLoader, IterableDataset
3152
3153class RandomDataset(IterableDataset):
3154    def __init__(self, len, size):
3155        super(RandomDataset).__init__()
3156        self.len = len
3157        self.size = size
3158
3159    def __iter__(self):
3160        return self
3161
3162    def __next__(self):
3163        if self.len <= 0:
3164            raise StopIteration
3165        self.len -= 1
3166        return torch.randn(self.size)
3167
3168try:
3169    keep_fds_alive = []
3170    resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
3171    for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
3172                               num_workers=1, persistent_workers=True):
3173      random_t.max(dim=0)
3174      keep_fds_alive.append(random_t)
3175except RuntimeError as e:
3176    assert "ulimit -n" in str(e)
3177    assert "set_sharing_strategy" in str(e)
3178""",
3179            ]
3180        )
3181
3182    def test_dataset_not_reset(self):
3183        dataset = DummyDataset()
3184        pin_memory_configs = [False]
3185        if TEST_CUDA:
3186            pin_memory_configs.append(True)
3187        for pin_memory in pin_memory_configs:
3188            dataloader = self._get_data_loader(
3189                dataset, num_workers=2, pin_memory=pin_memory
3190            )
3191            dataset.start = 0
3192            for i in range(10):
3193                for x in dataloader:
3194                    pass
3195                # Changing the start value here doesn't have any effect in the dataset
3196                # cached by the workers. since they are not recreated between epochs
3197                # and can cache values safely
3198                dataset.start = i
3199
3200    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3201    @unittest.skipIf(IS_WINDOWS, "Needs fork")
3202    def test_early_exit(self):
3203        import subprocess
3204
3205        proc = subprocess.check_output(
3206            [
3207                sys.executable,
3208                "-c",
3209                """\
3210import torch
3211from torch.utils.data import DataLoader, IterableDataset
3212
3213class RandomDataset(IterableDataset):
3214    def __init__(self, len, size):
3215        super(RandomDataset).__init__()
3216        self.len = len
3217        self.size = size
3218
3219    def __iter__(self):
3220        return self
3221
3222    def __next__(self):
3223        if self.len <= 0:
3224            raise StopIteration
3225        self.len -= 1
3226        return torch.randn(self.size)
3227
3228if __name__ == '__main__':
3229    dl = DataLoader(
3230        RandomDataset(64, (28, 28)),
3231        batch_size=16,
3232        num_workers=2,
3233        pin_memory=True,
3234        persistent_workers=True,
3235        multiprocessing_context="fork",
3236    )
3237
3238    for _ in dl:
3239        break
3240""",
3241            ]
3242        )
3243
3244
3245class NamedTupleDataset(Dataset):
3246    from collections import namedtuple
3247
3248    Batch = namedtuple("Batch", ["data", "label", "random_tensor"])
3249    Data = namedtuple("Data", ["positive", "negative"])
3250
3251    def __len__(self):
3252        return 4
3253
3254    def __getitem__(self, ndx):
3255        return self.Batch(
3256            data=self.Data(positive=ndx, negative=-ndx),
3257            label=str(ndx),
3258            random_tensor=torch.randn(3),
3259        )
3260
3261
3262@unittest.skipIf(
3263    TEST_WITH_TSAN,
3264    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3265    "fork is not supported. Dying (set die_after_fork=0 to override)",
3266)
3267class TestNamedTupleDataLoader(TestCase):
3268    def setUp(self):
3269        super().setUp()
3270        self.dataset = NamedTupleDataset()
3271
3272    def test_dataloader_with_namedtuple(self):
3273        # auto-collation
3274        loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA)
3275        for batch in loader:
3276            self.assertIsInstance(batch, NamedTupleDataset.Batch)
3277            self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3278            self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3279            self.assertIsInstance(batch.data.positive, torch.Tensor)
3280            self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA)
3281        # no auto-collation
3282        loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA)
3283        for batch in loader:
3284            self.assertIsInstance(batch, NamedTupleDataset.Batch)
3285            self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3286            self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3287            self.assertNotIsInstance(batch.data.positive, torch.Tensor)
3288
3289
3290class SimpleCustomBatch:
3291    def __init__(self, data):
3292        transposed_data = list(zip(*data))
3293        self.inp = torch.stack(transposed_data[0], 0)
3294        self.tgt = torch.stack(transposed_data[1], 0)
3295
3296    def pin_memory(self):
3297        self.inp = self.inp.pin_memory()
3298        self.tgt = self.tgt.pin_memory()
3299        return self
3300
3301    def is_pinned(self):
3302        return self.inp.is_pinned() and self.tgt.is_pinned()
3303
3304
3305# Workaround for https://github.com/pytorch/pytorch/issues/50661
3306# Classes from  `__main__` can not be correctly unpickled from spawned module
3307# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming
3308self_module = __import__(os.path.splitext(os.path.basename(__file__))[0])
3309
3310
3311def collate_wrapper(batch):
3312    return self_module.SimpleCustomBatch(batch)
3313
3314
3315def collate_into_packed_sequence(batch):
3316    data = torch.stack([sample[0] for sample in batch], 1)
3317    t, b = data.size()
3318    lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3319    return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
3320
3321
3322def collate_into_packed_sequence_batch_first(batch):
3323    data = torch.stack([sample[0] for sample in batch], 0)
3324    b, t = data.size()
3325    lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3326    return torch.nn.utils.rnn.pack_padded_sequence(
3327        data, lengths, batch_first=True, enforce_sorted=False
3328    )
3329
3330
3331@unittest.skipIf(
3332    TEST_WITH_TSAN,
3333    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3334    "fork is not supported. Dying (set die_after_fork=0 to override)",
3335)
3336class TestCustomPinFn(TestCase):
3337    def setUp(self):
3338        super().setUp()
3339        inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3340        tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3341        self.dataset = TensorDataset(inps, tgts)
3342
3343    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3344    def test_custom_batch_pin(self):
3345        test_cases = [
3346            (collate_wrapper, self_module.SimpleCustomBatch),
3347            (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3348            (
3349                collate_into_packed_sequence_batch_first,
3350                torch.nn.utils.rnn.PackedSequence,
3351            ),
3352        ]
3353        for collate_fn, elem_cls in test_cases:
3354            loader = DataLoader(
3355                self.dataset, batch_size=2, collate_fn=collate_fn, pin_memory=True
3356            )
3357            for sample in loader:
3358                self.assertIsInstance(sample, elem_cls)
3359                self.assertTrue(sample.is_pinned())
3360
3361    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3362    def test_custom_batch_pin_worker(self):
3363        test_cases = [
3364            (collate_wrapper, self_module.SimpleCustomBatch),
3365            (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3366            (
3367                collate_into_packed_sequence_batch_first,
3368                torch.nn.utils.rnn.PackedSequence,
3369            ),
3370        ]
3371        for collate_fn, elem_cls in test_cases:
3372            loader = DataLoader(
3373                self.dataset,
3374                batch_size=2,
3375                collate_fn=collate_fn,
3376                pin_memory=True,
3377                num_workers=1,
3378            )
3379            for sample in loader:
3380                self.assertIsInstance(sample, elem_cls)
3381                self.assertTrue(sample.is_pinned())
3382
3383
3384class TestWorkerQueueDataset(Dataset):
3385    def __init__(self, data):
3386        self.data = data
3387        self.worker_id = None
3388
3389    def worker_init_fn(self, worker_id):
3390        self.worker_id = worker_id
3391
3392    def __getitem__(self, item):
3393        return self.worker_id, self.data[item]
3394
3395    def __len__(self):
3396        return len(self.data)
3397
3398
3399@unittest.skipIf(
3400    TEST_WITH_TSAN,
3401    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3402    "fork is not supported. Dying (set die_after_fork=0 to override)",
3403)
3404@unittest.skipIf(
3405    TEST_WITH_ASAN,
3406    "Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727",
3407)
3408class TestIndividualWorkerQueue(TestCase):
3409    def setUp(self):
3410        super().setUp()
3411        self.dataset = TestWorkerQueueDataset(list(range(128)))
3412
3413    def _run_ind_worker_queue_test(self, batch_size, num_workers):
3414        loader = DataLoader(
3415            self.dataset,
3416            batch_size=batch_size,
3417            shuffle=False,
3418            num_workers=num_workers,
3419            timeout=5,
3420            worker_init_fn=self.dataset.worker_init_fn,
3421        )
3422        current_worker_idx = 0
3423        for i, (worker_ids, sample) in enumerate(loader):
3424            self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
3425            self.assertEqual(
3426                sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size))
3427            )
3428            current_worker_idx += 1
3429            if current_worker_idx == num_workers:
3430                current_worker_idx = 0
3431
3432    def test_ind_worker_queue(self):
3433        max_num_workers = None
3434        if hasattr(os, "sched_getaffinity"):
3435            try:
3436                max_num_workers = len(os.sched_getaffinity(0))
3437            except Exception:
3438                pass
3439        if max_num_workers is None:
3440            cpu_count = os.cpu_count()
3441            if cpu_count is not None:
3442                # Use half number of CPUs
3443                max_num_workers = cpu_count // 2
3444
3445        if max_num_workers is None:
3446            max_num_workers = 1
3447
3448        for batch_size in (8, 16, 32, 64):
3449            for num_workers in range(0, min(6, max_num_workers)):
3450                self._run_ind_worker_queue_test(
3451                    batch_size=batch_size, num_workers=num_workers + 1
3452                )
3453
3454
3455class SetAffinityDataset(IterableDataset):
3456    def __iter__(self):
3457        torch.randperm(1)
3458        after = os.sched_getaffinity(0)
3459        return iter(after)
3460
3461
3462@unittest.skipIf(
3463    not hasattr(os, "sched_setaffinity"), "os.sched_setaffinity is not available"
3464)
3465class TestSetAffinity(TestCase):
3466    def test_set_affinity_in_worker_init(self):
3467        # Query the current affinity mask to avoid setting a disallowed one
3468        old_affinity = os.sched_getaffinity(0)
3469        if not old_affinity:
3470            self.skipTest("No affinity information")
3471        # Choose any
3472        expected_affinity = list(old_affinity)[-1]
3473
3474        def worker_set_affinity(_):
3475            os.sched_setaffinity(0, [expected_affinity])
3476
3477        dataset = SetAffinityDataset()
3478
3479        dataloader = torch.utils.data.DataLoader(
3480            dataset, num_workers=2, worker_init_fn=worker_set_affinity
3481        )
3482        for sample in dataloader:
3483            self.assertEqual(sample, [expected_affinity])
3484
3485
3486class ConvDataset(Dataset):
3487    def __init__(self) -> None:
3488        self.x = torch.ones(1, 1, 24000)
3489        # Call convolution on parent process
3490        self[0]
3491
3492    def __len__(self):
3493        return 1
3494
3495    def __getitem__(self, index):
3496        return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2))
3497
3498
3499@unittest.skipIf(IS_WINDOWS, "Needs fork")
3500@unittest.skipIf(
3501    TEST_WITH_ASAN,
3502    "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492",
3503)
3504class TestConvAfterFork(TestCase):
3505    # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565
3506    def test_conv_after_fork(self):
3507        loader = DataLoader(ConvDataset(), num_workers=1)
3508        for x in loader:
3509            self.assertEqual(x.shape, (1, 1, 1, 23999))
3510
3511
3512instantiate_device_type_tests(TestDataLoaderDeviceType, globals())
3513
3514
3515if __name__ == "__main__":
3516    run_tests()
3517