1# Owner(s): ["module: dataloader"] 2 3import ctypes 4import errno 5import faulthandler 6import functools 7import gc 8import itertools 9import math 10import operator 11import os 12import signal 13import sys 14import tempfile 15import time 16import unittest 17import warnings 18 19import torch 20import torch.utils.data.datapipes as dp 21from torch import multiprocessing as mp 22from torch._utils import ExceptionWrapper 23from torch.testing._internal.common_device_type import instantiate_device_type_tests 24from torch.testing._internal.common_utils import ( 25 IS_CI, 26 IS_JETSON, 27 IS_MACOS, 28 IS_SANDCASTLE, 29 IS_WINDOWS, 30 load_tests, 31 NO_MULTIPROCESSING_SPAWN, 32 parametrize, 33 run_tests, 34 skipIfNoDill, 35 skipIfRocm, 36 slowTest, 37 TEST_CUDA, 38 TEST_NUMPY, 39 TEST_WITH_ASAN, 40 TEST_WITH_ROCM, 41 TEST_WITH_TSAN, 42 TestCase, 43) 44from torch.utils.data import ( 45 _utils, 46 ChainDataset, 47 ConcatDataset, 48 DataLoader, 49 Dataset, 50 IterableDataset, 51 IterDataPipe, 52 StackDataset, 53 Subset, 54 TensorDataset, 55) 56from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 57from torch.utils.data.datapipes.iter import IterableWrapper 58from torch.utils.data.dataset import random_split 59 60 61try: 62 import psutil 63 64 HAS_PSUTIL = True 65except ModuleNotFoundError: 66 HAS_PSUTIL = False 67 psutil = None 68 err_msg = ( 69 "psutil not found. Some critical data loader tests relying on it " 70 "(e.g., TestDataLoader.test_proper_exit) will not run." 71 ) 72 if IS_CI: 73 raise ModuleNotFoundError(err_msg) from None 74 else: 75 warnings.warn(err_msg) 76 77 78try: 79 import numpy as np 80 81 HAS_NUMPY = True 82except ModuleNotFoundError: 83 HAS_NUMPY = False 84 np = None 85skipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy") 86 87# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 88# sharding on sandcastle. This line silences flake warnings 89load_tests = load_tests 90 91TEST_CUDA_IPC = ( 92 torch.cuda.is_available() 93 and sys.platform != "darwin" 94 and sys.platform != "win32" 95 and not IS_JETSON 96 and not TEST_WITH_ROCM 97) # https://github.com/pytorch/pytorch/issues/90940 98 99TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1 100 101if not NO_MULTIPROCESSING_SPAWN: 102 # We want to use `spawn` if able because some of our tests check that the 103 # data loader terminiates gracefully. To prevent hanging in the testing 104 # process, such data loaders are run in a separate subprocess. 105 # 106 # We also want to test the `pin_memory=True` configuration, thus `spawn` is 107 # required to launch such processes and they initialize the CUDA context. 108 # 109 # Mixing different start method is a recipe for disaster (e.g., using a fork 110 # `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally 111 # to avoid bugs. 112 # 113 # Get a multiprocessing context because some test / third party library will 114 # set start_method when imported, and setting again triggers `RuntimeError`. 115 mp = mp.get_context(method="spawn") 116 117 118# 60s of timeout? 119# Yes, in environments where physical CPU resources are shared, e.g., CI, the 120# time for a inter-process communication can be highly varying. With 15~17s of 121# timeout, we have observed flakiness in some CI builds (see 122# pytorch/pytorch#14501, pytorch/pytorch#16608). We follow the CPython 123# multiprocessing setup and set the timeout to 60s here: 124# 125# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73 126JOIN_TIMEOUT = 60.0 # seconds 127 128 129supported_multiprocessing_contexts = [None] + list( 130 torch.multiprocessing.get_all_start_methods() 131) 132 133 134# collate_fn that returns the batch cloned; defined globally here for pickle purposes. 135def _clone_collate(b): 136 return [x.clone() for x in b] 137 138 139@unittest.skipIf( 140 TEST_WITH_TSAN, 141 "Fails with TSAN with the following error: starting new threads after multi-threaded " 142 "fork is not supported. Dying (set die_after_fork=0 to override)", 143) 144class TestDatasetRandomSplit(TestCase): 145 def test_lengths_must_equal_dataset_size(self): 146 with self.assertRaises(ValueError): 147 random_split([1, 2, 3, 4], [1, 2]) 148 149 def test_splits_have_correct_size(self): 150 splits = random_split([1, 2, 3, 4, 5, 6], [2, 4]) 151 self.assertEqual(len(splits), 2) 152 self.assertEqual(len(splits[0]), 2) 153 self.assertEqual(len(splits[1]), 4) 154 155 splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5]) 156 self.assertEqual(len(splits), 2) 157 self.assertEqual(len(splits[0]), 3) 158 self.assertEqual(len(splits[1]), 3) 159 160 # Odd size splits 161 self.assertEqual( 162 len( 163 random_split( 164 range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1) 165 ) 166 ), 167 2, 168 ) 169 170 # Odd sized round-robin splits 171 splits = random_split( 172 range(106), [0.1, 0.2, 0.3, 0.4], generator=torch.Generator().manual_seed(1) 173 ) 174 self.assertEqual(len(splits[0]), 11) 175 self.assertEqual(len(splits[1]), 22) 176 self.assertEqual(len(splits[2]), 31) 177 self.assertEqual(len(splits[3]), 42) 178 179 def test_splits_are_mutually_exclusive(self): 180 data = [5, 2, 3, 4, 1, 6] 181 splits = random_split(data, [2, 4]) 182 all_values = [] 183 all_values.extend(list(splits[0])) 184 all_values.extend(list(splits[1])) 185 data.sort() 186 all_values.sort() 187 self.assertListEqual(data, all_values) 188 189 splits = random_split(data, [0.33, 0.67]) 190 all_values = [] 191 all_values.extend(list(splits[0])) 192 all_values.extend(list(splits[1])) 193 data.sort() 194 all_values.sort() 195 self.assertListEqual(data, all_values) 196 197 data = [1, 2, 3, 4] 198 splits = random_split(data, [0.25, 0.75]) 199 all_values = [] 200 all_values.extend(list(splits[0])) 201 all_values.extend(list(splits[1])) 202 data.sort() 203 all_values.sort() 204 self.assertListEqual(data, all_values) 205 206 def test_splits_indexing_type(self): 207 r"""Indices generated by random_split 208 should be of integer type 209 """ 210 211 class CustomDataset: 212 def __init__(self, test_object, custom_list): 213 self.data = custom_list 214 self.test_object = test_object 215 216 def __getitem__(self, key): 217 self.test_object.assertEqual(type(key), int) 218 return self.data[key] 219 220 def __len__(self): 221 return len(self.data) 222 223 x = [1, 2, 3, 4, 5] 224 dataset = CustomDataset(self, x) 225 dataset = random_split(dataset, [5])[0] 226 data_loader = DataLoader(dataset) 227 for batch in data_loader: 228 pass 229 230 # fractional splitting 231 dataset = CustomDataset(self, x) 232 dataset = random_split(dataset, [1.0])[0] 233 data_loader = DataLoader(dataset) 234 for batch in data_loader: 235 pass 236 237 def test_splits_reproducibility(self): 238 self.assertEqual( 239 [ 240 list(x) 241 for x in random_split( 242 range(10), [3, 7], generator=torch.Generator().manual_seed(1) 243 ) 244 ], 245 [[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]], 246 ) 247 self.assertEqual( 248 random_split( 249 range(100), [60, 40], generator=torch.Generator().manual_seed(42) 250 ), 251 random_split( 252 range(100), [60, 40], generator=torch.Generator().manual_seed(42) 253 ), 254 ) 255 self.assertEqual( 256 random_split( 257 range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42) 258 ), 259 random_split( 260 range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42) 261 ), 262 ) 263 self.assertEqual( 264 random_split( 265 range(100), 266 [0.33, 0.33, 0.34], 267 generator=torch.Generator().manual_seed(42), 268 ), 269 random_split( 270 range(100), 271 [0.33, 0.33, 0.34], 272 generator=torch.Generator().manual_seed(42), 273 ), 274 ) 275 276 def test_incomplete_fractional_splits(self): 277 with self.assertRaises(ValueError): 278 # should raise since the sum of fractions is not 1 279 random_split([1, 2, 3, 4], [0.1]) 280 281 with self.assertRaises(ValueError): 282 # should raise since fraction > 1 283 random_split([1, 2, 3, 4], [1.1]) 284 285 def test_splits_generator(self): 286 # A random_split without a specific generator should affect the default one 287 state = torch.get_rng_state() 288 a = torch.rand(10) 289 torch.set_rng_state(state) 290 random_split(range(10), [5, 5]) 291 b = torch.rand(10) 292 self.assertNotEqual(a, b) 293 294 # A random_split with a specific generator should not affect the default one 295 state = torch.get_rng_state() 296 a = torch.rand(10) 297 torch.set_rng_state(state) 298 random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42)) 299 b = torch.rand(10) 300 self.assertEqual(a, b) 301 302 def test_slicing_of_subset_of_dataset(self): 303 # Testing slicing a subset initialized with a dataset 304 dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5])) 305 subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4]) 306 self.assertEqual(subset_of_dataset[:], dataset[:]) 307 self.assertEqual(subset_of_dataset[1:2], dataset[1:2]) 308 self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2]) 309 # Testing slicing of subset from random split 310 subset1, subset2 = random_split(dataset, [3, 2]) 311 self.assertEqual(subset1[:], dataset[subset1.indices[:]]) 312 self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]]) 313 self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]]) 314 315 def test_slicing_of_subset_of_subset(self): 316 # Testing slicing a subset initialized with a subset 317 dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5])) 318 subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4]) 319 subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4]) 320 self.assertEqual(subset_of_subset[:], dataset[:]) 321 self.assertEqual(subset_of_subset[0:2], dataset[0:2]) 322 self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2]) 323 # Testing slicing of subset of subset from random split 324 subset1, subset2 = random_split(dataset, [4, 1]) 325 subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1]) 326 idx = [subset1.indices[i] for i in subset_of_subset1.indices] 327 self.assertEqual(subset_of_subset1[:], dataset[idx.copy()]) 328 self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]]) 329 self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]]) 330 331 332class CUDACountingDataset(Dataset): 333 def __init__(self, n): 334 super().__init__() 335 self.n = n 336 337 def __getitem__(self, i): 338 return torch.as_tensor(i, device="cuda") 339 340 def __len__(self): 341 return self.n 342 343 344class CountingDataset(Dataset): 345 def __init__(self, n): 346 super().__init__() 347 self.n = n 348 349 def __getitem__(self, i): 350 return i 351 352 def __len__(self): 353 return self.n 354 355 356class CountingIterableDataset(IterableDataset): 357 def __init__(self, n): 358 super().__init__() 359 self.n = n 360 361 def __iter__(self): 362 return iter(range(self.n)) 363 364 def __len__(self): 365 return self.n 366 367 368@unittest.skipIf( 369 TEST_WITH_TSAN, 370 "Fails with TSAN with the following error: starting new threads after multi-threaded " 371 "fork is not supported. Dying (set die_after_fork=0 to override)", 372) 373class TestTensorDataset(TestCase): 374 def test_len(self): 375 source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15)) 376 self.assertEqual(len(source), 15) 377 378 def test_getitem(self): 379 t = torch.randn(15, 10, 2, 3, 4, 5) 380 l = torch.randn(15, 10) 381 source = TensorDataset(t, l) 382 for i in range(15): 383 self.assertEqual(t[i], source[i][0]) 384 self.assertEqual(l[i], source[i][1]) 385 386 def test_getitem_1d(self): 387 t = torch.randn(15) 388 l = torch.randn(15) 389 source = TensorDataset(t, l) 390 for i in range(15): 391 self.assertEqual(t[i], source[i][0]) 392 self.assertEqual(l[i], source[i][1]) 393 394 def test_single_tensor(self): 395 t = torch.randn(5, 10) 396 source = TensorDataset(t) 397 self.assertEqual(len(source), 5) 398 for i in range(5): 399 self.assertEqual(t[i], source[i][0]) 400 401 def test_many_tensors(self): 402 t0 = torch.randn(5, 10, 2, 3, 4, 5) 403 t1 = torch.randn(5, 10) 404 t2 = torch.randn(5, 10, 2, 5) 405 t3 = torch.randn(5, 10, 3, 7) 406 source = TensorDataset(t0, t1, t2, t3) 407 self.assertEqual(len(source), 5) 408 for i in range(5): 409 self.assertEqual(t0[i], source[i][0]) 410 self.assertEqual(t1[i], source[i][1]) 411 self.assertEqual(t2[i], source[i][2]) 412 self.assertEqual(t3[i], source[i][3]) 413 414 415@unittest.skipIf( 416 TEST_WITH_TSAN, 417 "Fails with TSAN with the following error: starting new threads after multi-threaded " 418 "fork is not supported. Dying (set die_after_fork=0 to override)", 419) 420class TestStackDataset(TestCase): 421 def test_empty(self): 422 with self.assertRaisesRegex( 423 ValueError, "At least one dataset should be passed" 424 ): 425 StackDataset() 426 427 def test_mixed(self): 428 with self.assertRaisesRegex(ValueError, "Supported either"): 429 StackDataset( 430 TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15)) 431 ) 432 433 def test_size_mismatch(self): 434 with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"): 435 StackDataset( 436 TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15)) 437 ) 438 with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"): 439 StackDataset( 440 a=TensorDataset(torch.randn(15, 10)), 441 b=TensorDataset(torch.randn(10, 15)), 442 ) 443 444 def test_len(self): 445 source = StackDataset( 446 TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15)) 447 ) 448 self.assertEqual(len(source), 15) 449 source = StackDataset(TensorDataset(torch.randn(15, 10))) 450 self.assertEqual(len(source), 15) 451 source = StackDataset( 452 a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15)) 453 ) 454 self.assertEqual(len(source), 15) 455 source = StackDataset(a=TensorDataset(torch.randn(15, 10))) 456 self.assertEqual(len(source), 15) 457 458 def test_single(self): 459 t = TensorDataset(torch.randn(15, 10)) 460 source = StackDataset(t) 461 for i in range(15): 462 self.assertEqual(t[i], source[i][0]) 463 source = StackDataset(a=t) 464 for i in range(15): 465 self.assertEqual(t[i], source[i]["a"]) 466 467 def test_getitem(self): 468 t = TensorDataset(torch.randn(15, 10)) 469 l = TensorDataset(torch.randn(15, 5, 4)) 470 source = StackDataset(t, l) 471 for i in range(15): 472 self.assertEqual(t[i], source[i][0]) 473 self.assertEqual(l[i], source[i][1]) 474 source = StackDataset(a=t, b=l) 475 for i in range(15): 476 self.assertEqual(t[i], source[i]["a"]) 477 self.assertEqual(l[i], source[i]["b"]) 478 479 def test_getitems(self): 480 class GetItemsDataset(Dataset): 481 def __init__(self) -> None: 482 self.data = torch.randn(4) 483 484 def __getitem__(self, item): 485 return self.data[item] 486 487 def __getitems__(self, items): 488 return self.data[items] 489 490 def __len__(self): 491 return 4 492 493 t = GetItemsDataset() 494 l = [1, 2, 3, 4] 495 496 source = StackDataset(t, l) 497 batch = source.__getitems__([0, 1, 2, 3]) 498 for i in range(4): 499 self.assertEqual(t[i], batch[i][0]) 500 self.assertEqual(l[i], batch[i][1]) 501 502 source = StackDataset(t=t, l=l) 503 batch = source.__getitems__([0, 1, 2, 3]) 504 for i in range(4): 505 self.assertEqual(t[i], batch[i]["t"]) 506 self.assertEqual(l[i], batch[i]["l"]) 507 508 def test_getitems_raises_index_error(self): 509 class GetItemsDataset(Dataset): 510 def __init__(self) -> None: 511 self.data = torch.randn(4) 512 513 def __getitem__(self, item): 514 return self.data[item] 515 516 def __getitems__(self, items): 517 return self.data[items] 518 519 def __len__(self): 520 return 4 521 522 t = GetItemsDataset() 523 l = [1, 2, 3, 4] 524 525 source = StackDataset(t, l) 526 527 with self.assertRaises(IndexError): 528 source.__getitems__([0, 4]) 529 530 def test_getitems_value_error(self): 531 class GetItemsDataset(Dataset): 532 def __init__(self) -> None: 533 self.data = torch.randn(4) 534 535 def __getitem__(self, item): 536 return self.data[item] 537 538 def __getitems__(self, items): 539 return self.data[items][:-1] # return less 540 541 def __len__(self): 542 return 4 543 544 t = GetItemsDataset() 545 l = [1, 2, 3, 4] 546 547 source = StackDataset(t, l) 548 549 with self.assertRaisesRegex( 550 ValueError, "Nested dataset's output size mismatch. Expected 4, got 3" 551 ): 552 source.__getitems__([0, 1, 2, 3]) 553 554 555@unittest.skipIf( 556 TEST_WITH_TSAN, 557 "Fails with TSAN with the following error: starting new threads after multi-threaded " 558 "fork is not supported. Dying (set die_after_fork=0 to override)", 559) 560class TestConcatDataset(TestCase): 561 def test_concat_two_singletons(self): 562 result = ConcatDataset([[0], [1]]) 563 self.assertEqual(2, len(result)) 564 self.assertEqual(0, result[0]) 565 self.assertEqual(1, result[1]) 566 567 def test_concat_two_non_singletons(self): 568 result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) 569 self.assertEqual(10, len(result)) 570 self.assertEqual(0, result[0]) 571 self.assertEqual(5, result[5]) 572 573 def test_concat_two_non_singletons_with_empty(self): 574 # Adding an empty dataset somewhere is correctly handled 575 result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]]) 576 self.assertEqual(10, len(result)) 577 self.assertEqual(0, result[0]) 578 self.assertEqual(5, result[5]) 579 580 def test_concat_raises_index_error(self): 581 result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) 582 with self.assertRaises(IndexError): 583 # this one goes to 11 584 result[11] 585 586 def test_add_dataset(self): 587 d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 588 d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 589 d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 590 result = d1 + d2 + d3 591 self.assertEqual(21, len(result)) 592 self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum()) 593 self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum()) 594 self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum()) 595 596 def test_iterable_dataset_err(self): 597 d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) 598 it1 = CountingIterableDataset(5) 599 it2 = CountingIterableDataset(10) 600 601 with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): 602 ConcatDataset([d1, it2, it1]) 603 604 with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): 605 ConcatDataset([it2]) 606 607 with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): 608 ConcatDataset([it1, d1]) 609 610 611# takes in dummy var so this can also be used as a `worker_init_fn` 612def set_faulthander_if_available(_=None): 613 faulthandler.enable(sys.__stderr__) 614 if not IS_WINDOWS: 615 # windows does not have faulthandler.register 616 # chain=False prevents the default behavior of killing the process 617 faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False) 618 619 620set_faulthander_if_available() 621 622 623# Process `pid` must have called `set_faulthander_if_available` 624def print_traces_of_all_threads(pid): 625 if not IS_WINDOWS: 626 # use the custom signal if available 627 os.kill(pid, signal.SIGUSR1) 628 else: 629 # otherwise we can still use the handler given by faulthandler.enable() 630 # at the cost of killing the process. 631 os.kill(pid, signal.SIGSEGV) 632 633 # wait in parent process to give subprocess some time to print 634 time.sleep(5) 635 636 637# The following `ErrorTrackingProcess` stores the first encountered exception in 638# its `.exception` attribute. 639# Inspired by https://stackoverflow.com/a/33599967 640class ErrorTrackingProcess(mp.Process): 641 # Why no *args? 642 # py2 doesn't support def fn(x, *args, key=val, **kwargs) 643 # Setting disable_stderr=True may generate a lot of unrelated error outputs 644 # but could be helpful for debugging. 645 def __init__(self, disable_stderr=True, **kwargs): 646 super().__init__(**kwargs) 647 self._pconn, self._cconn = mp.Pipe() 648 self._exception = None 649 self.disable_stderr = disable_stderr 650 651 def run(self): 652 set_faulthander_if_available() 653 if self.disable_stderr: 654 # Disable polluting stderr with errors that are supposed to happen. 655 with open(os.devnull, "w") as devnull: 656 os.dup2(devnull.fileno(), sys.stderr.fileno()) 657 try: 658 super().run() 659 self._cconn.send(None) 660 except Exception: 661 self._cconn.send(ExceptionWrapper(sys.exc_info())) 662 raise 663 664 def print_traces_of_all_threads(self): 665 assert ( 666 self.is_alive() 667 ), "can only use print_traces_of_all_threads if the process is alive" 668 assert ( 669 not self.disable_stderr 670 ), "do not disable stderr if you use print_traces_of_all_threads" 671 # On platforms without `SIGUSR1`, `set_faulthander_if_available` sets 672 # `faulthandler.enable()`, and `print_traces_of_all_threads` may kill 673 # the process. So let's poll the exception first 674 _ = self.exception 675 print_traces_of_all_threads(self.pid) 676 677 @property 678 def exception(self): 679 if self._pconn.poll(): 680 self._exception = self._pconn.recv() 681 if self._exception is None: 682 return None 683 else: 684 return self._exception.exc_type(self._exception.exc_msg) 685 686 # ESRCH means that os.kill can't finds alive proc 687 def send_signal(self, signum, ignore_ESRCH=False): 688 try: 689 os.kill(self.pid, signum) 690 except OSError as e: 691 if not ignore_ESRCH or e.errno != errno.ESRCH: 692 raise 693 694 695class ErrorDataset(Dataset): 696 def __init__(self, size): 697 self.size = size 698 699 def __len__(self): 700 return self.size 701 702 703class SegfaultDataset(Dataset): 704 def __init__(self, size): 705 self.size = size 706 707 def __getitem__(self, idx): 708 return ctypes.string_at(0) 709 710 def __len__(self): 711 return self.size 712 713 714class SleepDataset(Dataset): 715 def __init__(self, size, sleep_sec): 716 self.size = size 717 self.sleep_sec = sleep_sec 718 self.sleeped = False 719 720 def __getitem__(self, idx): 721 if not self.sleeped: 722 time.sleep(self.sleep_sec) 723 self.sleeped = True 724 return idx 725 726 def __len__(self): 727 return self.size 728 729 730class SeedDataset(Dataset): 731 def __init__(self, size): 732 self.size = size 733 734 def __getitem__(self, idx): 735 return torch.initial_seed() 736 737 def __len__(self): 738 return self.size 739 740 741class WorkerSpecificIterableDataset(IterableDataset): 742 def __init__(self, sizes_for_all_workers): 743 self.sizes_for_all_workers = sizes_for_all_workers 744 745 def __iter__(self): 746 worker_info = torch.utils.data.get_worker_info() 747 assert worker_info is not None 748 return iter(range(self.sizes_for_all_workers[worker_info.id])) 749 750 def __len__(self): 751 return sum(self.sizes_for_all_workers) 752 753 754# Inspired by https://stackoverflow.com/a/26703365 755# If all workers will call `sync_once`, they will be blocked until all workers 756# reach the call (i.e., acting like a barrier). 757# This can be used to ensure that each worker at least processes one data. 758class SynchronizedDataset(Dataset): 759 def __init__(self, size, batch_size, num_workers): 760 assert size >= num_workers * batch_size 761 self.count = mp.Value("i", 0, lock=True) 762 self.barrier = mp.Semaphore(0) 763 self.num_workers = num_workers 764 self.size = size 765 766 def sync_once(self): 767 with self.count.get_lock(): 768 self.count.value += 1 769 if self.count.value == self.num_workers: 770 self.barrier.release() 771 self.barrier.acquire() 772 self.barrier.release() 773 774 def __getitem__(self, idx): 775 raise NotImplementedError 776 777 def __len__(self): 778 return self.size 779 780 781class EmptyTensorDataset(torch.utils.data.Dataset): 782 def __init__(self, len): 783 self.len = len 784 785 def __len__(self): 786 return self.len 787 788 def __getitem__(self, any): 789 return torch.empty(0) 790 791 792class SynchronizedSeedDataset(SynchronizedDataset): 793 def __getitem__(self, idx): 794 self.sync_once() 795 return torch.initial_seed() 796 797 798def _test_timeout(persistent_workers): 799 dataset = SleepDataset(10, 3) 800 dataloader = DataLoader( 801 dataset, 802 batch_size=2, 803 num_workers=2, 804 timeout=1, 805 persistent_workers=persistent_workers, 806 ) 807 _ = next(iter(dataloader)) 808 809 810def _test_timeout_pin_memory(persistent_workers): 811 dataset = SleepDataset(10, 3) 812 dataloader = DataLoader( 813 dataset, 814 batch_size=2, 815 num_workers=2, 816 timeout=1, 817 pin_memory=True, 818 persistent_workers=persistent_workers, 819 ) 820 _ = next(iter(dataloader)) 821 822 823def _test_large_sampler_indices(persistent_workers): 824 # See 825 # test_large_sampler_indices 826 # https://github.com/pytorch/pytorch/issues/48666 827 828 dataloader = torch.utils.data.DataLoader( 829 EmptyTensorDataset(10000000), 830 batch_size=40960, 831 persistent_workers=persistent_workers, 832 num_workers=1, 833 ) 834 835 it = iter(dataloader) 836 837 for x in it: 838 assert x.numel() == 0 839 raise RuntimeError("My Error") 840 841 842def disable_stderr(worker_id): 843 r""" 844 Avoids printing "ERROR: Unexpected segmentation fault encountered in worker." 845 from workers. Since worker signal handler prints with low-level write(), 846 this has to be done on OS level via dup. 847 848 This is used as worker_init_fn for test_segfault. 849 """ 850 sys.stderr.flush() # flush library buffers that dup2 knows nothing about 851 # Can't use a with-block because otherwise the fd will be closed when this 852 # function ends. 853 with open(os.devnull, "w") as devnull: 854 os.dup2(devnull.fileno(), sys.stderr.fileno()) 855 856 857def _test_segfault(): 858 dataset = SegfaultDataset(10) 859 dataloader = DataLoader( 860 dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr 861 ) 862 _ = next(iter(dataloader)) 863 864 865def _test_no_segfault(): 866 dataset = [1, 2, 3] 867 num_threads = torch.get_num_threads() 868 if num_threads < 4: 869 torch.set_num_threads(4) 870 else: 871 torch.set_num_threads(num_threads) 872 mp_ctx = torch.multiprocessing.get_context(method="fork") 873 dataloader = DataLoader( 874 dataset, 875 num_workers=1, 876 worker_init_fn=disable_stderr, 877 multiprocessing_context=mp_ctx, 878 ) 879 _ = next(iter(dataloader)) 880 881 882class TestProperExitDataset(Dataset): 883 def __init__(self, size, error_event): 884 self.size = size 885 self.error_event = error_event 886 887 def __len__(self): 888 return self.size 889 890 def __getitem__(self, idx): 891 worker_info = torch.utils.data.get_worker_info() 892 if ( 893 self.error_event is not None 894 and self.error_event.is_set() 895 and worker_info.id == worker_info.num_workers - 1 896 ): 897 # only error in the last worker 898 raise RuntimeError("Worker error") 899 return torch.tensor([idx]) 900 901 902class TestProperExitIterableDataset(IterableDataset): 903 def __init__(self, size, error_event): 904 self.error_event = error_event 905 self.size = size 906 self.remaining = size 907 908 def __len__(self): 909 return self.size 910 911 def __iter__(self): 912 return self 913 914 def __next__(self): 915 worker_info = torch.utils.data.get_worker_info() 916 if ( 917 self.error_event is not None 918 and self.error_event.is_set() 919 and worker_info.id == worker_info.num_workers - 1 920 ): 921 # only error in the last worker 922 raise RuntimeError("Worker error") 923 self.remaining -= 1 924 if self.remaining < 0: 925 raise StopIteration 926 return torch.tensor(-1000) 927 928 929# See TestDataLoader.test_proper_exit for usage 930def _test_proper_exit( 931 is_iterable_dataset, 932 use_workers, 933 pin_memory, 934 exit_method, 935 hold_iter_reference, 936 loader_setup_event, 937 tester_setup_event, 938 persistent_workers, 939): 940 num_workers = 2 if use_workers else 0 941 942 if exit_method == "worker_error" or exit_method == "worker_kill": 943 assert use_workers is True 944 945 if exit_method == "worker_error": 946 worker_error_event = mp.Event() 947 else: 948 worker_error_event = None 949 950 if is_iterable_dataset: 951 ds = TestProperExitIterableDataset(7, worker_error_event) 952 else: 953 ds = TestProperExitDataset(12, worker_error_event) 954 955 loader = DataLoader( 956 ds, 957 batch_size=1, 958 shuffle=False, 959 num_workers=num_workers, 960 pin_memory=pin_memory, 961 worker_init_fn=set_faulthander_if_available, 962 persistent_workers=persistent_workers, 963 ) 964 965 error_it = 2 966 967 if use_workers: 968 # 2 is the magical per-worker prefetch number... 969 # FIXME: change this after the number becomes configurable. 970 if is_iterable_dataset: 971 assert len(ds) * num_workers > (error_it + 2 + 1) 972 else: 973 assert len(loader) > (error_it + 2 + 1) * num_workers 974 else: 975 if is_iterable_dataset: 976 assert len(ds) > error_it + 1 977 else: 978 assert len(loader) > error_it + 1 979 980 it = iter(loader) 981 if use_workers: 982 workers = it._workers 983 984 def kill_pid(pid): 985 psutil_p = psutil.Process(pid) 986 psutil_p.kill() 987 psutil_p.wait(JOIN_TIMEOUT) 988 assert not psutil_p.is_running() 989 990 for i, _ in enumerate(it): 991 if i == 0: 992 if not hold_iter_reference: 993 del it 994 del loader 995 loader_setup_event.set() 996 tester_setup_event.wait() 997 # ensure that the workers are still alive 998 if use_workers: 999 for w in workers: 1000 assert w.is_alive() 1001 if worker_error_event is not None: 1002 worker_error_event.set() 1003 1004 if i == error_it: 1005 if exit_method == "loader_error": 1006 raise RuntimeError("Loader error") 1007 elif exit_method == "loader_kill": 1008 kill_pid(os.getpid()) 1009 elif exit_method == "worker_kill": 1010 kill_pid(workers[-1].pid) # kill last worker 1011 1012 if not hold_iter_reference: 1013 # Tries to trigger the __del__ clean-up rather than the automatic 1014 # exiting of daemonic children. Technically it should be automatically 1015 # triggered, but I don't want to rely on the implementation detail of 1016 # Python gc. 1017 gc.collect() 1018 1019 1020class TestWorkerInfoDataset(SynchronizedDataset): 1021 def __getitem__(self, idx): 1022 self.sync_once() 1023 return torch.tensor(self.value) 1024 1025 1026# Should be used as worker_init_fn with TestWorkerInfoDataset. 1027# See _test_get_worker_info below for usage. 1028def _test_worker_info_init_fn(worker_id): 1029 worker_info = torch.utils.data.get_worker_info() 1030 assert ( 1031 worker_id == worker_info.id 1032 ), "worker_init_fn and worker_info should have consistent id" 1033 assert ( 1034 worker_id < worker_info.num_workers 1035 ), "worker_init_fn and worker_info should have valid id" 1036 assert ( 1037 worker_info.seed == torch.initial_seed() 1038 ), "worker_init_fn and worker_info should have consistent seed" 1039 dataset = worker_info.dataset 1040 assert isinstance( 1041 dataset, TestWorkerInfoDataset 1042 ), "worker_info should have correct dataset copy" 1043 assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy" 1044 # test that WorkerInfo attributes are read-only 1045 try: 1046 worker_info.id = 3999 1047 except RuntimeError as e: 1048 assert str(e) == "Cannot assign attributes to WorkerInfo objects" 1049 try: 1050 worker_info.a = 3 1051 except RuntimeError as e: 1052 assert str(e) == "Cannot assign attributes to WorkerInfo objects" 1053 for k in ["id", "num_workers", "seed", "dataset"]: 1054 assert f"{k}=" in repr(worker_info) 1055 dataset.value = [worker_id, os.getpid()] 1056 1057 1058def _test_get_worker_info(): 1059 # get_worker_info returns None in main proc 1060 assert torch.utils.data.get_worker_info() is None 1061 num_workers = 2 1062 batch_size = 2 1063 dataset = TestWorkerInfoDataset(6, batch_size, num_workers) 1064 dataloader = DataLoader( 1065 dataset, 1066 batch_size=batch_size, 1067 num_workers=num_workers, 1068 worker_init_fn=_test_worker_info_init_fn, 1069 ) 1070 it = iter(dataloader) 1071 data = [] 1072 for d in it: 1073 data.append(d) # noqa: PERF402 1074 worker_pids = [w.pid for w in it._workers] 1075 data = torch.cat(data, 0) 1076 for d in data: 1077 # each `d` is a [worker_id, worker_pid] pair, which is set in 1078 # _test_worker_info_init_fn 1079 assert d[1] == worker_pids[d[0]] 1080 # get_worker_info returns None in main proc after data loading 1081 assert torch.utils.data.get_worker_info() is None 1082 # main proc dataset was never assigned this attribute 1083 assert not hasattr(dataset, "value") 1084 try: 1085 _ = dataset[0] 1086 except AttributeError: 1087 return 1088 raise RuntimeError("Expected AttributeError") 1089 1090 1091# test custom init function 1092def init_fn(worker_id): 1093 torch.manual_seed(12345) 1094 1095 1096# used with test_error_in_init 1097class ErrorIterableDataset(IterableDataset): 1098 def __iter__(self): 1099 raise RuntimeError("Error in __iter__") 1100 1101 1102# used with test_error_in_init 1103def error_worker_init_fn(_): 1104 raise RuntimeError("Error in worker_init_fn") 1105 1106 1107class BulkLoadingDataset(Dataset): 1108 def __init__(self, length): 1109 self.length = length 1110 1111 def __getitem__(self, indices): 1112 assert isinstance(indices, (list, tuple)) 1113 return torch.as_tensor(indices) 1114 1115 def __len__(self): 1116 return self.length 1117 1118 1119class BulkLoadingSampler(torch.utils.data.Sampler): 1120 def __init__(self, dataset, batch_size): 1121 self.dataset = dataset 1122 self.batch_size = batch_size 1123 1124 def __iter__(self): 1125 for x in torch.randperm(len(self.dataset)).split(self.batch_size): 1126 yield x.tolist() 1127 1128 def __len__(self): 1129 return int(math.ceil(len(self.dataset) / float(self.batch_size))) 1130 1131 1132class TestMultiEpochDataset(IterableDataset): 1133 def __init__(self, length): 1134 self.length = length 1135 1136 def __iter__(self): 1137 worker_info = torch.utils.data.get_worker_info() 1138 assert worker_info is not None 1139 worker_id = worker_info.id 1140 for idx in range(self.length // worker_info.num_workers): 1141 yield worker_id 1142 1143 def __len__(self): 1144 return self.length 1145 1146 1147class CustomList(list): 1148 pass 1149 1150 1151class CustomDict(dict): 1152 pass 1153 1154 1155def row_processor(row): 1156 return np.add(row, 1) 1157 1158 1159def filter_len(row): 1160 return len(row) == 4 1161 1162 1163@unittest.skipIf( 1164 TEST_WITH_TSAN, 1165 "Fails with TSAN with the following error: starting new threads after multi-threaded " 1166 "fork is not supported. Dying (set die_after_fork=0 to override)", 1167) 1168@unittest.skipIf( 1169 TEST_WITH_ASAN, 1170 "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223", 1171) 1172class TestDataLoader(TestCase): 1173 def setUp(self): 1174 super().setUp() 1175 self.data = torch.randn(100, 2, 3, 5) 1176 self.labels = torch.randperm(50).repeat(2) 1177 self.dataset = TensorDataset(self.data, self.labels) 1178 self.persistent_workers = False 1179 1180 def _get_data_loader(self, dataset, **kwargs): 1181 persistent_workers = kwargs.get("persistent_workers", self.persistent_workers) 1182 if persistent_workers and kwargs.get("num_workers", 0) == 0: 1183 persistent_workers = False 1184 kwargs["persistent_workers"] = persistent_workers 1185 return DataLoader(dataset, **kwargs) 1186 1187 def _test_sequential(self, loader): 1188 batch_size = loader.batch_size 1189 if batch_size is None: 1190 for idx, (sample, target) in enumerate(loader): 1191 self.assertEqual(sample, self.data[idx]) 1192 self.assertEqual(target, self.labels[idx]) 1193 self.assertEqual(idx, len(self.dataset) - 1) 1194 else: 1195 for i, (sample, target) in enumerate(loader): 1196 idx = i * batch_size 1197 self.assertEqual(sample, self.data[idx : idx + batch_size]) 1198 self.assertEqual(target, self.labels[idx : idx + batch_size]) 1199 self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) 1200 1201 def _test_shuffle(self, loader): 1202 found_data = dict.fromkeys(range(self.data.size(0)), 0) 1203 found_labels = dict.fromkeys(range(self.labels.size(0)), 0) 1204 batch_size = loader.batch_size 1205 if batch_size is None: 1206 for i, (batch_samples, batch_targets) in enumerate(loader): 1207 sample, target = (batch_samples, batch_targets) 1208 for data_point_idx, data_point in enumerate(self.data): 1209 if data_point.eq(sample).all(): 1210 self.assertFalse(found_data[data_point_idx]) 1211 found_data[data_point_idx] += 1 1212 break 1213 self.assertEqual(target, self.labels[data_point_idx]) 1214 found_labels[data_point_idx] += 1 1215 self.assertEqual(sum(found_data.values()), (i + 1)) 1216 self.assertEqual(sum(found_labels.values()), (i + 1)) 1217 self.assertEqual(i, (len(self.dataset) - 1)) 1218 else: 1219 for i, (batch_samples, batch_targets) in enumerate(loader): 1220 for sample, target in zip(batch_samples, batch_targets): 1221 for data_point_idx, data_point in enumerate(self.data): 1222 if data_point.eq(sample).all(): 1223 self.assertFalse(found_data[data_point_idx]) 1224 found_data[data_point_idx] += 1 1225 break 1226 self.assertEqual(target, self.labels[data_point_idx]) 1227 found_labels[data_point_idx] += 1 1228 self.assertEqual(sum(found_data.values()), (i + 1) * batch_size) 1229 self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size) 1230 self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) 1231 1232 def _test_error(self, loader): 1233 it = iter(loader) 1234 errors = 0 1235 while True: 1236 try: 1237 next(it) 1238 except NotImplementedError: 1239 errors += 1 1240 except StopIteration: 1241 self.assertEqual( 1242 errors, math.ceil(float(len(loader.dataset)) / loader.batch_size) 1243 ) 1244 return 1245 1246 def test_error_in_init(self): 1247 for num_workers in [0, 2]: 1248 loader = self._get_data_loader( 1249 ErrorIterableDataset(), num_workers=num_workers 1250 ) 1251 with self.assertRaisesRegex(RuntimeError, "Error in __iter__"): 1252 list(iter(loader)) 1253 1254 loader = self._get_data_loader( 1255 self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn 1256 ) 1257 with self.assertRaisesRegex(RuntimeError, "Error in worker_init_fn"): 1258 list(iter(loader)) 1259 1260 def test_typing(self): 1261 from typing import List 1262 1263 # Make sure there is no TypeError 1264 1265 class SomeDatasetClass(Dataset[List[torch.Tensor]]): 1266 pass 1267 1268 def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]: 1269 pass 1270 1271 @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") 1272 @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows") 1273 def test_fd_limit_exceeded(self): 1274 # See NOTE [ DataLoader on Linux and open files limit ] 1275 import subprocess 1276 1277 subprocess.check_output( 1278 [ 1279 sys.executable, 1280 "-c", 1281 """\ 1282import torch 1283import resource 1284from torch.utils.data import DataLoader, IterableDataset 1285 1286class RandomDataset(IterableDataset): 1287 def __init__(self, len, size): 1288 super(RandomDataset).__init__() 1289 self.len = len 1290 self.size = size 1291 1292 def __iter__(self): 1293 return self 1294 1295 def __next__(self): 1296 if self.len <= 0: 1297 raise StopIteration 1298 self.len -= 1 1299 return torch.randn(self.size) 1300 1301try: 1302 keep_fds_alive = [] 1303 resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100)) 1304 for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork", 1305 num_workers=1): 1306 random_t.max(dim=0) 1307 keep_fds_alive.append(random_t) 1308except RuntimeError as e: 1309 assert "ulimit -n" in str(e) 1310 assert "set_sharing_strategy" in str(e) 1311""", 1312 ] 1313 ) 1314 1315 def test_invalid_assign_after_init(self): 1316 dl = self._get_data_loader(self.dataset) 1317 for attr in ("batch_size", "sampler", "batch_sampler", "drop_last", "dataset"): 1318 1319 def fn(): 1320 setattr(dl, attr, {}) 1321 1322 self.assertRaises(ValueError, fn) 1323 1324 def test_sequential_nonbatch(self): 1325 self._test_sequential(self._get_data_loader(self.dataset, batch_size=None)) 1326 1327 def test_sequential_batch(self): 1328 self._test_sequential(self._get_data_loader(self.dataset)) 1329 self._test_sequential(self._get_data_loader(self.dataset, batch_size=2)) 1330 1331 def test_bulk_loading_nobatch(self): 1332 n = 35 1333 bs = 4 1334 ds = BulkLoadingDataset(n) 1335 sampler = BulkLoadingSampler(ds, batch_size=4) 1336 1337 for num_workers in [0, 4]: 1338 dl = self._get_data_loader( 1339 ds, 1340 num_workers=num_workers, 1341 batch_size=None, 1342 sampler=sampler, 1343 pin_memory=TEST_CUDA, 1344 ) 1345 self.assertFalse(dl._auto_collation) 1346 samples = list(dl) 1347 self.assertEqual(samples[0].is_pinned(), TEST_CUDA) 1348 self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n))) 1349 1350 def test_growing_dataset(self): 1351 dataset = [torch.ones(4) for _ in range(4)] 1352 dataloader_seq = self._get_data_loader(dataset, shuffle=False) 1353 dataloader_shuffle = self._get_data_loader(dataset, shuffle=True) 1354 dataset.append(torch.ones(4)) 1355 self.assertEqual(len(dataloader_seq), 5) 1356 self.assertEqual(len(dataloader_shuffle), 5) 1357 1358 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 1359 def test_sequential_pin_memory(self): 1360 loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True) 1361 for input, target in loader: 1362 self.assertTrue(input.is_pinned()) 1363 self.assertTrue(target.is_pinned()) 1364 1365 @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1366 def test_multiple_dataloaders(self): 1367 for multiprocessing_context in supported_multiprocessing_contexts: 1368 loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1)) 1369 loader2_it = iter( 1370 self._get_data_loader( 1371 self.dataset, 1372 num_workers=2, 1373 multiprocessing_context=multiprocessing_context, 1374 ) 1375 ) 1376 next(loader1_it) 1377 next(loader1_it) 1378 next(loader2_it) 1379 next(loader2_it) 1380 next(loader1_it) 1381 next(loader2_it) 1382 del loader1_it 1383 del loader2_it 1384 1385 @unittest.skipIf(True, "This test is disabled in pytorch/pytorch") 1386 def test_segfault(self): 1387 p = ErrorTrackingProcess(target=_test_segfault) 1388 p.start() 1389 p.join(JOIN_TIMEOUT) 1390 try: 1391 self.assertFalse(p.is_alive()) 1392 self.assertNotEqual(p.exitcode, 0) 1393 if IS_WINDOWS: 1394 self.assertIsInstance(p.exception, OSError) 1395 self.assertRegex(str(p.exception), r"access violation reading ") 1396 else: 1397 self.assertIsInstance(p.exception, RuntimeError) 1398 self.assertRegex( 1399 str(p.exception), 1400 r"DataLoader worker \(pid \d+\) is killed by signal: ", 1401 ) 1402 finally: 1403 p.terminate() 1404 1405 # Tests if the child process forked by the DataLoader segfaults due to having more than 3 threads 1406 # in the parent process after at least one set_num_threads invocation in the parent process. 1407 # After forking, set_num_threads(1) in the child process entails handling some inherited data-structures 1408 # of the Caffe2 thread-pool of the parent process, culminating in a segfault. 1409 # Reference: https://github.com/pytorch/pytorch/issues/54752 1410 @unittest.skipIf(IS_WINDOWS, "Needs fork") 1411 def test_no_segfault(self): 1412 p = ErrorTrackingProcess(target=_test_no_segfault) 1413 p.start() 1414 p.join(JOIN_TIMEOUT) 1415 try: 1416 self.assertFalse(p.is_alive()) 1417 if p.exception: 1418 self.assertIsInstance(p.exception, RuntimeError) 1419 self.assertRegex( 1420 str(p.exception), 1421 r"DataLoader worker \(pid \d+\) is killed by signal: ", 1422 ) 1423 self.fail("Segfault occurred in worker process after fork") 1424 finally: 1425 p.terminate() 1426 1427 def test_timeout(self): 1428 if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN: 1429 # This test runs in a subprocess, which can only initialize CUDA with spawn. 1430 # _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is 1431 # constructed. 1432 targets = (_test_timeout, _test_timeout_pin_memory) 1433 else: 1434 targets = (_test_timeout,) 1435 for target in targets: 1436 p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,)) 1437 p.start() 1438 p.join(JOIN_TIMEOUT) 1439 try: 1440 self.assertFalse(p.is_alive()) 1441 self.assertNotEqual(p.exitcode, 0) 1442 self.assertIsInstance(p.exception, RuntimeError) 1443 self.assertRegex( 1444 str(p.exception), r"DataLoader timed out after \d+ seconds" 1445 ) 1446 finally: 1447 p.terminate() 1448 1449 def test_large_sampler_indices(self): 1450 # Test that the data loader cleanly exit when the process errors 1451 # 1. having an reference to the iterator 1452 # 2. using a sampler that yields big elements s.t. _index_queues putters block 1453 # 1454 # More context: https://github.com/pytorch/pytorch/issues/48666 1455 1456 p = ErrorTrackingProcess( 1457 target=_test_large_sampler_indices, args=(self.persistent_workers,) 1458 ) 1459 p.start() 1460 p.join(JOIN_TIMEOUT) 1461 try: 1462 self.assertFalse(p.is_alive()) 1463 self.assertNotEqual(p.exitcode, 0) 1464 self.assertIsInstance(p.exception, RuntimeError) 1465 self.assertRegex(str(p.exception), r"My Error") 1466 finally: 1467 p.terminate() 1468 1469 def test_invalid_ctor_args_combinations(self): 1470 # general 1471 with self.assertRaisesRegex( 1472 ValueError, "num_workers option should be non-negative" 1473 ): 1474 self._get_data_loader(self.dataset, num_workers=-1) 1475 with self.assertRaisesRegex( 1476 ValueError, "timeout option should be non-negative" 1477 ): 1478 self._get_data_loader(self.dataset, timeout=-1) 1479 1480 # disable auto-batching 1481 with self.assertRaisesRegex( 1482 ValueError, 1483 "batch_size=None option disables auto-batching and is mutually exclusive", 1484 ): 1485 self._get_data_loader(self.dataset, batch_size=None, drop_last=True) 1486 1487 valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1] 1488 with self.assertRaisesRegex( 1489 ValueError, r"multi-process loading \(num_workers > 0\), but got" 1490 ): 1491 self._get_data_loader( 1492 self.dataset, num_workers=0, multiprocessing_context=valid_ctx 1493 ) 1494 with self.assertRaisesRegex( 1495 ValueError, "should specify a valid start method in" 1496 ): 1497 self._get_data_loader( 1498 self.dataset, num_workers=1, multiprocessing_context="bad" 1499 ) 1500 with self.assertRaisesRegex( 1501 TypeError, "multiprocessing_context option should be a valid context " 1502 ): 1503 self._get_data_loader( 1504 self.dataset, num_workers=1, multiprocessing_context=object() 1505 ) 1506 1507 # map-style 1508 sampler = torch.utils.data.SequentialSampler(self.dataset) 1509 batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False) 1510 with self.assertRaisesRegex( 1511 ValueError, "sampler option is mutually exclusive with shuffle" 1512 ): 1513 self._get_data_loader( 1514 self.dataset, batch_size=11, sampler=sampler, shuffle=True 1515 ) 1516 with self.assertRaisesRegex( 1517 ValueError, "sampler option is mutually exclusive with shuffle" 1518 ): 1519 self._get_data_loader( 1520 self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True 1521 ) 1522 with self.assertRaisesRegex( 1523 ValueError, "sampler option is mutually exclusive with shuffle" 1524 ): 1525 self._get_data_loader( 1526 self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3 1527 ) 1528 with self.assertRaisesRegex( 1529 ValueError, "batch_sampler option is mutually exclusive with" 1530 ): 1531 self._get_data_loader( 1532 self.dataset, batch_size=11, batch_sampler=batch_sampler 1533 ) 1534 with self.assertRaisesRegex( 1535 ValueError, "batch_sampler option is mutually exclusive with" 1536 ): 1537 self._get_data_loader( 1538 self.dataset, shuffle=True, batch_sampler=batch_sampler 1539 ) 1540 with self.assertRaisesRegex( 1541 ValueError, "batch_sampler option is mutually exclusive with" 1542 ): 1543 self._get_data_loader( 1544 self.dataset, drop_last=True, batch_sampler=batch_sampler 1545 ) 1546 with self.assertRaisesRegex( 1547 ValueError, "batch_sampler option is mutually exclusive with" 1548 ): 1549 self._get_data_loader( 1550 self.dataset, drop_last=3, batch_sampler=batch_sampler 1551 ) 1552 1553 # iterable-style 1554 dataset = CountingIterableDataset(20) 1555 with self.assertRaisesRegex( 1556 ValueError, "DataLoader with IterableDataset: expected unspecified shuffle" 1557 ): 1558 self._get_data_loader(dataset, shuffle=True) 1559 with self.assertRaisesRegex( 1560 ValueError, "DataLoader with IterableDataset: expected unspecified shuffle" 1561 ): 1562 self._get_data_loader(dataset, shuffle=3) 1563 with self.assertRaisesRegex( 1564 ValueError, "DataLoader with IterableDataset: expected unspecified sampler" 1565 ): 1566 self._get_data_loader( 1567 dataset, sampler=torch.utils.data.SequentialSampler(dataset) 1568 ) 1569 with self.assertRaisesRegex( 1570 ValueError, "DataLoader with IterableDataset: expected unspecified sampler" 1571 ): 1572 self._get_data_loader(dataset, sampler=3) 1573 with self.assertRaisesRegex( 1574 ValueError, 1575 "DataLoader with IterableDataset: expected unspecified batch_sampler", 1576 ): 1577 self._get_data_loader( 1578 dataset, 1579 batch_sampler=torch.utils.data.BatchSampler( 1580 torch.utils.data.SequentialSampler(dataset), 3, False 1581 ), 1582 ) 1583 with self.assertRaisesRegex( 1584 ValueError, 1585 "DataLoader with IterableDataset: expected unspecified batch_sampler", 1586 ): 1587 self._get_data_loader(dataset, batch_sampler=3) 1588 1589 def test_builtin_collection_conversion(self): 1590 for coll_ty in (list, tuple): 1591 for num_workers in (0, 1): 1592 # map-style dataset 1593 dataset = CountingDataset(20) 1594 # no auto-batching 1595 fetched = coll_ty( 1596 self._get_data_loader( 1597 dataset, batch_size=None, num_workers=num_workers 1598 ) 1599 ) 1600 self.assertEqual(fetched, coll_ty(range(20))) 1601 # auto-batching 1602 fetched = coll_ty( 1603 self._get_data_loader( 1604 dataset, batch_size=2, num_workers=num_workers 1605 ) 1606 ) 1607 self.assertEqual( 1608 fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2)) 1609 ) 1610 1611 # iterable-style dataset 1612 dataset = CountingIterableDataset(20) 1613 # no auto-batching 1614 fetched = coll_ty( 1615 self._get_data_loader( 1616 dataset, batch_size=None, num_workers=num_workers 1617 ) 1618 ) 1619 self.assertEqual(fetched, coll_ty(range(20))) 1620 # auto-batching 1621 # this IterableDataset isn't configured for each worker, so for 1622 # the equality test below to be valid, we cannot have more than 1 workers. 1623 assert num_workers in [0, 1], "invalid test" 1624 fetched = coll_ty( 1625 self._get_data_loader( 1626 dataset, batch_size=2, num_workers=num_workers 1627 ) 1628 ) 1629 self.assertEqual( 1630 fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2)) 1631 ) 1632 1633 def test_iterable_style_dataset(self): 1634 # [no auto-batching] single process loading 1635 dataset = CountingIterableDataset(20) 1636 dataloader = self._get_data_loader(dataset, batch_size=None) 1637 fetched = list(dataloader) 1638 self.assertEqual(len(fetched), 20) 1639 for i, d in enumerate(fetched): 1640 # non-batched should not convert ints into tensors 1641 self.assertIsInstance(d, int) 1642 self.assertEqual(d, i) 1643 # DataLoader should match len of the iterable-style dataset (if implemented) 1644 self.assertEqual(len(dataloader), len(dataset)) 1645 1646 # [no auto-batching] multiprocessing loading 1647 num_workers = 3 1648 sizes_for_all_workers = [0, 4, 20] 1649 expected = sorted( 1650 functools.reduce( 1651 operator.iadd, (list(range(s)) for s in sizes_for_all_workers), [] 1652 ) 1653 ) 1654 assert len(sizes_for_all_workers) == num_workers, "invalid test case" 1655 for prefetch_factor in [2, 3, 4]: 1656 dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) 1657 dataloader = self._get_data_loader( 1658 dataset, 1659 num_workers=num_workers, 1660 batch_size=None, 1661 worker_init_fn=set_faulthander_if_available, 1662 prefetch_factor=prefetch_factor, 1663 ) 1664 dataloader_iter = iter(dataloader) 1665 fetched = sorted(dataloader_iter) 1666 for a, b in zip(fetched, expected): 1667 # non-batched should not convert ints into tensors 1668 self.assertIsInstance(a, int) 1669 self.assertEqual(a, b) 1670 # DataLoader should match len of the iterable-style dataset (if implemented) 1671 self.assertEqual(len(dataloader), len(dataset)) 1672 # When loading more than len(dataset) data, after accessing len(dataloader), 1673 # we should get a warning. See NOTE [ IterableDataset and __len__ ]. 1674 dataset = CountingIterableDataset(20) 1675 dataloader = self._get_data_loader( 1676 dataset, 1677 num_workers=num_workers, 1678 worker_init_fn=set_faulthander_if_available, 1679 prefetch_factor=prefetch_factor, 1680 ) 1681 it = iter(dataloader) 1682 for _ in range(40): 1683 self.assertNotWarn( 1684 lambda: next(it), "Should not warn before accessing len(dataloader)" 1685 ) 1686 self.assertEqual(len(dataloader), len(dataset)) 1687 self.assertEqual(len(dataloader), 20) 1688 it = iter(dataloader) 1689 for _ in range(20): 1690 self.assertNotWarn( 1691 lambda: next(it), "Should not warn before exceeding length" 1692 ) 1693 for _ in range(3): 1694 with self.assertWarnsRegex( 1695 UserWarning, 1696 r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this", 1697 msg="Should always warn after exceeding length", 1698 ): 1699 next(it) 1700 # [no auto-batching] test that workers exit gracefully 1701 workers = dataloader_iter._workers 1702 del dataloader_iter 1703 del dataloader 1704 try: 1705 for w in workers: 1706 w.join(JOIN_TIMEOUT) 1707 self.assertFalse(w.is_alive()) 1708 self.assertEqual(w.exitcode, 0) 1709 finally: 1710 for w in workers: 1711 w.terminate() 1712 1713 # [auto-batching] single process loading 1714 dataset = CountingIterableDataset(20) 1715 fetched = list(self._get_data_loader(dataset, batch_size=7)) 1716 self.assertEqual(len(fetched), 3) 1717 self.assertEqual(fetched[0].tolist(), list(range(7))) 1718 self.assertEqual(fetched[1].tolist(), list(range(7, 14))) 1719 self.assertEqual(fetched[2].tolist(), list(range(14, 20))) 1720 1721 # [auto-batching] multiprocessing loading 1722 num_workers = 3 1723 sizes_for_all_workers = [0, 4, 20] 1724 expected = sorted( 1725 functools.reduce( 1726 operator.iadd, (list(range(s)) for s in sizes_for_all_workers), [] 1727 ) 1728 ) 1729 assert len(sizes_for_all_workers) == num_workers, "invalid test case" 1730 for prefetch_factor in [2, 3, 4]: 1731 dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) 1732 # worker 0 should return 0 batches 1733 # worker 1 should return 1 batches 1734 # worker 2 should return 3 batches 1735 dataloader = self._get_data_loader( 1736 dataset, 1737 num_workers=num_workers, 1738 batch_size=7, 1739 prefetch_factor=prefetch_factor, 1740 ) 1741 dataloader_iter = iter(dataloader) 1742 fetched = list(dataloader_iter) 1743 self.assertEqual(len(fetched), 4) 1744 fetched = {tuple(t.tolist()) for t in fetched} 1745 self.assertEqual( 1746 fetched, 1747 { 1748 tuple(range(4)), 1749 tuple(range(7)), 1750 tuple(range(7, 14)), 1751 tuple(range(14, 20)), 1752 }, 1753 ) 1754 1755 # [auto-batching] test that workers exit gracefully 1756 workers = dataloader_iter._workers 1757 del dataloader_iter 1758 del dataloader 1759 try: 1760 for w in workers: 1761 w.join(JOIN_TIMEOUT) 1762 self.assertFalse(w.is_alive()) 1763 self.assertEqual(w.exitcode, 0) 1764 finally: 1765 for w in workers: 1766 w.terminate() 1767 # [auto-batching & drop_last] single process loading 1768 dataset = CountingIterableDataset(20) 1769 fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True)) 1770 self.assertEqual(len(fetched), 2) 1771 self.assertEqual(fetched[0].tolist(), list(range(7))) 1772 self.assertEqual(fetched[1].tolist(), list(range(7, 14))) 1773 1774 # [auto-batching & drop_last] multiprocessing loading 1775 num_workers = 3 1776 sizes_for_all_workers = [0, 4, 20] 1777 expected = sorted( 1778 functools.reduce( 1779 operator.iadd, (list(range(s)) for s in sizes_for_all_workers), [] 1780 ) 1781 ) 1782 assert len(sizes_for_all_workers) == num_workers, "invalid test case" 1783 for prefetch_factor in [2, 3, 4]: 1784 dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) 1785 # worker 0 should return 0 batches 1786 # worker 1 should return 1 batches 1787 # worker 2 should return 3 batches 1788 dataloader = self._get_data_loader( 1789 dataset, 1790 num_workers=num_workers, 1791 batch_size=7, 1792 drop_last=True, 1793 worker_init_fn=set_faulthander_if_available, 1794 prefetch_factor=prefetch_factor, 1795 ) 1796 dataloader_iter = iter(dataloader) 1797 fetched = list(dataloader_iter) 1798 self.assertEqual(len(fetched), 2) 1799 fetched = {tuple(t.tolist()) for t in fetched} 1800 self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))}) 1801 1802 # [auto-batching & drop_last] test that workers exit gracefully 1803 workers = dataloader_iter._workers 1804 del dataloader_iter 1805 del dataloader 1806 try: 1807 for w in workers: 1808 w.join(JOIN_TIMEOUT) 1809 self.assertFalse(w.is_alive()) 1810 self.assertEqual(w.exitcode, 0) 1811 finally: 1812 for w in workers: 1813 w.terminate() 1814 1815 def test_chain_iterable_style_dataset(self): 1816 # chaining (concatenation) 1817 dataset1 = CountingIterableDataset(20) 1818 dataset2 = CountingIterableDataset(15) 1819 expected = list(range(20)) + list(range(15)) 1820 for num_workers in [0, 1]: 1821 for chained_dataset in [ 1822 dataset1 + dataset2, 1823 ChainDataset([dataset1, dataset2]), 1824 ]: 1825 fetched = list( 1826 self._get_data_loader(chained_dataset, num_workers=num_workers) 1827 ) 1828 self.assertEqual(len(fetched), len(expected)) 1829 for e, d in zip(expected, fetched): 1830 self.assertIsInstance(d, torch.Tensor) 1831 self.assertEqual(e, d) 1832 1833 with self.assertRaisesRegex( 1834 AssertionError, "ChainDataset only supports IterableDataset" 1835 ): 1836 list(iter(dataset1 + self.dataset)) 1837 1838 with self.assertRaisesRegex( 1839 AssertionError, "ChainDataset only supports IterableDataset" 1840 ): 1841 list(iter(ChainDataset([dataset1, self.dataset]))) 1842 1843 @unittest.skipIf(IS_MACOS, "Not working on macos") 1844 @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1845 @skipIfRocm # https://github.com/pytorch/pytorch/issues/90940 1846 def test_multiprocessing_contexts(self): 1847 reference = [ 1848 torch.arange(3), 1849 torch.arange(3, 6), 1850 torch.arange(6, 9), 1851 torch.arange(9, 11), 1852 ] 1853 counting_ds_n = 11 1854 dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA)) 1855 for ctx in supported_multiprocessing_contexts: 1856 # windows and jetson devices don't support sharing cuda tensor; ROCm does not yet fully support IPC 1857 if ( 1858 ctx in ["spawn", "forkserver"] 1859 and TEST_CUDA 1860 and not IS_WINDOWS 1861 and not IS_JETSON 1862 ): 1863 ds_cls = CUDACountingDataset 1864 else: 1865 ds_cls = CountingDataset 1866 self.assertEqual( 1867 reference, 1868 list( 1869 self._get_data_loader( 1870 ds_cls(counting_ds_n), 1871 multiprocessing_context=ctx, 1872 **dl_common_args, 1873 ) 1874 ), 1875 ) 1876 if ctx is not None: 1877 # test ctx object 1878 ctx = mp.get_context(ctx) 1879 self.assertEqual( 1880 reference, 1881 list( 1882 self._get_data_loader( 1883 ds_cls(counting_ds_n), 1884 multiprocessing_context=ctx, 1885 **dl_common_args, 1886 ) 1887 ), 1888 ) 1889 1890 def _test_multiprocessing_iterdatapipe(self, with_dill): 1891 # Testing to make sure that function from global scope (e.g. imported from library) can be serialized 1892 # and used with multiprocess DataLoader 1893 1894 reference = [ 1895 torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64), 1896 torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64), 1897 ] 1898 datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]]) 1899 datapipe = datapipe.map(row_processor) 1900 datapipe = ( 1901 datapipe.filter(lambda row: len(row) == 4) 1902 if with_dill 1903 else datapipe.filter(filter_len) 1904 ) 1905 1906 dl_common_args = dict( 1907 num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA) 1908 ) 1909 for ctx in supported_multiprocessing_contexts: 1910 self.assertEqual( 1911 reference, 1912 [ 1913 t.type(torch.int64) 1914 for t in self._get_data_loader( 1915 datapipe, multiprocessing_context=ctx, **dl_common_args 1916 ) 1917 ], 1918 ) 1919 if ctx is not None: 1920 # test ctx object 1921 ctx = mp.get_context(ctx) 1922 self.assertEqual( 1923 reference, 1924 [ 1925 t.type(torch.int64) 1926 for t in self._get_data_loader( 1927 datapipe, multiprocessing_context=ctx, **dl_common_args 1928 ) 1929 ], 1930 ) 1931 1932 @skipIfNoNumpy 1933 @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1934 def test_multiprocessing_iterdatapipe(self): 1935 self._test_multiprocessing_iterdatapipe(with_dill=False) 1936 1937 @unittest.expectedFailure 1938 @skipIfNoNumpy 1939 @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 1940 @skipIfNoDill 1941 def test_multiprocessing_iterdatapipe_with_dill(self): 1942 self._test_multiprocessing_iterdatapipe(with_dill=True) 1943 1944 def test_worker_seed(self): 1945 num_workers = 6 1946 batch_size = 1 1947 dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) 1948 dataloader = self._get_data_loader( 1949 dataset, batch_size=batch_size, num_workers=num_workers 1950 ) 1951 seeds = set() 1952 seeds.update(batch[0] for batch in dataloader) 1953 self.assertEqual(len(seeds), num_workers) 1954 1955 def test_worker_seed_reproducibility(self): 1956 def get_dataloader(): 1957 return DataLoader( 1958 dataset, 1959 batch_size=batch_size, 1960 num_workers=num_workers, 1961 generator=torch.Generator().manual_seed(42), 1962 ) 1963 1964 num_workers = 6 1965 batch_size = 1 1966 dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) 1967 self.assertEqual( 1968 {int(batch) for batch in get_dataloader()}, 1969 {int(batch) for batch in get_dataloader()}, 1970 ) 1971 1972 def test_multi_epochs_reproducibility(self): 1973 num_workers = 2 1974 batch_size = 10 1975 num_epochs = 3 1976 1977 dataset = TestMultiEpochDataset(batch_size * num_workers) 1978 dataloader = self._get_data_loader( 1979 dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 1980 ) 1981 1982 for ind in range(num_epochs): 1983 for batch_idx, sample in enumerate(dataloader): 1984 self.assertEqual( 1985 sample.tolist(), [batch_idx % num_workers] * batch_size 1986 ) 1987 1988 def test_worker_init_fn(self): 1989 dataset = SeedDataset(4) 1990 dataloader = self._get_data_loader( 1991 dataset, batch_size=2, num_workers=2, worker_init_fn=init_fn 1992 ) 1993 for batch in dataloader: 1994 self.assertEqual(12345, batch[0]) 1995 self.assertEqual(12345, batch[1]) 1996 1997 def test_get_worker_info(self): 1998 p = ErrorTrackingProcess(target=_test_get_worker_info) 1999 p.start() 2000 p.join(JOIN_TIMEOUT) 2001 try: 2002 self.assertFalse(p.is_alive()) 2003 self.assertEqual(p.exitcode, 0) 2004 finally: 2005 p.terminate() 2006 2007 def test_shuffle(self): 2008 self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True)) 2009 2010 def test_shuffle_batch_none(self): 2011 self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True)) 2012 2013 def test_shuffle_batch(self): 2014 self._test_shuffle( 2015 self._get_data_loader(self.dataset, batch_size=2, shuffle=True) 2016 ) 2017 2018 def test_shuffle_reproducibility(self): 2019 for fn in ( 2020 lambda: DataLoader( 2021 self.dataset, 2022 shuffle=True, 2023 num_workers=0, 2024 generator=torch.Generator().manual_seed(42), 2025 ), 2026 lambda: DataLoader( 2027 self.dataset, 2028 shuffle=True, 2029 num_workers=2, 2030 generator=torch.Generator().manual_seed(42), 2031 ), 2032 ): 2033 self.assertEqual(list(fn()), list(fn())) 2034 2035 def test_sequential_workers(self): 2036 self._test_sequential(self._get_data_loader(self.dataset, num_workers=4)) 2037 2038 def test_seqential_batch_workers(self): 2039 self._test_sequential( 2040 self._get_data_loader(self.dataset, batch_size=2, num_workers=4) 2041 ) 2042 2043 def test_seqential_batch_workers_prefetch(self): 2044 self._test_sequential( 2045 DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3) 2046 ) 2047 2048 def test_shuffle_workers(self): 2049 self._test_shuffle( 2050 self._get_data_loader(self.dataset, shuffle=True, num_workers=4) 2051 ) 2052 2053 def test_shuffle_batch_workers(self): 2054 self._test_shuffle( 2055 self._get_data_loader( 2056 self.dataset, batch_size=2, shuffle=True, num_workers=4 2057 ) 2058 ) 2059 2060 def test_shuffle_batch_workers_prefetch(self): 2061 self._test_shuffle( 2062 DataLoader( 2063 self.dataset, 2064 batch_size=2, 2065 shuffle=True, 2066 num_workers=4, 2067 prefetch_factor=3, 2068 ) 2069 ) 2070 2071 def test_random_sampler(self): 2072 from collections import Counter 2073 2074 from torch.utils.data import RandomSampler 2075 2076 def sample_stat(sampler, num_samples): 2077 counts = Counter(sampler) 2078 count_repeated = sum(val > 1 for val in counts.values()) 2079 return ( 2080 count_repeated, 2081 min(counts.keys()), 2082 max(counts.keys()), 2083 sum(counts.values()), 2084 ) 2085 2086 # test sample with replacement 2087 n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once 2088 sampler_with_replacement = RandomSampler( 2089 self.dataset, replacement=True, num_samples=n 2090 ) 2091 count_repeated, minval, maxval, count_total = sample_stat( 2092 sampler_with_replacement, n 2093 ) 2094 self.assertTrue(count_repeated > 0) 2095 self.assertTrue(minval >= 0) 2096 self.assertTrue(maxval < len(self.dataset)) 2097 self.assertTrue(count_total == n) 2098 2099 # test sample without replacement and without specified num_samples 2100 sampler_without_replacement = RandomSampler(self.dataset) 2101 count_repeated, minval, maxval, count_total = sample_stat( 2102 sampler_without_replacement, len(self.dataset) 2103 ) 2104 self.assertTrue(count_repeated == 0) 2105 self.assertTrue(minval == 0) 2106 self.assertTrue(maxval == len(self.dataset) - 1) 2107 self.assertTrue(count_total == len(self.dataset)) 2108 2109 # test sample without replacement and with specified num_samples 2110 n = len(self.dataset) * 2 2111 sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) 2112 count_repeated, minval, maxval, count_total = sample_stat( 2113 sampler_without_replacement, len(self.dataset) 2114 ) 2115 self.assertTrue(count_repeated == len(self.dataset)) 2116 self.assertTrue(minval == 0) 2117 self.assertTrue(maxval == len(self.dataset) - 1) 2118 self.assertTrue(count_total == n) 2119 2120 n = len(self.dataset) - 1 2121 sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) 2122 count_repeated, minval, maxval, count_total = sample_stat( 2123 sampler_without_replacement, len(self.dataset) 2124 ) 2125 self.assertTrue(count_repeated == 0) 2126 self.assertTrue(minval >= 0) 2127 self.assertTrue(maxval < len(self.dataset)) 2128 self.assertTrue(count_total == n) 2129 2130 n = len(self.dataset) + 1 2131 sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) 2132 count_repeated, minval, maxval, count_total = sample_stat( 2133 sampler_without_replacement, len(self.dataset) 2134 ) 2135 self.assertTrue(count_repeated == 1) 2136 self.assertTrue(minval == 0) 2137 self.assertTrue(maxval == len(self.dataset) - 1) 2138 self.assertTrue(count_total == n) 2139 2140 # raise error when replacement is non-boolean 2141 with self.assertRaisesRegex( 2142 TypeError, "replacement should be a boolean value, but got replacement=0" 2143 ): 2144 RandomSampler(self.dataset, replacement=0) 2145 2146 def test_random_sampler_len_with_replacement(self): 2147 from torch.utils.data import RandomSampler 2148 2149 # add 5 extra samples 2150 num_samples = len(self.dataset) + 5 2151 sampler = RandomSampler(self.dataset, replacement=True, num_samples=num_samples) 2152 # test len method 2153 self.assertEqual(num_samples, len(sampler)) 2154 2155 # test with iteration 2156 count_num_samples = sum(1 for _ in sampler) 2157 self.assertEqual(num_samples, count_num_samples) 2158 2159 # test with dataloader, batch_size = 1 2160 batch_size = 1 2161 count_num_samples_in_data_loader = len( 2162 self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2163 ) 2164 self.assertEqual(num_samples, count_num_samples_in_data_loader) 2165 2166 # test with dataloader, batch_size = 6 2167 batch_size = 6 2168 count_num_samples_in_data_loader = len( 2169 self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2170 ) 2171 self.assertEqual( 2172 int(math.ceil(float(num_samples) / batch_size)), 2173 count_num_samples_in_data_loader, 2174 ) 2175 2176 def test_random_sampler_len_without_replacement(self): 2177 from torch.utils.data import RandomSampler 2178 2179 # add 5 extra samples 2180 num_samples = len(self.dataset) + 5 2181 sampler = RandomSampler( 2182 self.dataset, replacement=False, num_samples=num_samples 2183 ) 2184 # test len method 2185 self.assertEqual(num_samples, len(sampler)) 2186 2187 # test with iteration 2188 count_num_samples = sum(1 for _ in sampler) 2189 self.assertEqual(num_samples, count_num_samples) 2190 2191 # test with dataloader, batch_size = 1 2192 batch_size = 1 2193 count_num_samples_in_data_loader = len( 2194 self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2195 ) 2196 self.assertEqual(num_samples, count_num_samples_in_data_loader) 2197 2198 # test with dataloader, batch_size = 6 2199 batch_size = 6 2200 count_num_samples_in_data_loader = len( 2201 self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler) 2202 ) 2203 self.assertEqual( 2204 num_samples // batch_size + (num_samples % batch_size > 0), 2205 count_num_samples_in_data_loader, 2206 ) 2207 2208 def test_distributed_sampler_invalid_rank(self): 2209 from torch.utils.data.distributed import DistributedSampler 2210 2211 dataset = torch.IntTensor(range(10)) 2212 with self.assertRaisesRegex(ValueError, "Invalid rank"): 2213 sampler = DistributedSampler(dataset, 3, 3) 2214 2215 with self.assertRaisesRegex(ValueError, "Invalid rank"): 2216 sampler = DistributedSampler(dataset, 3, -1) 2217 2218 def test_duplicating_data_with_drop_last(self): 2219 from torch.utils.data.distributed import DistributedSampler 2220 2221 num_processes = 4 2222 num_batches = 9 2223 data_set = torch.IntTensor(range(num_batches)) 2224 scanned_data = torch.IntTensor([]) 2225 for i in range(num_processes): 2226 s = DistributedSampler(data_set, num_processes, i) 2227 d_loader = self._get_data_loader( 2228 data_set, 2229 batch_size=int(num_batches / num_processes), 2230 drop_last=True, 2231 sampler=s, 2232 ) 2233 for data in d_loader: 2234 scanned_data = torch.cat((scanned_data, data), 0) 2235 2236 self.assertEqual(scanned_data.size(), scanned_data.unique().size()) 2237 2238 def test_sampler_reproducibility(self): 2239 from torch.utils.data import ( 2240 RandomSampler, 2241 SubsetRandomSampler, 2242 WeightedRandomSampler, 2243 ) 2244 2245 weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6] 2246 for fn in ( 2247 lambda: RandomSampler( 2248 self.dataset, 2249 num_samples=5, 2250 replacement=True, 2251 generator=torch.Generator().manual_seed(42), 2252 ), 2253 lambda: RandomSampler( 2254 self.dataset, 2255 replacement=False, 2256 generator=torch.Generator().manual_seed(42), 2257 ), 2258 lambda: WeightedRandomSampler( 2259 weights, 2260 num_samples=5, 2261 replacement=True, 2262 generator=torch.Generator().manual_seed(42), 2263 ), 2264 lambda: WeightedRandomSampler( 2265 weights, 2266 num_samples=5, 2267 replacement=False, 2268 generator=torch.Generator().manual_seed(42), 2269 ), 2270 lambda: SubsetRandomSampler( 2271 range(10), generator=torch.Generator().manual_seed(42) 2272 ), 2273 ): 2274 self.assertEqual(list(fn()), list(fn())) 2275 2276 for sampler in ( 2277 RandomSampler(self.dataset, num_samples=5, replacement=True), 2278 RandomSampler(self.dataset, replacement=False), 2279 WeightedRandomSampler(weights, num_samples=5, replacement=True), 2280 WeightedRandomSampler(weights, num_samples=5, replacement=False), 2281 SubsetRandomSampler(range(10)), 2282 ): 2283 torch.manual_seed(0) 2284 l1 = list(sampler) + list(sampler) 2285 2286 torch.manual_seed(0) 2287 l2 = list(sampler) + list(sampler) 2288 self.assertEqual(l1, l2) 2289 2290 its = (iter(sampler), iter(sampler)) 2291 ls = ([], []) 2292 for idx in range(len(sampler)): 2293 for i in range(2): 2294 if idx == 0: 2295 torch.manual_seed(0) 2296 ls[i].append(next(its[i])) 2297 self.assertEqual(ls[0], ls[1]) 2298 2299 def _test_sampler(self, **kwargs): 2300 indices = range(2, 12) # using a regular iterable 2301 dl = self._get_data_loader( 2302 self.dataset, sampler=indices, batch_size=2, **kwargs 2303 ) 2304 self.assertEqual(len(dl), 5) 2305 for i, (input, _target) in enumerate(dl): 2306 self.assertEqual(len(input), 2) 2307 self.assertEqual(input, self.data[i * 2 + 2 : i * 2 + 4]) 2308 2309 def test_sampler(self): 2310 self._test_sampler() 2311 self._test_sampler(num_workers=4) 2312 if not NO_MULTIPROCESSING_SPAWN: 2313 self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn") 2314 2315 def _test_batch_sampler(self, **kwargs): 2316 # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...] 2317 batches = [] # using a regular iterable 2318 for i in range(0, 20, 5): 2319 batches.append(tuple(range(i, i + 2))) 2320 batches.append(tuple(range(i + 2, i + 5))) 2321 2322 dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs) 2323 self.assertEqual(len(dl), 8) 2324 for i, (input, _target) in enumerate(dl): 2325 if i % 2 == 0: 2326 offset = i * 5 // 2 2327 self.assertEqual(len(input), 2) 2328 self.assertEqual(input, self.data[offset : offset + 2]) 2329 else: 2330 offset = i * 5 // 2 2331 self.assertEqual(len(input), 3) 2332 self.assertEqual(input, self.data[offset : offset + 3]) 2333 2334 def test_batch_sampler(self): 2335 self._test_batch_sampler() 2336 self._test_batch_sampler(num_workers=4) 2337 if not NO_MULTIPROCESSING_SPAWN: 2338 self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn") 2339 2340 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 2341 def test_shuffle_pin_memory(self): 2342 loader = self._get_data_loader( 2343 self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True 2344 ) 2345 for input, target in loader: 2346 self.assertTrue(input.is_pinned()) 2347 self.assertTrue(target.is_pinned()) 2348 2349 @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2350 def test_numpy(self): 2351 import numpy as np 2352 2353 class TestDataset(torch.utils.data.Dataset): 2354 def __getitem__(self, i): 2355 return np.ones((2, 3, 4)) * i 2356 2357 def __len__(self): 2358 return 1000 2359 2360 loader = self._get_data_loader(TestDataset(), batch_size=12) 2361 batch = next(iter(loader)) 2362 self.assertIsInstance(batch, torch.DoubleTensor) 2363 self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4])) 2364 2365 @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2366 def test_numpy_gen_state(self): 2367 from torch.utils.data._utils.worker import _generate_state 2368 2369 # Using NumPy generated states as the reference to test `_generate_state` 2370 # having the same result. 2371 # Test case: ((worker_id, base_seed), expected_state) 2372 test_cases = [ 2373 ( 2374 (4, 13434589827475259383), 2375 (2884386318, 1088094898, 3523808998, 3860348662), 2376 ), 2377 ((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)), 2378 ( 2379 (10, 978296274032934101), 2380 (1759791917, 3550927336, 1225977135, 1036538043), 2381 ), 2382 ( 2383 (12, 11868770762134256968), 2384 (3974661794, 3331131333, 3630387033, 2885815368), 2385 ), 2386 ( 2387 (9, 15378787925219019706), 2388 (3815056996, 3162224466, 2735102421, 3190253477), 2389 ), 2390 ((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)), 2391 ( 2392 (15, 14617792358407278405), 2393 (3402479508, 1588702753, 1169536393, 3675067356), 2394 ), 2395 ( 2396 (9, 17363320784006640087), 2397 (957989458, 2518334477, 1421725660, 3086155459), 2398 ), 2399 ( 2400 (12, 480002904169484764), 2401 (2732851467, 1762620729, 4055801988, 1277640511), 2402 ), 2403 ( 2404 (15, 16803975943592702950), 2405 (3479415043, 4022359553, 295994005, 3358606349), 2406 ), 2407 ( 2408 (9, 11704776406047813044), 2409 (1968928009, 710113752, 2442656196, 1587420279), 2410 ), 2411 ( 2412 (10, 16357891985431864516), 2413 (1271733898, 4197047399, 3727213786, 2338547348), 2414 ), 2415 ( 2416 (2, 17423369006318065007), 2417 (544294336, 1911284083, 3299147734, 3231058347), 2418 ), 2419 ((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)), 2420 ((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)), 2421 ( 2422 (6, 6269787272229682235), 2423 (2548857855, 1216457374, 1012973562, 2999759647), 2424 ), 2425 ] 2426 2427 for (worker_id, base_seed), exp in test_cases: 2428 self.assertEqual(exp, _generate_state(base_seed, worker_id)) 2429 2430 def test_error(self): 2431 self._test_error( 2432 self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True) 2433 ) 2434 2435 def test_error_workers(self): 2436 self._test_error( 2437 self._get_data_loader( 2438 ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4 2439 ) 2440 ) 2441 2442 @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test") 2443 def test_partial_workers(self): 2444 r"""Check that workers exit even if the iterator is not exhausted.""" 2445 if TEST_CUDA: 2446 pin_memory_configs = (True, False) 2447 else: 2448 pin_memory_configs = (False,) 2449 2450 for pin_memory in pin_memory_configs: 2451 loader = iter( 2452 self._get_data_loader( 2453 self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory 2454 ) 2455 ) 2456 workers = loader._workers 2457 if pin_memory: 2458 pin_memory_thread = loader._pin_memory_thread 2459 for i, _ in enumerate(loader): 2460 if i == 10: 2461 break 2462 assert i == 10 2463 del loader 2464 for w in workers: 2465 w.join(JOIN_TIMEOUT) 2466 self.assertFalse(w.is_alive(), "subprocess not terminated") 2467 if pin_memory: 2468 pin_memory_thread.join(JOIN_TIMEOUT) 2469 self.assertFalse(pin_memory_thread.is_alive()) 2470 2471 # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065 2472 @skipIfRocm 2473 @unittest.skipIf(not HAS_PSUTIL, "psutil not found") 2474 @slowTest 2475 def test_proper_exit(self): 2476 ( 2477 r"""There might be ConnectionResetError or leaked semaphore warning """ 2478 r"""(due to dirty process exit), but they are all safe to ignore""" 2479 ) 2480 2481 # TODO: test the case where the pin_memory_thread triggers an 2482 # error/fatal signal. I haven't found out how to properly do that. 2483 2484 for ( 2485 is_iterable_dataset, 2486 use_workers, 2487 pin_memory, 2488 hold_iter_reference, 2489 ) in itertools.product([True, False], repeat=4): 2490 # `hold_iter_reference` specifies whether we hold a reference to the 2491 # iterator. This is interesting because Python3 error traces holds a 2492 # reference to the frames, which hold references to all the local 2493 # variables including the iterator, and then the iterator dtor may 2494 # not be called before process end. It is important to see that the 2495 # processes still exit in both cases. 2496 2497 if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS): 2498 # This test runs in a subprocess, which can only initialize CUDA with spawn. 2499 # DataLoader with pin_memory=True initializes CUDA when its iterator is constructed. 2500 # For windows, pin_memory sometimes causes CUDA oom. 2501 continue 2502 2503 # `exit_method` controls the way the loader process ends. 2504 # - `*_kill` means that `*` is killed by OS. 2505 # - `*_error` means that `*` raises an error. 2506 # - `None` means that no error happens. 2507 # In all cases, all processes should end properly. 2508 if use_workers: 2509 # TODO: Fix test for 'loader_kill' that would cause running out of shared memory. 2510 # Killing loader process would prevent DataLoader iterator clean up all queues 2511 # and worker processes 2512 exit_methods = [None, "loader_error", "worker_error", "worker_kill"] 2513 persistent_workers = self.persistent_workers 2514 else: 2515 exit_methods = [None, "loader_error", "loader_kill"] 2516 persistent_workers = False 2517 2518 for exit_method in exit_methods: 2519 if exit_method == "worker_kill": 2520 # FIXME: This sometimes hangs. See #16608. 2521 continue 2522 2523 desc = [] 2524 desc.append(f"is_iterable_dataset={is_iterable_dataset}") 2525 desc.append(f"use_workers={use_workers}") 2526 desc.append(f"pin_memory={pin_memory}") 2527 desc.append(f"hold_iter_reference={hold_iter_reference}") 2528 desc.append(f"exit_method={exit_method}") 2529 desc = "test_proper_exit with " + ", ".join(desc) 2530 2531 # Event that the loader process uses to signal testing process 2532 # that various things are setup, including that the worker pids 2533 # are specified in `worker_pids` array. 2534 loader_setup_event = mp.Event() 2535 2536 # Event that this process has finished setting up, and the 2537 # loader process can now proceed to trigger error events or 2538 # finish normally. 2539 tester_setup_event = mp.Event() 2540 2541 loader_p = ErrorTrackingProcess( 2542 target=_test_proper_exit, 2543 args=( 2544 is_iterable_dataset, 2545 use_workers, 2546 pin_memory, 2547 exit_method, 2548 hold_iter_reference, 2549 loader_setup_event, 2550 tester_setup_event, 2551 persistent_workers, 2552 ), 2553 disable_stderr=False, 2554 ) 2555 loader_p.start() 2556 loader_psutil_p = psutil.Process(loader_p.pid) 2557 2558 # Wait for loader process to set everything up, e.g., starting 2559 # workers. 2560 loader_setup_event.wait(timeout=JOIN_TIMEOUT) 2561 if not loader_setup_event.is_set(): 2562 fail_msg = ( 2563 desc + ": loader process failed to setup within given time" 2564 ) 2565 if loader_p.exception is not None: 2566 fail_msg += f", and had exception {loader_p.exception}" 2567 elif not loader_p.is_alive(): 2568 fail_msg += f", and exited with code {loader_p.exitcode} but had no exception" 2569 else: 2570 fail_msg += ", and is still alive." 2571 if loader_p.is_alive(): 2572 # this may kill the process, needs to run after the above lines 2573 loader_p.print_traces_of_all_threads() 2574 self.fail(fail_msg) 2575 2576 # We are certain that the workers have started now. 2577 worker_psutil_ps = loader_psutil_p.children() 2578 2579 def fail(reason): 2580 report_psutil_attrs = [ 2581 "pid", 2582 "name", 2583 "cpu_times", 2584 "io_counters", 2585 "memory_full_info", 2586 "num_ctx_switches", 2587 "open_files", 2588 "threads", 2589 "status", 2590 "nice", 2591 "ionice", 2592 ] 2593 if reason is None: 2594 err_msg = desc 2595 else: 2596 err_msg = f"{desc}: {reason}" 2597 err_msg += "\nLoader info:\n\t" 2598 if loader_psutil_p.is_running(): 2599 err_msg += str( 2600 loader_psutil_p.as_dict(attrs=report_psutil_attrs) 2601 ) 2602 # this may kill the process, needs to run after the above line 2603 loader_p.print_traces_of_all_threads() 2604 else: 2605 err_msg += f"exited with code {loader_p.exitcode}" 2606 if use_workers: 2607 err_msg += "\nWorker(s) info:" 2608 for idx, worker_psutil_p in enumerate(worker_psutil_ps): 2609 err_msg += f"\n\tWorker {idx}:\n\t\t" 2610 if worker_psutil_p.is_running(): 2611 err_msg += str( 2612 worker_psutil_p.as_dict(attrs=report_psutil_attrs) 2613 ) 2614 # this may kill the process, needs to run after the above line 2615 print_traces_of_all_threads(worker_psutil_p.pid) 2616 else: 2617 err_msg += "exited with unknown code" 2618 self.fail(err_msg) 2619 2620 tester_setup_event.set() 2621 2622 try: 2623 loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL) 2624 if loader_p.is_alive(): 2625 fail_reason = "loader process did not terminate" 2626 if loader_p.exception is not None: 2627 fail( 2628 fail_reason 2629 + f", and had exception {loader_p.exception}" 2630 ) 2631 else: 2632 fail(fail_reason + ", and had no exception") 2633 _, alive = psutil.wait_procs( 2634 worker_psutil_ps, 2635 timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT), 2636 ) 2637 if len(alive) > 0: 2638 fail( 2639 "worker process (pid(s) {}) did not terminate".format( 2640 ", ".join(str(p.pid) for p in alive) 2641 ) 2642 ) 2643 if exit_method is None: 2644 if loader_p.exitcode != 0: 2645 fail( 2646 f"loader process had nonzero exitcode {loader_p.exitcode}" 2647 ) 2648 else: 2649 if loader_p.exitcode == 0: 2650 fail("loader process had zero exitcode") 2651 if exit_method == "loader_error": 2652 if not isinstance( 2653 loader_p.exception, RuntimeError 2654 ) or "Loader error" not in str(loader_p.exception): 2655 fail( 2656 f"loader process did not raise expected exception, but had {loader_p.exception}" 2657 ) 2658 elif exit_method == "worker_kill": 2659 if isinstance(loader_p.exception, RuntimeError): 2660 if "DataLoader worker (pid" not in str( 2661 loader_p.exception 2662 ): 2663 fail( 2664 f"loader process did not raise expected exception, but had {loader_p.exception}" 2665 ) 2666 elif isinstance(loader_p.exception, ConnectionRefusedError): 2667 # Sometimes, when the worker is being killed and is freeing its 2668 # resources, the unpickling in loader process will be met an 2669 # a `ConnectionRefusedError` as it can not open a socket to receive 2670 # resource. In such cases, the worker may not have fully exited, 2671 # and the loader can't know this via `is_alive` check or `SIGCHLD` 2672 # handler. So we permit this as an allowed error as well. 2673 # After all, we are happy as long as it terminates. 2674 pass 2675 else: 2676 fail( 2677 f"loader process did not raise expected exception, but had {loader_p.exception}" 2678 ) 2679 elif exit_method == "worker_error": 2680 if not isinstance( 2681 loader_p.exception, RuntimeError 2682 ) or "Worker error" not in str(loader_p.exception): 2683 fail( 2684 f"loader process did not raise expected exception, but had {loader_p.exception}" 2685 ) 2686 finally: 2687 loader_p.terminate() 2688 2689 def test_len(self): 2690 def check_len(dl, expected): 2691 self.assertEqual(len(dl), expected) 2692 n = 0 2693 for _ in dl: 2694 n += 1 2695 self.assertEqual(n, expected) 2696 2697 check_len(self.dataset, 100) 2698 check_len(self._get_data_loader(self.dataset, batch_size=2), 50) 2699 check_len(self._get_data_loader(self.dataset, batch_size=3), 34) 2700 2701 def test_iterabledataset_len(self): 2702 class IterableDataset(torch.utils.data.IterableDataset): 2703 def __len__(self): 2704 return 10 2705 2706 def __iter__(self): 2707 return iter(range(10)) 2708 2709 iterable_loader = DataLoader(IterableDataset(), batch_size=1) 2710 self.assertEqual(len(iterable_loader), 10) 2711 iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True) 2712 self.assertEqual(len(iterable_loader), 10) 2713 2714 iterable_loader = DataLoader(IterableDataset(), batch_size=2) 2715 self.assertEqual(len(iterable_loader), 5) 2716 iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True) 2717 self.assertEqual(len(iterable_loader), 5) 2718 2719 iterable_loader = DataLoader(IterableDataset(), batch_size=3) 2720 self.assertEqual(len(iterable_loader), 4) 2721 iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True) 2722 self.assertEqual(len(iterable_loader), 3) 2723 2724 @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2725 def test_numpy_scalars(self): 2726 import numpy as np 2727 2728 class ScalarDataset(torch.utils.data.Dataset): 2729 def __init__(self, dtype): 2730 self.dtype = dtype 2731 2732 def __getitem__(self, i): 2733 return self.dtype() 2734 2735 def __len__(self): 2736 return 4 2737 2738 dtypes = { 2739 np.float64: torch.DoubleTensor, 2740 np.float32: torch.FloatTensor, 2741 np.float16: torch.HalfTensor, 2742 np.int64: torch.LongTensor, 2743 np.int32: torch.IntTensor, 2744 np.int16: torch.ShortTensor, 2745 np.int8: torch.CharTensor, 2746 np.uint8: torch.ByteTensor, 2747 } 2748 for dt, tt in dtypes.items(): 2749 dset = ScalarDataset(dt) 2750 loader = self._get_data_loader(dset, batch_size=2) 2751 batch = next(iter(loader)) 2752 self.assertIsInstance(batch, tt) 2753 2754 def test_default_convert_mapping_keep_type(self): 2755 data = CustomDict({"a": 1, "b": 2}) 2756 converted = _utils.collate.default_convert(data) 2757 2758 self.assertEqual(converted, data) 2759 2760 def test_default_convert_sequence_keep_type(self): 2761 data = CustomList([1, 2, 3]) 2762 converted = _utils.collate.default_convert(data) 2763 2764 self.assertEqual(converted, data) 2765 2766 def test_default_convert_sequence_dont_keep_type(self): 2767 data = range(2) 2768 converted = _utils.collate.default_convert(data) 2769 2770 self.assertEqual(converted, [0, 1]) 2771 2772 def test_default_collate_dtype(self): 2773 arr = [1, 2, -1] 2774 collated = _utils.collate.default_collate(arr) 2775 self.assertEqual(collated, torch.tensor(arr)) 2776 self.assertEqual(collated.dtype, torch.int64) 2777 2778 arr = [1.1, 2.3, -0.9] 2779 collated = _utils.collate.default_collate(arr) 2780 self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64)) 2781 2782 arr = [True, False] 2783 collated = _utils.collate.default_collate(arr) 2784 self.assertEqual(collated, torch.tensor(arr)) 2785 self.assertEqual(collated.dtype, torch.bool) 2786 2787 # Should be a no-op 2788 arr = ["a", "b", "c"] 2789 self.assertEqual(arr, _utils.collate.default_collate(arr)) 2790 2791 def test_default_collate_mapping_keep_type(self): 2792 batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})] 2793 collated = _utils.collate.default_collate(batch) 2794 2795 expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])}) 2796 self.assertEqual(collated, expected) 2797 2798 def test_default_collate_sequence_keep_type(self): 2799 batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])] 2800 collated = _utils.collate.default_collate(batch) 2801 2802 expected = CustomList( 2803 [ 2804 torch.tensor([1, 4]), 2805 torch.tensor([2, 5]), 2806 torch.tensor([3, 6]), 2807 ] 2808 ) 2809 self.assertEqual(collated, expected) 2810 2811 def test_default_collate_sequence_dont_keep_type(self): 2812 batch = [range(2), range(2)] 2813 collated = _utils.collate.default_collate(batch) 2814 2815 self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])]) 2816 2817 @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2818 def test_default_collate_bad_numpy_types(self): 2819 import numpy as np 2820 2821 # Should be a no-op 2822 arr = np.array(["a", "b", "c"]) 2823 self.assertEqual(arr, _utils.collate.default_collate(arr)) 2824 2825 arr = np.array([[["a", "b", "c"]]]) 2826 self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) 2827 2828 arr = np.array([object(), object(), object()]) 2829 self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) 2830 2831 arr = np.array([[[object(), object(), object()]]]) 2832 self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) 2833 2834 @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2835 def test_default_collate_numpy_memmap(self): 2836 import numpy as np 2837 2838 with tempfile.TemporaryFile() as f: 2839 arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]]) 2840 arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape) 2841 arr_memmap[:] = arr[:] 2842 arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape) 2843 tensor = _utils.collate.default_collate(list(arr_new)) 2844 2845 self.assertTrue( 2846 (tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item() 2847 ) 2848 2849 def test_default_collate_bad_sequence_type(self): 2850 batch = [["X"], ["X", "X"]] 2851 self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch)) 2852 self.assertRaises( 2853 RuntimeError, lambda: _utils.collate.default_collate(batch[::-1]) 2854 ) 2855 2856 @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") 2857 def test_default_collate_shared_tensor(self): 2858 import numpy as np 2859 2860 t_in = torch.zeros(1) 2861 n_in = np.zeros(1) 2862 2863 self.assertEqual(t_in.is_shared(), False) 2864 2865 self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False) 2866 self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False) 2867 2868 # FIXME: fix the following hack that makes `default_collate` believe 2869 # that it is in a worker process (since it tests 2870 # `get_worker_info() != None`), even though it is not. 2871 old = _utils.worker._worker_info 2872 try: 2873 _utils.worker._worker_info = "x" 2874 self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True) 2875 self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True) 2876 finally: 2877 _utils.worker._worker_info = old 2878 2879 def test_excessive_thread_creation_warning(self): 2880 with self.assertWarnsRegex( 2881 UserWarning, 2882 r"excessive worker creation might get DataLoader running slow or even freeze", 2883 ): 2884 dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) 2885 2886 2887class TestDataLoaderDeviceType(TestCase): 2888 @parametrize( 2889 "context", 2890 [ctx for ctx in supported_multiprocessing_contexts if ctx is not None], 2891 ) 2892 @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") 2893 def test_nested_tensor_multiprocessing(self, device, context): 2894 # The 'fork' multiprocessing context doesn't work for CUDA so skip it 2895 if "cuda" in device and context == "fork": 2896 # TODO: Skip this better in a better way when the test framework allows 2897 return 2898 2899 dataset = [ 2900 torch.nested.nested_tensor([torch.randn(5)], device=device) 2901 for _ in range(10) 2902 ] 2903 2904 pin_memory_settings = [False] 2905 if device == "cpu" and torch.cuda.is_available(): 2906 pin_memory_settings.append(True) 2907 2908 for pin_memory in pin_memory_settings: 2909 loader = torch.utils.data.DataLoader( 2910 dataset, 2911 batch_size=1, 2912 num_workers=4, 2913 collate_fn=_clone_collate, 2914 pin_memory=pin_memory, 2915 multiprocessing_context=context, 2916 ) 2917 2918 for i, batch in enumerate(loader): 2919 self.assertEqual(batch[0], dataset[i]) 2920 2921 # Error case: default collate_fn doesn't currently support batches of nested tensors. 2922 # Following the current semantics, we'd need to stack them, which isn't possible atm. 2923 with self.assertRaisesRegex( 2924 RuntimeError, "not currently supported by the default collate_fn" 2925 ): 2926 loader = torch.utils.data.DataLoader( 2927 dataset, 2928 batch_size=1, 2929 num_workers=4, 2930 multiprocessing_context=context, 2931 ) 2932 2933 next(iter(loader)) 2934 2935 2936class IntegrationTestDataLoaderDataPipe(TestCase): 2937 r""" 2938 Verify the behavior of a certain ``DataPipes`` with ``DataLoader`` 2939 """ 2940 2941 def test_shuffler_iterdatapipe(self): 2942 r""" 2943 Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader`` 2944 to generate different seeds deterministically per epoch. 2945 """ 2946 exp = list(range(100)) 2947 2948 def _create_dp(buffer_size): 2949 input_ds = dp.iter.IterableWrapper(exp) 2950 return input_ds.shuffle(buffer_size=buffer_size).sharding_filter() 2951 2952 for bs in (5, 20, 33): 2953 # Test Deterministic 2954 for num_workers, pw in itertools.product((0, 1, 2), (True, False)): 2955 if num_workers == 0 and pw: 2956 continue 2957 2958 shuffle_dp = _create_dp(bs) 2959 2960 mp_ctx = "spawn" if num_workers > 0 else None 2961 dl = DataLoader( 2962 shuffle_dp, 2963 num_workers=num_workers, 2964 shuffle=True, 2965 multiprocessing_context=mp_ctx, 2966 persistent_workers=pw, 2967 ) 2968 2969 # No seed 2970 dl_res_ns = list(dl) 2971 self.assertEqual(sorted(dl_res_ns), exp) 2972 2973 # Same seeds 2974 dl_res = [] 2975 for epoch in range(2): 2976 torch.manual_seed(123) 2977 dl_res.append(list(dl)) 2978 self.assertEqual(dl_res[0], dl_res[1]) 2979 self.assertEqual(sorted(dl_res[0]), exp) 2980 2981 # Different seeds 2982 torch.manual_seed(321) 2983 dl_res.append(list(dl)) 2984 2985 self.assertEqual(len(dl_res[0]), len(dl_res[2])) 2986 self.assertNotEqual(dl_res[0], dl_res[2]) 2987 self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2])) 2988 2989 if dl._iterator is not None: 2990 dl._iterator._shutdown_workers() 2991 dl._iterator = None 2992 del dl 2993 2994 2995class StringDataset(Dataset): 2996 def __init__(self) -> None: 2997 self.s = "12345" 2998 2999 def __len__(self): 3000 return len(self.s) 3001 3002 def __getitem__(self, ndx): 3003 return (self.s[ndx], ndx) 3004 3005 3006@unittest.skipIf( 3007 TEST_WITH_TSAN, 3008 "Fails with TSAN with the following error: starting new threads after multi-threaded " 3009 "fork is not supported. Dying (set die_after_fork=0 to override)", 3010) 3011class TestStringDataLoader(TestCase): 3012 def setUp(self): 3013 super().setUp() 3014 self.dataset = StringDataset() 3015 3016 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3017 def test_shuffle_pin_memory(self): 3018 loader = DataLoader( 3019 self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True 3020 ) 3021 for s, n in loader: 3022 self.assertIsInstance(s[0], str) 3023 self.assertTrue(n.is_pinned()) 3024 3025 3026class DictDataset(Dataset): 3027 def __len__(self): 3028 return 4 3029 3030 def __getitem__(self, ndx): 3031 return { 3032 "a_tensor": torch.empty(4, 2).fill_(ndx), 3033 "another_dict": {"a_number": ndx}, 3034 } 3035 3036 3037@unittest.skipIf( 3038 TEST_WITH_TSAN, 3039 "Fails with TSAN with the following error: starting new threads after multi-threaded " 3040 "fork is not supported. Dying (set die_after_fork=0 to override)", 3041) 3042class TestDictDataLoader(TestCase): 3043 def setUp(self): 3044 super().setUp() 3045 self.dataset = DictDataset() 3046 3047 def test_sequential_batch(self): 3048 for persistent_workers in (False, True): 3049 if persistent_workers: 3050 loader = DataLoader( 3051 self.dataset, 3052 batch_size=2, 3053 shuffle=False, 3054 persistent_workers=persistent_workers, 3055 num_workers=1, 3056 ) 3057 else: 3058 loader = DataLoader( 3059 self.dataset, 3060 batch_size=2, 3061 shuffle=False, 3062 persistent_workers=persistent_workers, 3063 ) 3064 batch_size = loader.batch_size 3065 for i, sample in enumerate(loader): 3066 idx = i * batch_size 3067 self.assertEqual(set(sample.keys()), {"a_tensor", "another_dict"}) 3068 self.assertEqual(set(sample["another_dict"].keys()), {"a_number"}) 3069 3070 t = sample["a_tensor"] 3071 self.assertEqual(t.size(), torch.Size([batch_size, 4, 2])) 3072 self.assertTrue((t[0] == idx).all()) 3073 self.assertTrue((t[1] == idx + 1).all()) 3074 3075 n = sample["another_dict"]["a_number"] 3076 self.assertEqual(n.size(), torch.Size([batch_size])) 3077 self.assertEqual(n[0], idx) 3078 self.assertEqual(n[1], idx + 1) 3079 3080 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3081 def test_pin_memory(self): 3082 loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) 3083 for sample in loader: 3084 self.assertTrue(sample["a_tensor"].is_pinned()) 3085 self.assertTrue(sample["another_dict"]["a_number"].is_pinned()) 3086 3087 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3088 def test_pin_memory_device(self): 3089 loader = DataLoader( 3090 self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda" 3091 ) 3092 for sample in loader: 3093 self.assertTrue(sample["a_tensor"].is_pinned(device="cuda")) 3094 self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda")) 3095 3096 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3097 def test_pin_memory_with_only_device(self): 3098 loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda") 3099 for sample in loader: 3100 self.assertFalse(sample["a_tensor"].is_pinned(device="cuda")) 3101 self.assertFalse( 3102 sample["another_dict"]["a_number"].is_pinned(device="cuda") 3103 ) 3104 3105 3106class DummyDataset(torch.utils.data.Dataset): 3107 def __init__(self) -> None: 3108 self.data = list(range(10)) 3109 3110 def __len__(self): 3111 return len(self.data) 3112 3113 def __getitem__(self, idx): 3114 if torch.is_tensor(idx): 3115 idx = idx.tolist() 3116 # The persistent workers always maintain the original 3117 # dataset through the dataloader lifetime 3118 # so the attributes will remain the same as the 3119 # first time the workers where spawned (dataloader iteration) 3120 assert self.start == 0 3121 return self.data[idx] 3122 3123 3124@unittest.skipIf( 3125 TEST_WITH_TSAN, 3126 "Fails with TSAN with the following error: starting new threads after multi-threaded " 3127 "fork is not supported. Dying (set die_after_fork=0 to override)", 3128) 3129@unittest.skipIf( 3130 TEST_WITH_ASAN, 3131 "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223", 3132) 3133class TestDataLoaderPersistentWorkers(TestDataLoader): 3134 def setUp(self): 3135 super().setUp() 3136 self.persistent_workers = True 3137 3138 @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") 3139 @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows") 3140 def test_fd_limit_exceeded(self): 3141 # See NOTE [ DataLoader on Linux and open files limit ] 3142 import subprocess 3143 3144 subprocess.check_output( 3145 [ 3146 sys.executable, 3147 "-c", 3148 """\ 3149import torch 3150import resource 3151from torch.utils.data import DataLoader, IterableDataset 3152 3153class RandomDataset(IterableDataset): 3154 def __init__(self, len, size): 3155 super(RandomDataset).__init__() 3156 self.len = len 3157 self.size = size 3158 3159 def __iter__(self): 3160 return self 3161 3162 def __next__(self): 3163 if self.len <= 0: 3164 raise StopIteration 3165 self.len -= 1 3166 return torch.randn(self.size) 3167 3168try: 3169 keep_fds_alive = [] 3170 resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100)) 3171 for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork", 3172 num_workers=1, persistent_workers=True): 3173 random_t.max(dim=0) 3174 keep_fds_alive.append(random_t) 3175except RuntimeError as e: 3176 assert "ulimit -n" in str(e) 3177 assert "set_sharing_strategy" in str(e) 3178""", 3179 ] 3180 ) 3181 3182 def test_dataset_not_reset(self): 3183 dataset = DummyDataset() 3184 pin_memory_configs = [False] 3185 if TEST_CUDA: 3186 pin_memory_configs.append(True) 3187 for pin_memory in pin_memory_configs: 3188 dataloader = self._get_data_loader( 3189 dataset, num_workers=2, pin_memory=pin_memory 3190 ) 3191 dataset.start = 0 3192 for i in range(10): 3193 for x in dataloader: 3194 pass 3195 # Changing the start value here doesn't have any effect in the dataset 3196 # cached by the workers. since they are not recreated between epochs 3197 # and can cache values safely 3198 dataset.start = i 3199 3200 @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") 3201 @unittest.skipIf(IS_WINDOWS, "Needs fork") 3202 def test_early_exit(self): 3203 import subprocess 3204 3205 proc = subprocess.check_output( 3206 [ 3207 sys.executable, 3208 "-c", 3209 """\ 3210import torch 3211from torch.utils.data import DataLoader, IterableDataset 3212 3213class RandomDataset(IterableDataset): 3214 def __init__(self, len, size): 3215 super(RandomDataset).__init__() 3216 self.len = len 3217 self.size = size 3218 3219 def __iter__(self): 3220 return self 3221 3222 def __next__(self): 3223 if self.len <= 0: 3224 raise StopIteration 3225 self.len -= 1 3226 return torch.randn(self.size) 3227 3228if __name__ == '__main__': 3229 dl = DataLoader( 3230 RandomDataset(64, (28, 28)), 3231 batch_size=16, 3232 num_workers=2, 3233 pin_memory=True, 3234 persistent_workers=True, 3235 multiprocessing_context="fork", 3236 ) 3237 3238 for _ in dl: 3239 break 3240""", 3241 ] 3242 ) 3243 3244 3245class NamedTupleDataset(Dataset): 3246 from collections import namedtuple 3247 3248 Batch = namedtuple("Batch", ["data", "label", "random_tensor"]) 3249 Data = namedtuple("Data", ["positive", "negative"]) 3250 3251 def __len__(self): 3252 return 4 3253 3254 def __getitem__(self, ndx): 3255 return self.Batch( 3256 data=self.Data(positive=ndx, negative=-ndx), 3257 label=str(ndx), 3258 random_tensor=torch.randn(3), 3259 ) 3260 3261 3262@unittest.skipIf( 3263 TEST_WITH_TSAN, 3264 "Fails with TSAN with the following error: starting new threads after multi-threaded " 3265 "fork is not supported. Dying (set die_after_fork=0 to override)", 3266) 3267class TestNamedTupleDataLoader(TestCase): 3268 def setUp(self): 3269 super().setUp() 3270 self.dataset = NamedTupleDataset() 3271 3272 def test_dataloader_with_namedtuple(self): 3273 # auto-collation 3274 loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA) 3275 for batch in loader: 3276 self.assertIsInstance(batch, NamedTupleDataset.Batch) 3277 self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) 3278 self.assertIsInstance(batch.data, NamedTupleDataset.Data) 3279 self.assertIsInstance(batch.data.positive, torch.Tensor) 3280 self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA) 3281 # no auto-collation 3282 loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA) 3283 for batch in loader: 3284 self.assertIsInstance(batch, NamedTupleDataset.Batch) 3285 self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) 3286 self.assertIsInstance(batch.data, NamedTupleDataset.Data) 3287 self.assertNotIsInstance(batch.data.positive, torch.Tensor) 3288 3289 3290class SimpleCustomBatch: 3291 def __init__(self, data): 3292 transposed_data = list(zip(*data)) 3293 self.inp = torch.stack(transposed_data[0], 0) 3294 self.tgt = torch.stack(transposed_data[1], 0) 3295 3296 def pin_memory(self): 3297 self.inp = self.inp.pin_memory() 3298 self.tgt = self.tgt.pin_memory() 3299 return self 3300 3301 def is_pinned(self): 3302 return self.inp.is_pinned() and self.tgt.is_pinned() 3303 3304 3305# Workaround for https://github.com/pytorch/pytorch/issues/50661 3306# Classes from `__main__` can not be correctly unpickled from spawned module 3307# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming 3308self_module = __import__(os.path.splitext(os.path.basename(__file__))[0]) 3309 3310 3311def collate_wrapper(batch): 3312 return self_module.SimpleCustomBatch(batch) 3313 3314 3315def collate_into_packed_sequence(batch): 3316 data = torch.stack([sample[0] for sample in batch], 1) 3317 t, b = data.size() 3318 lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) 3319 return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False) 3320 3321 3322def collate_into_packed_sequence_batch_first(batch): 3323 data = torch.stack([sample[0] for sample in batch], 0) 3324 b, t = data.size() 3325 lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) 3326 return torch.nn.utils.rnn.pack_padded_sequence( 3327 data, lengths, batch_first=True, enforce_sorted=False 3328 ) 3329 3330 3331@unittest.skipIf( 3332 TEST_WITH_TSAN, 3333 "Fails with TSAN with the following error: starting new threads after multi-threaded " 3334 "fork is not supported. Dying (set die_after_fork=0 to override)", 3335) 3336class TestCustomPinFn(TestCase): 3337 def setUp(self): 3338 super().setUp() 3339 inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 3340 tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 3341 self.dataset = TensorDataset(inps, tgts) 3342 3343 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3344 def test_custom_batch_pin(self): 3345 test_cases = [ 3346 (collate_wrapper, self_module.SimpleCustomBatch), 3347 (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), 3348 ( 3349 collate_into_packed_sequence_batch_first, 3350 torch.nn.utils.rnn.PackedSequence, 3351 ), 3352 ] 3353 for collate_fn, elem_cls in test_cases: 3354 loader = DataLoader( 3355 self.dataset, batch_size=2, collate_fn=collate_fn, pin_memory=True 3356 ) 3357 for sample in loader: 3358 self.assertIsInstance(sample, elem_cls) 3359 self.assertTrue(sample.is_pinned()) 3360 3361 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 3362 def test_custom_batch_pin_worker(self): 3363 test_cases = [ 3364 (collate_wrapper, self_module.SimpleCustomBatch), 3365 (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), 3366 ( 3367 collate_into_packed_sequence_batch_first, 3368 torch.nn.utils.rnn.PackedSequence, 3369 ), 3370 ] 3371 for collate_fn, elem_cls in test_cases: 3372 loader = DataLoader( 3373 self.dataset, 3374 batch_size=2, 3375 collate_fn=collate_fn, 3376 pin_memory=True, 3377 num_workers=1, 3378 ) 3379 for sample in loader: 3380 self.assertIsInstance(sample, elem_cls) 3381 self.assertTrue(sample.is_pinned()) 3382 3383 3384class TestWorkerQueueDataset(Dataset): 3385 def __init__(self, data): 3386 self.data = data 3387 self.worker_id = None 3388 3389 def worker_init_fn(self, worker_id): 3390 self.worker_id = worker_id 3391 3392 def __getitem__(self, item): 3393 return self.worker_id, self.data[item] 3394 3395 def __len__(self): 3396 return len(self.data) 3397 3398 3399@unittest.skipIf( 3400 TEST_WITH_TSAN, 3401 "Fails with TSAN with the following error: starting new threads after multi-threaded " 3402 "fork is not supported. Dying (set die_after_fork=0 to override)", 3403) 3404@unittest.skipIf( 3405 TEST_WITH_ASAN, 3406 "Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727", 3407) 3408class TestIndividualWorkerQueue(TestCase): 3409 def setUp(self): 3410 super().setUp() 3411 self.dataset = TestWorkerQueueDataset(list(range(128))) 3412 3413 def _run_ind_worker_queue_test(self, batch_size, num_workers): 3414 loader = DataLoader( 3415 self.dataset, 3416 batch_size=batch_size, 3417 shuffle=False, 3418 num_workers=num_workers, 3419 timeout=5, 3420 worker_init_fn=self.dataset.worker_init_fn, 3421 ) 3422 current_worker_idx = 0 3423 for i, (worker_ids, sample) in enumerate(loader): 3424 self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size) 3425 self.assertEqual( 3426 sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size)) 3427 ) 3428 current_worker_idx += 1 3429 if current_worker_idx == num_workers: 3430 current_worker_idx = 0 3431 3432 def test_ind_worker_queue(self): 3433 max_num_workers = None 3434 if hasattr(os, "sched_getaffinity"): 3435 try: 3436 max_num_workers = len(os.sched_getaffinity(0)) 3437 except Exception: 3438 pass 3439 if max_num_workers is None: 3440 cpu_count = os.cpu_count() 3441 if cpu_count is not None: 3442 # Use half number of CPUs 3443 max_num_workers = cpu_count // 2 3444 3445 if max_num_workers is None: 3446 max_num_workers = 1 3447 3448 for batch_size in (8, 16, 32, 64): 3449 for num_workers in range(0, min(6, max_num_workers)): 3450 self._run_ind_worker_queue_test( 3451 batch_size=batch_size, num_workers=num_workers + 1 3452 ) 3453 3454 3455class SetAffinityDataset(IterableDataset): 3456 def __iter__(self): 3457 torch.randperm(1) 3458 after = os.sched_getaffinity(0) 3459 return iter(after) 3460 3461 3462@unittest.skipIf( 3463 not hasattr(os, "sched_setaffinity"), "os.sched_setaffinity is not available" 3464) 3465class TestSetAffinity(TestCase): 3466 def test_set_affinity_in_worker_init(self): 3467 # Query the current affinity mask to avoid setting a disallowed one 3468 old_affinity = os.sched_getaffinity(0) 3469 if not old_affinity: 3470 self.skipTest("No affinity information") 3471 # Choose any 3472 expected_affinity = list(old_affinity)[-1] 3473 3474 def worker_set_affinity(_): 3475 os.sched_setaffinity(0, [expected_affinity]) 3476 3477 dataset = SetAffinityDataset() 3478 3479 dataloader = torch.utils.data.DataLoader( 3480 dataset, num_workers=2, worker_init_fn=worker_set_affinity 3481 ) 3482 for sample in dataloader: 3483 self.assertEqual(sample, [expected_affinity]) 3484 3485 3486class ConvDataset(Dataset): 3487 def __init__(self) -> None: 3488 self.x = torch.ones(1, 1, 24000) 3489 # Call convolution on parent process 3490 self[0] 3491 3492 def __len__(self): 3493 return 1 3494 3495 def __getitem__(self, index): 3496 return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2)) 3497 3498 3499@unittest.skipIf(IS_WINDOWS, "Needs fork") 3500@unittest.skipIf( 3501 TEST_WITH_ASAN, 3502 "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492", 3503) 3504class TestConvAfterFork(TestCase): 3505 # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565 3506 def test_conv_after_fork(self): 3507 loader = DataLoader(ConvDataset(), num_workers=1) 3508 for x in loader: 3509 self.assertEqual(x.shape, (1, 1, 1, 23999)) 3510 3511 3512instantiate_device_type_tests(TestDataLoaderDeviceType, globals()) 3513 3514 3515if __name__ == "__main__": 3516 run_tests() 3517