1# mypy: ignore-errors 2 3# Owner(s): ["module: dataloader"] 4 5import copy 6import itertools 7import os 8import os.path 9import pickle 10import pydoc 11import random 12import sys 13import tempfile 14import warnings 15from functools import partial 16from typing import ( 17 Any, 18 Awaitable, 19 Dict, 20 Generic, 21 Iterator, 22 List, 23 Optional, 24 Set, 25 Tuple, 26 Type, 27 TYPE_CHECKING, 28 TypeVar, 29 Union, 30) 31 32if not TYPE_CHECKING: 33 # pyre isn't treating this the same as a typing.NamedTuple 34 from typing_extensions import NamedTuple 35else: 36 from typing import NamedTuple 37 38import operator 39from unittest import skipIf 40 41import numpy as np 42 43import torch 44import torch.nn as nn 45import torch.utils.data.datapipes as dp 46import torch.utils.data.graph 47import torch.utils.data.graph_settings 48from torch.testing._internal.common_utils import ( 49 run_tests, 50 skipIfNoDill, 51 skipIfTorchDynamo, 52 suppress_warnings, 53 TEST_DILL, 54 TestCase, 55) 56from torch.utils._import_utils import import_dill 57from torch.utils.data import ( 58 argument_validation, 59 DataChunk, 60 DataLoader, 61 IterDataPipe, 62 MapDataPipe, 63 RandomSampler, 64 runtime_validation, 65 runtime_validation_disabled, 66) 67from torch.utils.data.datapipes.dataframe import ( 68 CaptureDataFrame, 69 dataframe_wrapper as df_wrapper, 70) 71from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES 72from torch.utils.data.datapipes.utils.common import StreamWrapper 73from torch.utils.data.datapipes.utils.decoder import ( 74 basichandlers as decoder_basichandlers, 75) 76from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration 77from torch.utils.data.graph import traverse_dps 78 79dill = import_dill() 80HAS_DILL = TEST_DILL 81 82try: 83 import pandas # type: ignore[import] # noqa: F401 F403 84 85 HAS_PANDAS = True 86except ImportError: 87 HAS_PANDAS = False 88skipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)") 89 90skipTyping = skipIf(True, "TODO: Fix typing bug") 91T_co = TypeVar("T_co", covariant=True) 92 93 94def create_temp_dir_and_files(): 95 # The temp dir and files within it will be released and deleted in tearDown(). 96 # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. 97 temp_dir = tempfile.TemporaryDirectory() # noqa: P201 98 temp_dir_path = temp_dir.name 99 with tempfile.NamedTemporaryFile( 100 dir=temp_dir_path, delete=False, suffix=".txt" 101 ) as f: 102 temp_file1_name = f.name 103 with tempfile.NamedTemporaryFile( 104 dir=temp_dir_path, delete=False, suffix=".byte" 105 ) as f: 106 temp_file2_name = f.name 107 with tempfile.NamedTemporaryFile( 108 dir=temp_dir_path, delete=False, suffix=".empty" 109 ) as f: 110 temp_file3_name = f.name 111 112 with open(temp_file1_name, "w") as f1: 113 f1.write("0123456789abcdef") 114 with open(temp_file2_name, "wb") as f2: 115 f2.write(b"0123456789abcdef") 116 117 temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path) # noqa: P201 118 temp_sub_dir_path = temp_sub_dir.name 119 with tempfile.NamedTemporaryFile( 120 dir=temp_sub_dir_path, delete=False, suffix=".txt" 121 ) as f: 122 temp_sub_file1_name = f.name 123 with tempfile.NamedTemporaryFile( 124 dir=temp_sub_dir_path, delete=False, suffix=".byte" 125 ) as f: 126 temp_sub_file2_name = f.name 127 128 with open(temp_sub_file1_name, "w") as f1: 129 f1.write("0123456789abcdef") 130 with open(temp_sub_file2_name, "wb") as f2: 131 f2.write(b"0123456789abcdef") 132 133 return [ 134 (temp_dir, temp_file1_name, temp_file2_name, temp_file3_name), 135 (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name), 136 ] 137 138 139def reset_after_n_next_calls( 140 datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], n: int 141) -> Tuple[List[T_co], List[T_co]]: 142 """ 143 Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list 144 Then, reset the DataPipe and return a tuple of two lists 145 1. A list of elements yielded before the reset 146 2. A list of all elements of the DataPipe after the reset 147 """ 148 it = iter(datapipe) 149 res_before_reset = [] 150 for _ in range(n): 151 res_before_reset.append(next(it)) 152 return res_before_reset, list(datapipe) 153 154 155def odd_or_even(x: int) -> int: 156 return x % 2 157 158 159class TestDataChunk(TestCase): 160 def setUp(self): 161 self.elements = list(range(10)) 162 random.shuffle(self.elements) 163 self.chunk: DataChunk[int] = DataChunk(self.elements) 164 165 def test_getitem(self): 166 for i in range(10): 167 self.assertEqual(self.elements[i], self.chunk[i]) 168 169 def test_iter(self): 170 for ele, dc in zip(self.elements, iter(self.chunk)): 171 self.assertEqual(ele, dc) 172 173 def test_len(self): 174 self.assertEqual(len(self.elements), len(self.chunk)) 175 176 def test_as_string(self): 177 self.assertEqual(str(self.chunk), str(self.elements)) 178 179 batch = [self.elements] * 3 180 chunks: List[DataChunk[int]] = [DataChunk(self.elements)] * 3 181 self.assertEqual(str(batch), str(chunks)) 182 183 def test_sort(self): 184 chunk: DataChunk[int] = DataChunk(self.elements) 185 chunk.sort() 186 self.assertTrue(isinstance(chunk, DataChunk)) 187 for i, d in enumerate(chunk): 188 self.assertEqual(i, d) 189 190 def test_reverse(self): 191 chunk: DataChunk[int] = DataChunk(self.elements) 192 chunk.reverse() 193 self.assertTrue(isinstance(chunk, DataChunk)) 194 for i in range(10): 195 self.assertEqual(chunk[i], self.elements[9 - i]) 196 197 def test_random_shuffle(self): 198 elements = list(range(10)) 199 chunk: DataChunk[int] = DataChunk(elements) 200 201 rng = random.Random(0) 202 rng.shuffle(chunk) 203 204 rng = random.Random(0) 205 rng.shuffle(elements) 206 207 self.assertEqual(chunk, elements) 208 209 210class TestStreamWrapper(TestCase): 211 class _FakeFD: 212 def __init__(self, filepath): 213 self.filepath = filepath 214 self.opened = False 215 self.closed = False 216 217 def open(self): 218 self.opened = True 219 220 def read(self): 221 if self.opened: 222 return "".join(self) 223 else: 224 raise OSError("Cannot read from un-opened file descriptor") 225 226 def __iter__(self): 227 for i in range(5): 228 yield str(i) 229 230 def close(self): 231 if self.opened: 232 self.opened = False 233 self.closed = True 234 235 def __repr__(self): 236 return "FakeFD" 237 238 def test_dir(self): 239 fd = TestStreamWrapper._FakeFD("") 240 wrap_fd = StreamWrapper(fd) 241 242 s = set(dir(wrap_fd)) 243 for api in ["open", "read", "close"]: 244 self.assertTrue(api in s) 245 246 @skipIfTorchDynamo() 247 def test_api(self): 248 fd = TestStreamWrapper._FakeFD("") 249 wrap_fd = StreamWrapper(fd) 250 251 self.assertFalse(fd.opened) 252 self.assertFalse(fd.closed) 253 with self.assertRaisesRegex(IOError, "Cannot read from"): 254 wrap_fd.read() 255 256 wrap_fd.open() 257 self.assertTrue(fd.opened) 258 self.assertEqual("01234", wrap_fd.read()) 259 260 del wrap_fd 261 self.assertFalse(fd.opened) 262 self.assertTrue(fd.closed) 263 264 def test_pickle(self): 265 with tempfile.TemporaryFile() as f: 266 with self.assertRaises(TypeError) as ctx1: 267 pickle.dumps(f) 268 269 wrap_f = StreamWrapper(f) 270 with self.assertRaises(TypeError) as ctx2: 271 pickle.dumps(wrap_f) 272 273 # Same exception when pickle 274 self.assertEqual(str(ctx1.exception), str(ctx2.exception)) 275 276 fd = TestStreamWrapper._FakeFD("") 277 wrap_fd = StreamWrapper(fd) 278 _ = pickle.loads(pickle.dumps(wrap_fd)) 279 280 def test_repr(self): 281 fd = TestStreamWrapper._FakeFD("") 282 wrap_fd = StreamWrapper(fd) 283 self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>") 284 285 with tempfile.TemporaryFile() as f: 286 wrap_f = StreamWrapper(f) 287 self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">") 288 289 290class TestIterableDataPipeBasic(TestCase): 291 def setUp(self): 292 ret = create_temp_dir_and_files() 293 self.temp_dir = ret[0][0] 294 self.temp_files = ret[0][1:] 295 self.temp_sub_dir = ret[1][0] 296 self.temp_sub_files = ret[1][1:] 297 298 def tearDown(self): 299 try: 300 self.temp_sub_dir.cleanup() 301 self.temp_dir.cleanup() 302 except Exception as e: 303 warnings.warn( 304 f"TestIterableDatasetBasic was not able to cleanup temp dir due to {str(e)}" 305 ) 306 307 def test_listdirfiles_iterable_datapipe(self): 308 temp_dir = self.temp_dir.name 309 datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, "") 310 311 count = 0 312 for pathname in datapipe: 313 count = count + 1 314 self.assertTrue(pathname in self.temp_files) 315 self.assertEqual(count, len(self.temp_files)) 316 317 count = 0 318 datapipe = dp.iter.FileLister(temp_dir, "", recursive=True) 319 for pathname in datapipe: 320 count = count + 1 321 self.assertTrue( 322 (pathname in self.temp_files) or (pathname in self.temp_sub_files) 323 ) 324 self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files)) 325 326 temp_files = self.temp_files 327 datapipe = dp.iter.FileLister([temp_dir, *temp_files]) 328 count = 0 329 for pathname in datapipe: 330 count += 1 331 self.assertTrue(pathname in self.temp_files) 332 self.assertEqual(count, 2 * len(self.temp_files)) 333 334 # test functional API 335 datapipe = datapipe.list_files() 336 count = 0 337 for pathname in datapipe: 338 count += 1 339 self.assertTrue(pathname in self.temp_files) 340 self.assertEqual(count, 2 * len(self.temp_files)) 341 342 def test_listdirfilesdeterministic_iterable_datapipe(self): 343 temp_dir = self.temp_dir.name 344 345 datapipe = dp.iter.FileLister(temp_dir, "") 346 # The output order should be always the same. 347 self.assertEqual(list(datapipe), list(datapipe)) 348 349 datapipe = dp.iter.FileLister(temp_dir, "", recursive=True) 350 # The output order should be always the same. 351 self.assertEqual(list(datapipe), list(datapipe)) 352 353 def test_openfilesfromdisk_iterable_datapipe(self): 354 # test import datapipe class directly 355 from torch.utils.data.datapipes.iter import FileLister, FileOpener 356 357 temp_dir = self.temp_dir.name 358 datapipe1 = FileLister(temp_dir, "") 359 datapipe2 = FileOpener(datapipe1, mode="b") 360 361 count = 0 362 for rec in datapipe2: 363 count = count + 1 364 self.assertTrue(rec[0] in self.temp_files) 365 with open(rec[0], "rb") as f: 366 self.assertEqual(rec[1].read(), f.read()) 367 rec[1].close() 368 self.assertEqual(count, len(self.temp_files)) 369 370 # functional API 371 datapipe3 = datapipe1.open_files(mode="b") 372 373 count = 0 374 for rec in datapipe3: 375 count = count + 1 376 self.assertTrue(rec[0] in self.temp_files) 377 with open(rec[0], "rb") as f: 378 self.assertEqual(rec[1].read(), f.read()) 379 rec[1].close() 380 self.assertEqual(count, len(self.temp_files)) 381 382 # __len__ Test 383 with self.assertRaises(TypeError): 384 len(datapipe3) 385 386 def test_routeddecoder_iterable_datapipe(self): 387 temp_dir = self.temp_dir.name 388 temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png") 389 png_data = np.array( 390 [[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], 391 dtype=np.single, 392 ) 393 np.save(temp_pngfile_pathname, png_data) 394 datapipe1 = dp.iter.FileLister(temp_dir, ["*.png", "*.txt"]) 395 datapipe2 = dp.iter.FileOpener(datapipe1, mode="b") 396 397 def _png_decoder(extension, data): 398 if extension != "png": 399 return None 400 return np.load(data) 401 402 def _helper(prior_dp, dp, channel_first=False): 403 # Byte stream is not closed 404 for inp in prior_dp: 405 self.assertFalse(inp[1].closed) 406 for inp, rec in zip(prior_dp, dp): 407 ext = os.path.splitext(rec[0])[1] 408 if ext == ".png": 409 expected = np.array( 410 [ 411 [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], 412 [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], 413 ], 414 dtype=np.single, 415 ) 416 if channel_first: 417 expected = expected.transpose(2, 0, 1) 418 self.assertEqual(rec[1], expected) 419 else: 420 with open(rec[0], "rb") as f: 421 self.assertEqual(rec[1], f.read().decode("utf-8")) 422 # Corresponding byte stream is closed by Decoder 423 self.assertTrue(inp[1].closed) 424 425 cached = list(datapipe2) 426 with warnings.catch_warnings(record=True) as wa: 427 datapipe3 = dp.iter.RoutedDecoder(cached, _png_decoder) 428 datapipe3.add_handler(decoder_basichandlers) 429 _helper(cached, datapipe3) 430 431 cached = list(datapipe2) 432 with warnings.catch_warnings(record=True) as wa: 433 datapipe4 = dp.iter.RoutedDecoder(cached, decoder_basichandlers) 434 datapipe4.add_handler(_png_decoder) 435 _helper(cached, datapipe4, channel_first=True) 436 437 def test_groupby_iterable_datapipe(self): 438 file_list = [ 439 "a.png", 440 "b.png", 441 "c.json", 442 "a.json", 443 "c.png", 444 "b.json", 445 "d.png", 446 "d.json", 447 "e.png", 448 "f.json", 449 "g.png", 450 "f.png", 451 "g.json", 452 "e.json", 453 "h.txt", 454 "h.json", 455 ] 456 457 import io 458 459 datapipe1 = dp.iter.IterableWrapper( 460 [(filename, io.BytesIO(b"12345abcde")) for filename in file_list] 461 ) 462 463 def group_fn(data): 464 filepath, _ = data 465 return os.path.basename(filepath).split(".")[0] 466 467 datapipe2 = dp.iter.Grouper(datapipe1, group_key_fn=group_fn, group_size=2) 468 469 def order_fn(data): 470 data.sort(key=lambda f: f[0], reverse=True) 471 return data 472 473 datapipe3 = dp.iter.Mapper(datapipe2, fn=order_fn) # type: ignore[var-annotated] 474 475 expected_result = [ 476 ("a.png", "a.json"), 477 ("c.png", "c.json"), 478 ("b.png", "b.json"), 479 ("d.png", "d.json"), 480 ("f.png", "f.json"), 481 ("g.png", "g.json"), 482 ("e.png", "e.json"), 483 ("h.txt", "h.json"), 484 ] 485 486 count = 0 487 for rec, expected in zip(datapipe3, expected_result): 488 count = count + 1 489 self.assertEqual(os.path.basename(rec[0][0]), expected[0]) 490 self.assertEqual(os.path.basename(rec[1][0]), expected[1]) 491 for i in [0, 1]: 492 self.assertEqual(rec[i][1].read(), b"12345abcde") 493 rec[i][1].close() 494 self.assertEqual(count, 8) 495 496 # testing the keep_key option 497 datapipe4 = dp.iter.Grouper( 498 datapipe1, group_key_fn=group_fn, keep_key=True, group_size=2 499 ) 500 501 def order_fn(data): 502 data[1].sort(key=lambda f: f[0], reverse=True) 503 return data 504 505 datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn) # type: ignore[var-annotated] 506 507 expected_result = [ 508 ("a", ("a.png", "a.json")), 509 ("c", ("c.png", "c.json")), 510 ("b", ("b.png", "b.json")), 511 ("d", ("d.png", "d.json")), 512 ("f", ("f.png", "f.json")), 513 ("g", ("g.png", "g.json")), 514 ("e", ("e.png", "e.json")), 515 ("h", ("h.txt", "h.json")), 516 ] 517 518 count = 0 519 for rec, expected in zip(datapipe5, expected_result): 520 count = count + 1 521 self.assertEqual(rec[0], expected[0]) 522 self.assertEqual(rec[1][0][0], expected[1][0]) 523 self.assertEqual(rec[1][1][0], expected[1][1]) 524 for i in [0, 1]: 525 self.assertEqual(rec[1][i][1].read(), b"12345abcde") 526 rec[1][i][1].close() 527 self.assertEqual(count, 8) 528 529 def test_demux_mux_datapipe(self): 530 numbers = NumbersDataset(10) 531 n1, n2 = numbers.demux(2, lambda x: x % 2) 532 self.assertEqual([0, 2, 4, 6, 8], list(n1)) 533 self.assertEqual([1, 3, 5, 7, 9], list(n2)) 534 535 # Functional Test: demux and mux works sequentially as expected 536 numbers = NumbersDataset(10) 537 n1, n2, n3 = numbers.demux(3, lambda x: x % 3) 538 n = n1.mux(n2, n3) 539 self.assertEqual(list(range(9)), list(n)) 540 541 # Functional Test: Uneven DataPipes 542 source_numbers = list(range(0, 10)) + [10, 12] 543 numbers_dp = dp.iter.IterableWrapper(source_numbers) 544 n1, n2 = numbers_dp.demux(2, lambda x: x % 2) 545 self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1)) 546 self.assertEqual([1, 3, 5, 7, 9], list(n2)) 547 n = n1.mux(n2) 548 self.assertEqual(list(range(10)), list(n)) 549 550 @suppress_warnings # Suppress warning for lambda fn 551 def test_map_with_col_file_handle_datapipe(self): 552 temp_dir = self.temp_dir.name 553 datapipe1 = dp.iter.FileLister(temp_dir, "") 554 datapipe2 = dp.iter.FileOpener(datapipe1) 555 556 def _helper(datapipe): 557 dp1 = datapipe.map(lambda x: x.read(), input_col=1) 558 dp2 = datapipe.map(lambda x: (x[0], x[1].read())) 559 self.assertEqual(list(dp1), list(dp2)) 560 561 # tuple 562 _helper(datapipe2) 563 # list 564 datapipe3 = datapipe2.map(lambda x: list(x)) 565 _helper(datapipe3) 566 567 568@skipIfNoDataFrames 569class TestCaptureDataFrame(TestCase): 570 def get_new_df(self): 571 return df_wrapper.create_dataframe([[1, 2]], columns=["a", "b"]) 572 573 def compare_capture_and_eager(self, operations): 574 cdf = CaptureDataFrame() 575 cdf = operations(cdf) 576 df = self.get_new_df() 577 cdf = cdf.apply_ops(df) 578 579 df = self.get_new_df() 580 df = operations(df) 581 582 self.assertTrue(df.equals(cdf)) 583 584 def test_basic_capture(self): 585 def operations(df): 586 df["c"] = df.b + df["a"] * 7 587 # somehow swallows pandas UserWarning when `df.c = df.b + df['a'] * 7` 588 return df 589 590 self.compare_capture_and_eager(operations) 591 592 593class TestDataFramesPipes(TestCase): 594 """ 595 Most of test will fail if pandas instaled, but no dill available. 596 Need to rework them to avoid multiple skips. 597 """ 598 599 def _get_datapipe(self, range=10, dataframe_size=7): 600 return NumbersDataset(range).map(lambda i: (i, i % 3)) 601 602 def _get_dataframes_pipe(self, range=10, dataframe_size=7): 603 return ( 604 NumbersDataset(range) 605 .map(lambda i: (i, i % 3)) 606 ._to_dataframes_pipe(columns=["i", "j"], dataframe_size=dataframe_size) 607 ) 608 609 @skipIfNoDataFrames 610 @skipIfNoDill # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map 611 def test_capture(self): 612 dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0])) 613 df_numbers = self._get_dataframes_pipe() 614 df_numbers["k"] = df_numbers["j"] + df_numbers.i * 3 615 expected = list(dp_numbers) 616 actual = list(df_numbers) 617 self.assertEqual(expected, actual) 618 619 @skipIfNoDataFrames 620 @skipIfNoDill 621 def test_shuffle(self): 622 # With non-zero (but extremely low) probability (when shuffle do nothing), 623 # this test fails, so feel free to restart 624 df_numbers = self._get_dataframes_pipe(range=1000).shuffle() 625 dp_numbers = self._get_datapipe(range=1000) 626 df_result = [tuple(item) for item in df_numbers] 627 self.assertNotEqual(list(dp_numbers), df_result) 628 self.assertEqual(list(dp_numbers), sorted(df_result)) 629 630 @skipIfNoDataFrames 631 @skipIfNoDill 632 def test_batch(self): 633 df_numbers = self._get_dataframes_pipe(range=100).batch(8) 634 df_numbers_list = list(df_numbers) 635 last_batch = df_numbers_list[-1] 636 self.assertEqual(4, len(last_batch)) 637 unpacked_batch = [tuple(row) for row in last_batch] 638 self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch) 639 640 @skipIfNoDataFrames 641 @skipIfNoDill 642 def test_unbatch(self): 643 df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3) 644 dp_numbers = self._get_datapipe(range=100) 645 self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2))) 646 647 @skipIfNoDataFrames 648 @skipIfNoDill 649 def test_filter(self): 650 df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5) 651 actual = list(df_numbers) 652 self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], actual) 653 654 @skipIfNoDataFrames 655 @skipIfNoDill 656 def test_collate(self): 657 def collate_i(column): 658 return column.sum() 659 660 def collate_j(column): 661 return column.prod() 662 663 df_numbers = self._get_dataframes_pipe(range=30).batch(3) 664 df_numbers = df_numbers.collate({"j": collate_j, "i": collate_i}) 665 666 expected_i = [ 667 3, 668 12, 669 21, 670 30, 671 39, 672 48, 673 57, 674 66, 675 75, 676 84, 677 ] 678 679 actual_i = [] 680 for i, j in df_numbers: 681 actual_i.append(i) 682 self.assertEqual(expected_i, actual_i) 683 684 actual_i = [] 685 for item in df_numbers: 686 actual_i.append(item.i) 687 self.assertEqual(expected_i, actual_i) 688 689 690class IDP_NoLen(IterDataPipe): 691 def __init__(self, input_dp): 692 super().__init__() 693 self.input_dp = input_dp 694 695 # Prevent in-place modification 696 def __iter__(self): 697 input_dp = ( 698 self.input_dp 699 if isinstance(self.input_dp, IterDataPipe) 700 else copy.deepcopy(self.input_dp) 701 ) 702 yield from input_dp 703 704 705def _fake_fn(data): 706 return data 707 708 709def _fake_add(constant, data): 710 return constant + data 711 712 713def _fake_filter_fn(data): 714 return True 715 716 717def _simple_filter_fn(data): 718 return data >= 5 719 720 721def _fake_filter_fn_constant(constant, data): 722 return data >= constant 723 724 725def _mul_10(x): 726 return x * 10 727 728 729def _mod_3_test(x): 730 return x % 3 == 1 731 732 733def _to_list(x): 734 return [x] 735 736 737lambda_fn1 = lambda x: x # noqa: E731 738lambda_fn2 = lambda x: x % 2 # noqa: E731 739lambda_fn3 = lambda x: x >= 5 # noqa: E731 740 741 742class Add1Module(nn.Module): 743 def forward(self, x): 744 return x + 1 745 746 747class Add1Callable: 748 def __call__(self, x): 749 return x + 1 750 751 752class TestFunctionalIterDataPipe(TestCase): 753 def _serialization_test_helper(self, datapipe, use_dill): 754 if use_dill: 755 serialized_dp = dill.dumps(datapipe) 756 deserialized_dp = dill.loads(serialized_dp) 757 else: 758 serialized_dp = pickle.dumps(datapipe) 759 deserialized_dp = pickle.loads(serialized_dp) 760 try: 761 self.assertEqual(list(datapipe), list(deserialized_dp)) 762 except AssertionError as e: 763 print(f"{datapipe} is failing.") 764 raise e 765 766 def _serialization_test_for_single_dp(self, dp, use_dill=False): 767 # 1. Testing for serialization before any iteration starts 768 self._serialization_test_helper(dp, use_dill) 769 # 2. Testing for serialization after DataPipe is partially read 770 it = iter(dp) 771 _ = next(it) 772 self._serialization_test_helper(dp, use_dill) 773 # 3. Testing for serialization after DataPipe is fully read 774 it = iter(dp) 775 _ = list(it) 776 self._serialization_test_helper(dp, use_dill) 777 778 def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill=False): 779 # 1. Testing for serialization before any iteration starts 780 self._serialization_test_helper(dp1, use_dill) 781 self._serialization_test_helper(dp2, use_dill) 782 783 # 2. Testing for serialization after DataPipe is partially read 784 it1, it2 = iter(dp1), iter(dp2) 785 _, _ = next(it1), next(it2) 786 # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning 787 with warnings.catch_warnings(record=True) as wa: 788 self._serialization_test_helper(dp1, use_dill) 789 self._serialization_test_helper(dp2, use_dill) 790 791 # 2.5. Testing for serialization after one child DataPipe is fully read 792 # (Only for DataPipes with children DataPipes) 793 it1 = iter(dp1) 794 _ = list(it1) # fully read one child 795 # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning 796 with warnings.catch_warnings(record=True) as wa: 797 self._serialization_test_helper(dp1, use_dill) 798 self._serialization_test_helper(dp2, use_dill) 799 800 # 3. Testing for serialization after DataPipe is fully read 801 it2 = iter(dp2) 802 _ = list(it2) # fully read the other child 803 self._serialization_test_helper(dp1, use_dill) 804 self._serialization_test_helper(dp2, use_dill) 805 806 def test_serializable(self): 807 picklable_datapipes: List = [ 808 ( 809 dp.iter.Batcher, 810 None, 811 ( 812 3, 813 True, 814 ), 815 {}, 816 ), 817 (dp.iter.Collator, None, (_fake_fn,), {}), 818 (dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}), 819 (dp.iter.Demultiplexer, None, (2, _simple_filter_fn), {}), 820 (dp.iter.FileLister, ".", (), {}), 821 (dp.iter.FileOpener, None, (), {}), 822 (dp.iter.Filter, None, (_fake_filter_fn,), {}), 823 (dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}), 824 (dp.iter.Forker, None, (2,), {}), 825 (dp.iter.Forker, None, (2,), {"copy": "shallow"}), 826 (dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}), 827 (dp.iter.IterableWrapper, range(10), (), {}), 828 (dp.iter.Mapper, None, (_fake_fn,), {}), 829 (dp.iter.Mapper, None, (partial(_fake_add, 1),), {}), 830 (dp.iter.Multiplexer, None, (dp.iter.IterableWrapper(range(10)),), {}), 831 (dp.iter.Sampler, None, (), {}), 832 (dp.iter.Shuffler, dp.iter.IterableWrapper([0] * 10), (), {}), 833 (dp.iter.StreamReader, None, (), {}), 834 (dp.iter.UnBatcher, None, (0,), {}), 835 (dp.iter.Zipper, None, (dp.iter.IterableWrapper(range(10)),), {}), 836 ] 837 # Skipping comparison for these DataPipes 838 dp_skip_comparison = {dp.iter.FileOpener, dp.iter.StreamReader} 839 # These DataPipes produce multiple DataPipes as outputs and those should be compared 840 dp_compare_children = {dp.iter.Demultiplexer, dp.iter.Forker} 841 842 for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes: 843 if custom_input is None: 844 custom_input = dp.iter.IterableWrapper(range(10)) 845 if ( 846 dpipe in dp_skip_comparison 847 ): # Merely make sure they are picklable and loadable (no value comparison) 848 datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 849 serialized_dp = pickle.dumps(datapipe) 850 _ = pickle.loads(serialized_dp) 851 elif dpipe in dp_compare_children: # DataPipes that have children 852 dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 853 self._serialization_test_for_dp_with_children(dp1, dp2) 854 else: # Single DataPipe that requires comparison 855 datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 856 self._serialization_test_for_single_dp(datapipe) 857 858 @skipIfTorchDynamo("Dict with function as keys") 859 def test_serializable_with_dill(self): 860 """Only for DataPipes that take in a function as argument""" 861 input_dp = dp.iter.IterableWrapper(range(10)) 862 863 datapipes_with_lambda_fn: List[ 864 Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]] 865 ] = [ 866 (dp.iter.Collator, (lambda_fn1,), {}), 867 ( 868 dp.iter.Demultiplexer, 869 ( 870 2, 871 lambda_fn2, 872 ), 873 {}, 874 ), 875 (dp.iter.Filter, (lambda_fn3,), {}), 876 (dp.iter.Grouper, (lambda_fn3,), {}), 877 (dp.iter.Mapper, (lambda_fn1,), {}), 878 ] 879 880 def _local_fns(): 881 def _fn1(x): 882 return x 883 884 def _fn2(x): 885 return x % 2 886 887 def _fn3(x): 888 return x >= 5 889 890 return _fn1, _fn2, _fn3 891 892 fn1, fn2, fn3 = _local_fns() 893 894 datapipes_with_local_fn: List[ 895 Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]] 896 ] = [ 897 (dp.iter.Collator, (fn1,), {}), 898 ( 899 dp.iter.Demultiplexer, 900 ( 901 2, 902 fn2, 903 ), 904 {}, 905 ), 906 (dp.iter.Filter, (fn3,), {}), 907 (dp.iter.Grouper, (fn3,), {}), 908 (dp.iter.Mapper, (fn1,), {}), 909 ] 910 911 dp_compare_children = {dp.iter.Demultiplexer} 912 913 if HAS_DILL: 914 for dpipe, dp_args, dp_kwargs in ( 915 datapipes_with_lambda_fn + datapipes_with_local_fn 916 ): 917 if dpipe in dp_compare_children: 918 dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 919 self._serialization_test_for_dp_with_children( 920 dp1, dp2, use_dill=True 921 ) 922 else: 923 datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 924 self._serialization_test_for_single_dp(datapipe, use_dill=True) 925 else: 926 msgs = ( 927 r"^Lambda function is not supported by pickle", 928 r"^Local function is not supported by pickle", 929 ) 930 for dps, msg in zip( 931 (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs 932 ): 933 for dpipe, dp_args, dp_kwargs in dps: 934 with self.assertWarnsRegex(UserWarning, msg): 935 datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 936 with self.assertRaises((pickle.PicklingError, AttributeError)): 937 pickle.dumps(datapipe) 938 939 def test_docstring(self): 940 """ 941 Ensure functional form of IterDataPipe has the correct docstring from 942 the class form. 943 944 Regression test for https://github.com/pytorch/data/issues/792. 945 """ 946 input_dp = dp.iter.IterableWrapper(range(10)) 947 948 for dp_funcname in [ 949 "batch", 950 "collate", 951 "concat", 952 "demux", 953 "filter", 954 "fork", 955 "map", 956 "mux", 957 "read_from_stream", 958 # "sampler", 959 "shuffle", 960 "unbatch", 961 "zip", 962 ]: 963 if sys.version_info >= (3, 9): 964 docstring = pydoc.render_doc( 965 thing=getattr(input_dp, dp_funcname), forceload=True 966 ) 967 elif sys.version_info < (3, 9): 968 # pydoc works differently on Python 3.8, see 969 # https://docs.python.org/3/whatsnew/3.9.html#pydoc 970 docstring = getattr(input_dp, dp_funcname).__doc__ 971 972 assert f"(functional name: ``{dp_funcname}``)" in docstring 973 assert "Args:" in docstring 974 assert "Example:" in docstring or "Examples:" in docstring 975 976 def test_iterable_wrapper_datapipe(self): 977 input_ls = list(range(10)) 978 input_dp = dp.iter.IterableWrapper(input_ls) 979 980 # Functional Test: values are unchanged and in the same order 981 self.assertEqual(input_ls, list(input_dp)) 982 983 # Functional Test: deep copy by default when an iterator is initialized (first element is read) 984 it = iter(input_dp) 985 self.assertEqual( 986 0, next(it) 987 ) # The deep copy only happens when the first element is read 988 input_ls.append(50) 989 self.assertEqual(list(range(1, 10)), list(it)) 990 991 # Functional Test: shallow copy 992 input_ls2 = [1, 2, 3] 993 input_dp_shallow = dp.iter.IterableWrapper(input_ls2, deepcopy=False) 994 input_ls2.append(10) 995 self.assertEqual([1, 2, 3, 10], list(input_dp_shallow)) 996 997 # Reset Test: reset the DataPipe 998 input_ls = list(range(10)) 999 input_dp = dp.iter.IterableWrapper(input_ls) 1000 n_elements_before_reset = 5 1001 res_before_reset, res_after_reset = reset_after_n_next_calls( 1002 input_dp, n_elements_before_reset 1003 ) 1004 self.assertEqual(input_ls[:n_elements_before_reset], res_before_reset) 1005 self.assertEqual(input_ls, res_after_reset) 1006 1007 # __len__ Test: inherits length from sequence 1008 self.assertEqual(len(input_ls), len(input_dp)) 1009 1010 def test_concat_iterdatapipe(self): 1011 input_dp1 = dp.iter.IterableWrapper(range(10)) 1012 input_dp2 = dp.iter.IterableWrapper(range(5)) 1013 1014 # Functional Test: Raises exception for empty input 1015 with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): 1016 dp.iter.Concater() 1017 1018 # Functional Test: Raises exception for non-IterDataPipe input 1019 with self.assertRaisesRegex( 1020 TypeError, r"Expected all inputs to be `IterDataPipe`" 1021 ): 1022 dp.iter.Concater(input_dp1, ()) # type: ignore[arg-type] 1023 1024 # Functional Test: Concatenate DataPipes as expected 1025 concat_dp = input_dp1.concat(input_dp2) 1026 self.assertEqual(len(concat_dp), 15) 1027 self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) 1028 1029 # Reset Test: reset the DataPipe 1030 n_elements_before_reset = 5 1031 res_before_reset, res_after_reset = reset_after_n_next_calls( 1032 concat_dp, n_elements_before_reset 1033 ) 1034 self.assertEqual(list(range(5)), res_before_reset) 1035 self.assertEqual(list(range(10)) + list(range(5)), res_after_reset) 1036 1037 # __len__ Test: inherits length from source DataPipe 1038 input_dp_nl = IDP_NoLen(range(5)) 1039 concat_dp = input_dp1.concat(input_dp_nl) 1040 with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1041 len(concat_dp) 1042 1043 self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) 1044 1045 def test_fork_iterdatapipe(self): 1046 input_dp = dp.iter.IterableWrapper(range(10)) 1047 1048 with self.assertRaises(ValueError): 1049 input_dp.fork(num_instances=0) 1050 1051 dp0 = input_dp.fork(num_instances=1, buffer_size=0) 1052 self.assertEqual(dp0, input_dp) 1053 1054 # Functional Test: making sure all child DataPipe shares the same reference 1055 dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1056 self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3))) 1057 1058 # Functional Test: one child DataPipe yields all value at a time 1059 output1, output2, output3 = list(dp1), list(dp2), list(dp3) 1060 self.assertEqual(list(range(10)), output1) 1061 self.assertEqual(list(range(10)), output2) 1062 self.assertEqual(list(range(10)), output3) 1063 1064 # Functional Test: two child DataPipes yield value together 1065 dp1, dp2 = input_dp.fork(num_instances=2) 1066 output = [] 1067 for n1, n2 in zip(dp1, dp2): 1068 output.append((n1, n2)) 1069 self.assertEqual([(i, i) for i in range(10)], output) 1070 1071 # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small 1072 dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=4) 1073 it1 = iter(dp1) 1074 for _ in range(4): 1075 next(it1) 1076 with self.assertRaises(BufferError): 1077 next(it1) 1078 with self.assertRaises(BufferError): 1079 list(dp2) 1080 1081 dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5) 1082 with self.assertRaises(BufferError): 1083 list(dp2) 1084 1085 # Functional Test: one child DataPipe yields all value first with unlimited buffer 1086 with warnings.catch_warnings(record=True) as wa: 1087 dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1) 1088 self.assertEqual(len(wa), 1) 1089 self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set") 1090 l1, l2 = list(dp1), list(dp2) 1091 for d1, d2 in zip(l1, l2): 1092 self.assertEqual(d1, d2) 1093 1094 # Functional Test: two child DataPipes yield value together with buffer size 1 1095 dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1) 1096 output = [] 1097 for n1, n2 in zip(dp1, dp2): 1098 output.append((n1, n2)) 1099 self.assertEqual([(i, i) for i in range(10)], output) 1100 1101 # Functional Test: two child DataPipes yield shallow copies with copy equals shallow 1102 dp1, dp2 = input_dp.map(_to_list).fork(num_instances=2, copy="shallow") 1103 for n1, n2 in zip(dp1, dp2): 1104 self.assertIsNot(n1, n2) 1105 self.assertEqual(n1, n2) 1106 1107 # Functional Test: two child DataPipes yield deep copies with copy equals deep 1108 dp1, dp2 = ( 1109 input_dp.map(_to_list).map(_to_list).fork(num_instances=2, copy="deep") 1110 ) 1111 for n1, n2 in zip(dp1, dp2): 1112 self.assertIsNot(n1[0], n2[0]) 1113 self.assertEqual(n1, n2) 1114 1115 # Functional Test: fork DataPipe raises error for unknown copy method 1116 with self.assertRaises(ValueError): 1117 input_dp.fork(num_instances=2, copy="unknown") 1118 1119 # Functional Test: make sure logic related to slowest_ptr is working properly 1120 dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1121 output1, output2, output3 = [], [], [] 1122 for i, (n1, n2) in enumerate(zip(dp1, dp2)): 1123 output1.append(n1) 1124 output2.append(n2) 1125 if i == 4: # yield all of dp3 when halfway through dp1, dp2 1126 output3 = list(dp3) 1127 break 1128 self.assertEqual(list(range(5)), output1) 1129 self.assertEqual(list(range(5)), output2) 1130 self.assertEqual(list(range(10)), output3) 1131 1132 # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read 1133 dp1, dp2 = input_dp.fork(num_instances=2) 1134 _ = iter(dp1) 1135 output2 = [] 1136 with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"): 1137 for i, n2 in enumerate(dp2): 1138 output2.append(n2) 1139 if i == 4: 1140 with warnings.catch_warnings(record=True) as wa: 1141 _ = iter(dp1) # This will reset all child DataPipes 1142 self.assertEqual(len(wa), 1) 1143 self.assertRegex( 1144 str(wa[0].message), r"child DataPipes are not exhausted" 1145 ) 1146 self.assertEqual(list(range(5)), output2) 1147 1148 # Reset Test: DataPipe resets when some of it has been read 1149 dp1, dp2 = input_dp.fork(num_instances=2) 1150 output1, output2 = [], [] 1151 for i, (n1, n2) in enumerate(zip(dp1, dp2)): 1152 output1.append(n1) 1153 output2.append(n2) 1154 if i == 4: 1155 with warnings.catch_warnings(record=True) as wa: 1156 _ = iter(dp1) # Reset both all child DataPipe 1157 self.assertEqual(len(wa), 1) 1158 self.assertRegex( 1159 str(wa[0].message), r"Some child DataPipes are not exhausted" 1160 ) 1161 break 1162 with warnings.catch_warnings(record=True) as wa: 1163 for i, (n1, n2) in enumerate(zip(dp1, dp2)): 1164 output1.append(n1) 1165 output2.append(n2) 1166 self.assertEqual(len(wa), 1) 1167 self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 1168 self.assertEqual(list(range(5)) + list(range(10)), output1) 1169 self.assertEqual(list(range(5)) + list(range(10)), output2) 1170 1171 # Reset Test: DataPipe reset, even when some other child DataPipes are not read 1172 dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1173 output1, output2 = list(dp1), list(dp2) 1174 self.assertEqual(list(range(10)), output1) 1175 self.assertEqual(list(range(10)), output2) 1176 with warnings.catch_warnings(record=True) as wa: 1177 self.assertEqual( 1178 list(range(10)), list(dp1) 1179 ) # Resets even though dp3 has not been read 1180 self.assertEqual(len(wa), 1) 1181 self.assertRegex( 1182 str(wa[0].message), r"Some child DataPipes are not exhausted" 1183 ) 1184 output3 = [] 1185 for i, n3 in enumerate(dp3): 1186 output3.append(n3) 1187 if i == 4: 1188 with warnings.catch_warnings(record=True) as wa: 1189 output1 = list(dp1) # Resets even though dp3 is only partially read 1190 self.assertEqual(len(wa), 1) 1191 self.assertRegex( 1192 str(wa[0].message), r"Some child DataPipes are not exhausted" 1193 ) 1194 self.assertEqual(list(range(5)), output3) 1195 self.assertEqual(list(range(10)), output1) 1196 break 1197 self.assertEqual( 1198 list(range(10)), list(dp3) 1199 ) # dp3 has to read from the start again 1200 1201 # __len__ Test: Each DataPipe inherits the source datapipe's length 1202 dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1203 self.assertEqual(len(input_dp), len(dp1)) 1204 self.assertEqual(len(input_dp), len(dp2)) 1205 self.assertEqual(len(input_dp), len(dp3)) 1206 1207 # Pickle Test: 1208 dp1, dp2, dp3 = input_dp.fork(num_instances=3) 1209 traverse_dps(dp1) # This should not raise any error 1210 for _ in zip(dp1, dp2, dp3): 1211 pass 1212 traverse_dps(dp2) # This should not raise any error either 1213 1214 def test_mux_iterdatapipe(self): 1215 # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted 1216 input_dp1 = dp.iter.IterableWrapper(range(4)) 1217 input_dp2 = dp.iter.IterableWrapper(range(4, 8)) 1218 input_dp3 = dp.iter.IterableWrapper(range(8, 12)) 1219 output_dp = input_dp1.mux(input_dp2, input_dp3) 1220 expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] 1221 self.assertEqual(len(expected_output), len(output_dp)) 1222 self.assertEqual(expected_output, list(output_dp)) 1223 1224 # Functional Test: Uneven input Data Pipes 1225 input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4]) 1226 input_dp2 = dp.iter.IterableWrapper([10]) 1227 input_dp3 = dp.iter.IterableWrapper([100, 200, 300]) 1228 output_dp = input_dp1.mux(input_dp2, input_dp3) 1229 expected_output = [1, 10, 100] 1230 self.assertEqual(len(expected_output), len(output_dp)) 1231 self.assertEqual(expected_output, list(output_dp)) 1232 1233 # Functional Test: Empty Data Pipe 1234 input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3]) 1235 input_dp2 = dp.iter.IterableWrapper([]) 1236 output_dp = input_dp1.mux(input_dp2) 1237 self.assertEqual(len(input_dp2), len(output_dp)) 1238 self.assertEqual(list(input_dp2), list(output_dp)) 1239 1240 # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__ 1241 input_dp1 = dp.iter.IterableWrapper(range(10)) 1242 input_dp_no_len = IDP_NoLen(range(10)) 1243 output_dp = input_dp1.mux(input_dp_no_len) 1244 with self.assertRaises(TypeError): 1245 len(output_dp) 1246 1247 def test_demux_iterdatapipe(self): 1248 input_dp = dp.iter.IterableWrapper(range(10)) 1249 1250 with self.assertRaises(ValueError): 1251 input_dp.demux(num_instances=0, classifier_fn=lambda x: 0) 1252 1253 # Functional Test: split into 2 DataPipes and output them one at a time 1254 dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1255 output1, output2 = list(dp1), list(dp2) 1256 self.assertEqual(list(range(0, 10, 2)), output1) 1257 self.assertEqual(list(range(1, 10, 2)), output2) 1258 1259 # Functional Test: split into 2 DataPipes and output them together 1260 dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1261 output = [] 1262 for n1, n2 in zip(dp1, dp2): 1263 output.append((n1, n2)) 1264 self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output) 1265 1266 # Functional Test: values of the same classification are lumped together, and buffer_size = 3 being too small 1267 dp1, dp2 = input_dp.demux( 1268 num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4 1269 ) 1270 it1 = iter(dp1) 1271 with self.assertRaises(BufferError): 1272 next( 1273 it1 1274 ) # Buffer raises because first 5 elements all belong to the a different child 1275 with self.assertRaises(BufferError): 1276 list(dp2) 1277 1278 # Functional Test: values of the same classification are lumped together, and buffer_size = 5 is just enough 1279 dp1, dp2 = input_dp.demux( 1280 num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5 1281 ) 1282 output1, output2 = list(dp1), list(dp2) 1283 self.assertEqual(list(range(5, 10)), output1) 1284 self.assertEqual(list(range(0, 5)), output2) 1285 1286 # Functional Test: values of the same classification are lumped together, and unlimited buffer 1287 with warnings.catch_warnings(record=True) as wa: 1288 dp1, dp2 = input_dp.demux( 1289 num_instances=2, 1290 classifier_fn=lambda x: 0 if x >= 5 else 1, 1291 buffer_size=-1, 1292 ) 1293 exp_l = 1 if HAS_DILL else 2 1294 self.assertEqual(len(wa), exp_l) 1295 self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set") 1296 output1, output2 = list(dp1), list(dp2) 1297 self.assertEqual(list(range(5, 10)), output1) 1298 self.assertEqual(list(range(0, 5)), output2) 1299 1300 # Functional Test: classifier returns a value outside of [0, num_instance - 1] 1301 dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2) 1302 it = iter(dp0[0]) 1303 with self.assertRaises(ValueError): 1304 next(it) 1305 next(it) 1306 1307 # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read 1308 dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1309 _ = iter(dp1) 1310 output2 = [] 1311 with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"): 1312 for i, n2 in enumerate(dp2): 1313 output2.append(n2) 1314 if i == 4: 1315 with warnings.catch_warnings(record=True) as wa: 1316 _ = iter(dp1) # This will reset all child DataPipes 1317 self.assertEqual(len(wa), 1) 1318 self.assertRegex( 1319 str(wa[0].message), r"child DataPipes are not exhausted" 1320 ) 1321 self.assertEqual(list(range(1, 10, 2)), output2) 1322 1323 # Reset Test: DataPipe resets when some of it has been read 1324 dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1325 output1, output2 = [], [] 1326 for n1, n2 in zip(dp1, dp2): 1327 output1.append(n1) 1328 output2.append(n2) 1329 if n1 == 4: 1330 break 1331 with warnings.catch_warnings(record=True) as wa: 1332 i1 = iter(dp1) # Reset all child DataPipes 1333 self.assertEqual(len(wa), 1) 1334 self.assertRegex( 1335 str(wa[0].message), r"Some child DataPipes are not exhausted" 1336 ) 1337 for n1, n2 in zip(dp1, dp2): 1338 output1.append(n1) 1339 output2.append(n2) 1340 self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1) 1341 self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2) 1342 self.assertEqual(len(wa), 1) 1343 self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 1344 1345 # Reset Test: DataPipe reset, even when not all child DataPipes are exhausted 1346 dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1347 output1 = list(dp1) 1348 self.assertEqual(list(range(0, 10, 2)), output1) 1349 with warnings.catch_warnings(record=True) as wa: 1350 self.assertEqual( 1351 list(range(0, 10, 2)), list(dp1) 1352 ) # Reset even when dp2 is not read 1353 self.assertEqual(len(wa), 1) 1354 self.assertRegex( 1355 str(wa[0].message), r"Some child DataPipes are not exhausted" 1356 ) 1357 output2 = [] 1358 for i, n2 in enumerate(dp2): 1359 output2.append(n2) 1360 if i == 1: 1361 self.assertEqual(list(range(1, 5, 2)), output2) 1362 with warnings.catch_warnings(record=True) as wa: 1363 self.assertEqual( 1364 list(range(0, 10, 2)), list(dp1) 1365 ) # Can reset even when dp2 is partially read 1366 self.assertEqual(len(wa), 1) 1367 self.assertRegex( 1368 str(wa[0].message), r"Some child DataPipes are not exhausted" 1369 ) 1370 break 1371 output2 = list(dp2) # output2 has to read from beginning again 1372 self.assertEqual(list(range(1, 10, 2)), output2) 1373 1374 # Functional Test: drop_none = True 1375 dp1, dp2 = input_dp.demux( 1376 num_instances=2, 1377 classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, 1378 drop_none=True, 1379 ) 1380 self.assertEqual([2, 4, 6, 8], list(dp1)) 1381 self.assertEqual([1, 3, 7, 9], list(dp2)) 1382 1383 # Functional Test: drop_none = False 1384 dp1, dp2 = input_dp.demux( 1385 num_instances=2, 1386 classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, 1387 drop_none=False, 1388 ) 1389 it1 = iter(dp1) 1390 with self.assertRaises(ValueError): 1391 next(it1) 1392 1393 # __len__ Test: __len__ not implemented 1394 dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) 1395 with self.assertRaises(TypeError): 1396 len( 1397 dp1 1398 ) # It is not implemented as we do not know length for each child in advance 1399 with self.assertRaises(TypeError): 1400 len(dp2) 1401 1402 # Pickle Test: 1403 dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=odd_or_even) 1404 traverse_dps(dp1) # This should not raise any error 1405 for _ in zip(dp1, dp2): 1406 pass 1407 traverse_dps(dp2) # This should not raise any error either 1408 1409 def test_map_iterdatapipe(self): 1410 target_length = 10 1411 input_dp = dp.iter.IterableWrapper(range(target_length)) 1412 1413 def fn(item, dtype=torch.float, *, sum=False): 1414 data = torch.tensor(item, dtype=dtype) 1415 return data if not sum else data.sum() 1416 1417 # Functional Test: apply to each element correctly 1418 map_dp = input_dp.map(fn) 1419 self.assertEqual(target_length, len(map_dp)) 1420 for x, y in zip(map_dp, range(target_length)): 1421 self.assertEqual(x, torch.tensor(y, dtype=torch.float)) 1422 1423 # Functional Test: works with partial function 1424 map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True)) 1425 for x, y in zip(map_dp, range(target_length)): 1426 self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) 1427 1428 # __len__ Test: inherits length from source DataPipe 1429 self.assertEqual(target_length, len(map_dp)) 1430 1431 input_dp_nl = IDP_NoLen(range(target_length)) 1432 map_dp_nl = input_dp_nl.map(lambda x: x) 1433 for x, y in zip(map_dp_nl, range(target_length)): 1434 self.assertEqual(x, torch.tensor(y, dtype=torch.float)) 1435 1436 # __len__ Test: inherits length from source DataPipe - raises error when invalid 1437 with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1438 len(map_dp_nl) 1439 1440 # Reset Test: DataPipe resets properly 1441 n_elements_before_reset = 5 1442 res_before_reset, res_after_reset = reset_after_n_next_calls( 1443 map_dp, n_elements_before_reset 1444 ) 1445 self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) 1446 self.assertEqual(list(range(10)), res_after_reset) 1447 1448 @suppress_warnings # Suppress warning for lambda fn 1449 def test_map_tuple_list_with_col_iterdatapipe(self): 1450 def fn_11(d): 1451 return -d 1452 1453 def fn_1n(d): 1454 return -d, d 1455 1456 def fn_n1(d0, d1): 1457 return d0 + d1 1458 1459 def fn_nn(d0, d1): 1460 return -d0, -d1, d0 + d1 1461 1462 def fn_n1_def(d0, d1=1): 1463 return d0 + d1 1464 1465 def fn_n1_kwargs(d0, d1, **kwargs): 1466 return d0 + d1 1467 1468 def fn_n1_pos(d0, d1, *args): 1469 return d0 + d1 1470 1471 def fn_n1_sep_pos(d0, *args, d1): 1472 return d0 + d1 1473 1474 def fn_cmplx(d0, d1=1, *args, d2, **kwargs): 1475 return d0 + d1 1476 1477 p_fn_n1 = partial(fn_n1, d1=1) 1478 p_fn_cmplx = partial(fn_cmplx, d2=2) 1479 p_fn_cmplx_large_arg = partial( 1480 fn_cmplx, d2={i: list(range(i)) for i in range(10_000)} 1481 ) 1482 1483 def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): 1484 for constr in (list, tuple): 1485 datapipe = dp.iter.IterableWrapper( 1486 [constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))] 1487 ) 1488 if ref_fn is None: 1489 with self.assertRaises(error): 1490 res_dp = datapipe.map(fn, input_col, output_col) 1491 list(res_dp) 1492 else: 1493 res_dp = datapipe.map(fn, input_col, output_col) 1494 ref_dp = datapipe.map(ref_fn) 1495 self.assertEqual(list(res_dp), list(ref_dp)) 1496 # Reset 1497 self.assertEqual(list(res_dp), list(ref_dp)) 1498 1499 _helper(lambda data: data, fn_n1_def, 0, 1) 1500 _helper( 1501 lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2 1502 ) 1503 _helper(lambda data: data, p_fn_n1, 0, 1) 1504 _helper(lambda data: data, p_fn_cmplx, 0, 1) 1505 _helper(lambda data: data, p_fn_cmplx_large_arg, 0, 1) 1506 _helper( 1507 lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2 1508 ) 1509 _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2]) 1510 1511 # Replacing with one input column and default output column 1512 _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) 1513 _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) 1514 # The index of input column is out of range 1515 _helper(None, fn_1n, 3, error=IndexError) 1516 # Unmatched input columns with fn arguments 1517 _helper(None, fn_n1, 1, error=ValueError) 1518 _helper(None, fn_n1, [0, 1, 2], error=ValueError) 1519 _helper(None, operator.add, 0, error=ValueError) 1520 _helper(None, operator.add, [0, 1, 2], error=ValueError) 1521 _helper(None, fn_cmplx, 0, 1, ValueError) 1522 _helper(None, fn_n1_pos, 1, error=ValueError) 1523 _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) 1524 _helper(None, p_fn_n1, [0, 1], error=ValueError) 1525 _helper(None, fn_1n, [1, 2], error=ValueError) 1526 # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) 1527 _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) 1528 # Fn has keyword-only arguments 1529 _helper(None, fn_n1_kwargs, 1, error=ValueError) 1530 _helper(None, fn_cmplx, [0, 1], 2, ValueError) 1531 1532 # Replacing with multiple input columns and default output column (the left-most input column) 1533 _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) 1534 _helper( 1535 lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), 1536 fn_nn, 1537 [2, 1], 1538 ) 1539 1540 # output_col can only be specified when input_col is not None 1541 _helper(None, fn_n1, None, 1, error=ValueError) 1542 # output_col can only be single-element list or tuple 1543 _helper(None, fn_n1, None, [0, 1], error=ValueError) 1544 # Single-element list as output_col 1545 _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) 1546 # Replacing with one input column and single specified output column 1547 _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) 1548 _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) 1549 # The index of output column is out of range 1550 _helper(None, fn_1n, 1, 3, error=IndexError) 1551 _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) 1552 _helper( 1553 lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), 1554 fn_nn, 1555 [1, 2], 1556 0, 1557 ) 1558 1559 # Appending the output at the end 1560 _helper(lambda data: (*data, -data[1]), fn_11, 1, -1) 1561 _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1) 1562 _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1) 1563 _helper( 1564 lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), 1565 fn_nn, 1566 [1, 2], 1567 -1, 1568 ) 1569 1570 # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected 1571 _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) 1572 _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) 1573 1574 # Handle nn.Module and Callable (without __name__ implemented) 1575 _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0) 1576 _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0) 1577 1578 @suppress_warnings # Suppress warning for lambda fn 1579 @skipIfTorchDynamo() 1580 def test_map_dict_with_col_iterdatapipe(self): 1581 def fn_11(d): 1582 return -d 1583 1584 def fn_1n(d): 1585 return -d, d 1586 1587 def fn_n1(d0, d1): 1588 return d0 + d1 1589 1590 def fn_nn(d0, d1): 1591 return -d0, -d1, d0 + d1 1592 1593 def fn_n1_def(d0, d1=1): 1594 return d0 + d1 1595 1596 p_fn_n1 = partial(fn_n1, d1=1) 1597 1598 def fn_n1_pos(d0, d1, *args): 1599 return d0 + d1 1600 1601 def fn_n1_kwargs(d0, d1, **kwargs): 1602 return d0 + d1 1603 1604 def fn_kwonly(*, d0, d1): 1605 return d0 + d1 1606 1607 def fn_has_nondefault_kwonly(d0, *, d1): 1608 return d0 + d1 1609 1610 def fn_cmplx(d0, d1=1, *args, d2, **kwargs): 1611 return d0 + d1 1612 1613 p_fn_cmplx = partial(fn_cmplx, d2=2) 1614 p_fn_cmplx_large_arg = partial( 1615 fn_cmplx, d2={i: list(range(i)) for i in range(10_000)} 1616 ) 1617 1618 # Prevent modification in-place to support resetting 1619 def _dict_update(data, newdata, remove_idx=None): 1620 _data = dict(data) 1621 _data.update(newdata) 1622 if remove_idx: 1623 for idx in remove_idx: 1624 del _data[idx] 1625 return _data 1626 1627 def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): 1628 datapipe = dp.iter.IterableWrapper( 1629 [ 1630 {"x": 0, "y": 1, "z": 2}, 1631 {"x": 3, "y": 4, "z": 5}, 1632 {"x": 6, "y": 7, "z": 8}, 1633 ] 1634 ) 1635 if ref_fn is None: 1636 with self.assertRaises(error): 1637 res_dp = datapipe.map(fn, input_col, output_col) 1638 list(res_dp) 1639 else: 1640 res_dp = datapipe.map(fn, input_col, output_col) 1641 ref_dp = datapipe.map(ref_fn) 1642 self.assertEqual(list(res_dp), list(ref_dp)) 1643 # Reset 1644 self.assertEqual(list(res_dp), list(ref_dp)) 1645 1646 _helper(lambda data: data, fn_n1_def, "x", "y") 1647 _helper(lambda data: data, p_fn_n1, "x", "y") 1648 _helper(lambda data: data, p_fn_cmplx, "x", "y") 1649 _helper(lambda data: data, p_fn_cmplx_large_arg, "x", "y") 1650 _helper( 1651 lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), 1652 p_fn_cmplx, 1653 ["x", "y", "z"], 1654 "z", 1655 ) 1656 1657 _helper( 1658 lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), 1659 fn_n1_def, 1660 ["x", "y"], 1661 "z", 1662 ) 1663 1664 _helper(None, fn_n1_pos, "x", error=ValueError) 1665 _helper(None, fn_n1_kwargs, "x", error=ValueError) 1666 # non-default kw-only args 1667 _helper(None, fn_kwonly, ["x", "y"], error=ValueError) 1668 _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError) 1669 _helper(None, fn_cmplx, ["x", "y"], error=ValueError) 1670 1671 # Replacing with one input column and default output column 1672 _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") 1673 _helper( 1674 lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y" 1675 ) 1676 # The key of input column is not in dict 1677 _helper(None, fn_1n, "a", error=KeyError) 1678 # Unmatched input columns with fn arguments 1679 _helper(None, fn_n1, "y", error=ValueError) 1680 _helper(None, fn_1n, ["x", "y"], error=ValueError) 1681 _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) 1682 _helper(None, p_fn_n1, ["x", "y"], error=ValueError) 1683 _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) 1684 # Replacing with multiple input columns and default output column (the left-most input column) 1685 _helper( 1686 lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), 1687 fn_n1, 1688 ["z", "x"], 1689 ) 1690 _helper( 1691 lambda data: _dict_update( 1692 data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"] 1693 ), 1694 fn_nn, 1695 ["z", "y"], 1696 ) 1697 1698 # output_col can only be specified when input_col is not None 1699 _helper(None, fn_n1, None, "x", error=ValueError) 1700 # output_col can only be single-element list or tuple 1701 _helper(None, fn_n1, None, ["x", "y"], error=ValueError) 1702 # Single-element list as output_col 1703 _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) 1704 # Replacing with one input column and single specified output column 1705 _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x") 1706 _helper( 1707 lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), 1708 fn_1n, 1709 "y", 1710 "z", 1711 ) 1712 _helper( 1713 lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), 1714 fn_n1, 1715 ["x", "z"], 1716 "y", 1717 ) 1718 _helper( 1719 lambda data: _dict_update( 1720 data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])} 1721 ), 1722 fn_nn, 1723 ["y", "z"], 1724 "x", 1725 ) 1726 1727 # Adding new key to dict for the output 1728 _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a") 1729 _helper( 1730 lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), 1731 fn_1n, 1732 "y", 1733 "a", 1734 ) 1735 _helper( 1736 lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), 1737 fn_n1, 1738 ["x", "z"], 1739 "a", 1740 ) 1741 _helper( 1742 lambda data: _dict_update( 1743 data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])} 1744 ), 1745 fn_nn, 1746 ["y", "z"], 1747 "a", 1748 ) 1749 1750 def test_collate_iterdatapipe(self): 1751 arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 1752 input_dp = dp.iter.IterableWrapper(arrs) 1753 1754 def _collate_fn(batch, default_type=torch.float): 1755 return torch.tensor(sum(batch), dtype=default_type) 1756 1757 # Functional Test: defaults to the default collate function when a custom one is not specified 1758 collate_dp = input_dp.collate() 1759 for x, y in zip(arrs, collate_dp): 1760 self.assertEqual(torch.tensor(x), y) 1761 1762 # Functional Test: custom collate function 1763 collate_dp = input_dp.collate(collate_fn=_collate_fn) 1764 for x, y in zip(arrs, collate_dp): 1765 self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y) 1766 1767 # Functional Test: custom, partial collate function 1768 collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int)) 1769 for x, y in zip(arrs, collate_dp): 1770 self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y) 1771 1772 # Reset Test: reset the DataPipe and results are still correct 1773 n_elements_before_reset = 1 1774 res_before_reset, res_after_reset = reset_after_n_next_calls( 1775 collate_dp, n_elements_before_reset 1776 ) 1777 self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset) 1778 for x, y in zip(arrs, res_after_reset): 1779 self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y) 1780 1781 # __len__ Test: __len__ is inherited 1782 self.assertEqual(len(input_dp), len(collate_dp)) 1783 1784 # __len__ Test: verify that it has no valid __len__ when the source doesn't have it 1785 input_dp_nl = IDP_NoLen(arrs) 1786 collate_dp_nl = input_dp_nl.collate() 1787 with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1788 len(collate_dp_nl) 1789 for x, y in zip(arrs, collate_dp_nl): 1790 self.assertEqual(torch.tensor(x), y) 1791 1792 def test_batch_iterdatapipe(self): 1793 arrs = list(range(10)) 1794 input_dp = dp.iter.IterableWrapper(arrs) 1795 1796 # Functional Test: raise error when input argument `batch_size = 0` 1797 with self.assertRaises(AssertionError): 1798 input_dp.batch(batch_size=0) 1799 1800 # Functional Test: by default, do not drop the last batch 1801 bs = 3 1802 batch_dp = input_dp.batch(batch_size=bs) 1803 self.assertEqual(len(batch_dp), 4) 1804 for i, batch in enumerate(batch_dp): 1805 self.assertEqual(len(batch), 1 if i == 3 else bs) 1806 self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)]) 1807 1808 # Functional Test: Drop the last batch when specified 1809 bs = 4 1810 batch_dp = input_dp.batch(batch_size=bs, drop_last=True) 1811 for i, batch in enumerate(batch_dp): 1812 self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)]) 1813 1814 # __len__ test: verifying that the overall length and of each batch is correct 1815 for i, batch in enumerate(batch_dp): 1816 self.assertEqual(len(batch), bs) 1817 1818 # __len__ Test: the length is missing if the source DataPipe doesn't have length 1819 self.assertEqual(len(batch_dp), 2) 1820 input_dp_nl = IDP_NoLen(range(10)) 1821 batch_dp_nl = input_dp_nl.batch(batch_size=2) 1822 with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 1823 len(batch_dp_nl) 1824 1825 # Reset Test: Ensures that the DataPipe can properly reset 1826 n_elements_before_reset = 1 1827 res_before_reset, res_after_reset = reset_after_n_next_calls( 1828 batch_dp, n_elements_before_reset 1829 ) 1830 self.assertEqual([[0, 1, 2, 3]], res_before_reset) 1831 self.assertEqual([[0, 1, 2, 3], [4, 5, 6, 7]], res_after_reset) 1832 1833 def test_unbatch_iterdatapipe(self): 1834 target_length = 6 1835 prebatch_dp = dp.iter.IterableWrapper(range(target_length)) 1836 1837 # Functional Test: Unbatch DataPipe should be the same as pre-batch DataPipe 1838 input_dp = prebatch_dp.batch(3) 1839 unbatch_dp = input_dp.unbatch() 1840 self.assertEqual(len(list(unbatch_dp)), target_length) # __len__ is as expected 1841 for i, res in zip(range(target_length), unbatch_dp): 1842 self.assertEqual(i, res) 1843 1844 # Functional Test: unbatch works for an input with nested levels 1845 input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]]) 1846 unbatch_dp = input_dp.unbatch() 1847 self.assertEqual(len(list(unbatch_dp)), target_length) 1848 for i, res in zip(range(target_length), unbatch_dp): 1849 self.assertEqual(i, res) 1850 1851 input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) 1852 1853 # Functional Test: unbatch works for an input with nested levels 1854 unbatch_dp = input_dp.unbatch() 1855 expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]] 1856 self.assertEqual(len(list(unbatch_dp)), 4) 1857 for j, res in zip(expected_dp, unbatch_dp): 1858 self.assertEqual(j, res) 1859 1860 # Functional Test: unbatching multiple levels at the same time 1861 unbatch_dp = input_dp.unbatch(unbatch_level=2) 1862 expected_dp2 = [0, 1, 2, 3, 4, 5, 6, 7] 1863 self.assertEqual(len(list(unbatch_dp)), 8) 1864 for i, res in zip(expected_dp2, unbatch_dp): 1865 self.assertEqual(i, res) 1866 1867 # Functional Test: unbatching all levels at the same time 1868 unbatch_dp = input_dp.unbatch(unbatch_level=-1) 1869 self.assertEqual(len(list(unbatch_dp)), 8) 1870 for i, res in zip(expected_dp2, unbatch_dp): 1871 self.assertEqual(i, res) 1872 1873 # Functional Test: raises error when input unbatch_level is less than -1 1874 input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]]) 1875 with self.assertRaises(ValueError): 1876 unbatch_dp = input_dp.unbatch(unbatch_level=-2) 1877 for i in unbatch_dp: 1878 print(i) 1879 1880 # Functional Test: raises error when input unbatch_level is too high 1881 with self.assertRaises(IndexError): 1882 unbatch_dp = input_dp.unbatch(unbatch_level=5) 1883 for i in unbatch_dp: 1884 print(i) 1885 1886 # Reset Test: unbatch_dp resets properly 1887 input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]]) 1888 unbatch_dp = input_dp.unbatch(unbatch_level=-1) 1889 n_elements_before_reset = 3 1890 res_before_reset, res_after_reset = reset_after_n_next_calls( 1891 unbatch_dp, n_elements_before_reset 1892 ) 1893 self.assertEqual([0, 1, 2], res_before_reset) 1894 self.assertEqual([0, 1, 2, 3, 4, 5], res_after_reset) 1895 1896 def test_filter_datapipe(self): 1897 input_ds = dp.iter.IterableWrapper(range(10)) 1898 1899 def _filter_fn(data, val): 1900 return data >= val 1901 1902 # Functional Test: filter works with partial function 1903 filter_dp = input_ds.filter(partial(_filter_fn, val=5)) 1904 self.assertEqual(list(filter_dp), list(range(5, 10))) 1905 1906 def _non_bool_fn(data): 1907 return 1 1908 1909 # Functional Test: filter function must return bool 1910 filter_dp = input_ds.filter(filter_fn=_non_bool_fn) 1911 with self.assertRaises(ValueError): 1912 temp = list(filter_dp) 1913 1914 # Funtional Test: Specify input_col 1915 tuple_input_ds = dp.iter.IterableWrapper([(d - 1, d, d + 1) for d in range(10)]) 1916 1917 # Single input_col 1918 input_col_1_dp = tuple_input_ds.filter(partial(_filter_fn, val=5), input_col=1) 1919 self.assertEqual( 1920 list(input_col_1_dp), [(d - 1, d, d + 1) for d in range(5, 10)] 1921 ) 1922 1923 # Multiple input_col 1924 def _mul_filter_fn(a, b): 1925 return a + b < 10 1926 1927 input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2]) 1928 self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)]) 1929 1930 # invalid input col 1931 with self.assertRaises(ValueError): 1932 tuple_input_ds.filter(_mul_filter_fn, input_col=0) 1933 1934 p_mul_filter_fn = partial(_mul_filter_fn, b=1) 1935 out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0) 1936 self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)]) 1937 1938 def _mul_filter_fn_with_defaults(a, b=1): 1939 return a + b < 10 1940 1941 out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0) 1942 self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)]) 1943 1944 def _mul_filter_fn_with_kw_only(*, a, b): 1945 return a + b < 10 1946 1947 with self.assertRaises(ValueError): 1948 tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0) 1949 1950 def _mul_filter_fn_with_kw_only_1_default(*, a, b=1): 1951 return a + b < 10 1952 1953 with self.assertRaises(ValueError): 1954 tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0) 1955 1956 # __len__ Test: DataPipe has no valid len 1957 with self.assertRaisesRegex(TypeError, r"has no len"): 1958 len(filter_dp) 1959 1960 # Reset Test: DataPipe resets correctly 1961 filter_dp = input_ds.filter(partial(_filter_fn, val=5)) 1962 n_elements_before_reset = 3 1963 res_before_reset, res_after_reset = reset_after_n_next_calls( 1964 filter_dp, n_elements_before_reset 1965 ) 1966 self.assertEqual(list(range(5, 10))[:n_elements_before_reset], res_before_reset) 1967 self.assertEqual(list(range(5, 10)), res_after_reset) 1968 1969 def test_sampler_iterdatapipe(self): 1970 input_dp = dp.iter.IterableWrapper(range(10)) 1971 # Default SequentialSampler 1972 sampled_dp = dp.iter.Sampler(input_dp) # type: ignore[var-annotated] 1973 self.assertEqual(len(sampled_dp), 10) 1974 for i, x in enumerate(sampled_dp): 1975 self.assertEqual(x, i) 1976 1977 # RandomSampler 1978 random_sampled_dp = dp.iter.Sampler( 1979 input_dp, sampler=RandomSampler, sampler_kwargs={"replacement": True} 1980 ) # type: ignore[var-annotated] # noqa: B950 1981 1982 # Requires `__len__` to build SamplerDataPipe 1983 input_dp_nolen = IDP_NoLen(range(10)) 1984 with self.assertRaises(AssertionError): 1985 sampled_dp = dp.iter.Sampler(input_dp_nolen) 1986 1987 def test_stream_reader_iterdatapipe(self): 1988 from io import StringIO 1989 1990 input_dp = dp.iter.IterableWrapper( 1991 [("f1", StringIO("abcde")), ("f2", StringIO("bcdef"))] 1992 ) 1993 expected_res = ["abcde", "bcdef"] 1994 1995 # Functional Test: Read full chunk 1996 dp1 = input_dp.read_from_stream() 1997 self.assertEqual([d[1] for d in dp1], expected_res) 1998 1999 # Functional Test: Read full chunk 2000 dp2 = input_dp.read_from_stream(chunk=1) 2001 self.assertEqual([d[1] for d in dp2], [c for s in expected_res for c in s]) 2002 2003 # `__len__` Test 2004 with self.assertRaises(TypeError): 2005 len(dp1) 2006 2007 def test_shuffler_iterdatapipe(self): 2008 input_dp = dp.iter.IterableWrapper(list(range(10))) 2009 2010 with self.assertRaises(AssertionError): 2011 shuffle_dp = input_dp.shuffle(buffer_size=0) 2012 2013 # Functional Test: No seed 2014 shuffler_dp = input_dp.shuffle() 2015 self.assertEqual(set(range(10)), set(shuffler_dp)) 2016 2017 # Functional Test: With global seed 2018 torch.manual_seed(123) 2019 shuffler_dp = input_dp.shuffle() 2020 res = list(shuffler_dp) 2021 torch.manual_seed(123) 2022 self.assertEqual(list(shuffler_dp), res) 2023 2024 # Functional Test: Set seed 2025 shuffler_dp = input_dp.shuffle().set_seed(123) 2026 res = list(shuffler_dp) 2027 shuffler_dp.set_seed(123) 2028 self.assertEqual(list(shuffler_dp), res) 2029 2030 # Functional Test: deactivate shuffling via set_shuffle 2031 unshuffled_dp = input_dp.shuffle().set_shuffle(False) 2032 self.assertEqual(list(unshuffled_dp), list(input_dp)) 2033 2034 # Reset Test: 2035 shuffler_dp = input_dp.shuffle() 2036 n_elements_before_reset = 5 2037 res_before_reset, res_after_reset = reset_after_n_next_calls( 2038 shuffler_dp, n_elements_before_reset 2039 ) 2040 self.assertEqual(5, len(res_before_reset)) 2041 for x in res_before_reset: 2042 self.assertTrue(x in set(range(10))) 2043 self.assertEqual(set(range(10)), set(res_after_reset)) 2044 2045 # __len__ Test: returns the length of the input DataPipe 2046 shuffler_dp = input_dp.shuffle() 2047 self.assertEqual(10, len(shuffler_dp)) 2048 exp = list(range(100)) 2049 2050 # Serialization Test 2051 from torch.utils.data.datapipes._hook_iterator import _SnapshotState 2052 2053 def _serialization_helper(bs): 2054 shuffler_dp = input_dp.shuffle(buffer_size=bs) 2055 it = iter(shuffler_dp) 2056 for _ in range(2): 2057 next(it) 2058 shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp)) 2059 _simple_graph_snapshot_restoration( 2060 shuffler_dp_copy.datapipe, 2061 shuffler_dp.datapipe._number_of_samples_yielded, 2062 ) 2063 2064 exp = list(it) 2065 shuffler_dp_copy._snapshot_state = _SnapshotState.Restored 2066 self.assertEqual(exp, list(shuffler_dp_copy)) 2067 2068 buffer_sizes = [2, 5, 15] 2069 for bs in buffer_sizes: 2070 _serialization_helper(bs) 2071 2072 def test_zip_iterdatapipe(self): 2073 # Functional Test: raises TypeError when an input is not of type `IterDataPipe` 2074 with self.assertRaises(TypeError): 2075 dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10))) # type: ignore[arg-type] 2076 2077 # Functional Test: raises TypeError when an input does not have valid length 2078 zipped_dp = dp.iter.Zipper( 2079 dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5)) 2080 ) # type: ignore[var-annotated] 2081 with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): 2082 len(zipped_dp) 2083 2084 # Functional Test: zips the results properly 2085 exp = [(i, i) for i in range(5)] 2086 self.assertEqual(list(zipped_dp), exp) 2087 2088 # Functional Test: zips the inputs properly even when lengths are different (zips to the shortest) 2089 zipped_dp = dp.iter.Zipper( 2090 dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5)) 2091 ) 2092 2093 # __len__ Test: length matches the length of the shortest input 2094 self.assertEqual(len(zipped_dp), 5) 2095 2096 # Reset Test: 2097 n_elements_before_reset = 3 2098 res_before_reset, res_after_reset = reset_after_n_next_calls( 2099 zipped_dp, n_elements_before_reset 2100 ) 2101 expected_res = [(i, i) for i in range(5)] 2102 self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) 2103 self.assertEqual(expected_res, res_after_reset) 2104 2105 2106class TestFunctionalMapDataPipe(TestCase): 2107 def _serialization_test_helper(self, datapipe, use_dill): 2108 if use_dill: 2109 serialized_dp = dill.dumps(datapipe) 2110 deserialized_dp = dill.loads(serialized_dp) 2111 else: 2112 serialized_dp = pickle.dumps(datapipe) 2113 deserialized_dp = pickle.loads(serialized_dp) 2114 try: 2115 self.assertEqual(list(datapipe), list(deserialized_dp)) 2116 except AssertionError as e: 2117 print(f"{datapipe} is failing.") 2118 raise e 2119 2120 def _serialization_test_for_single_dp(self, dp, use_dill=False): 2121 # 1. Testing for serialization before any iteration starts 2122 self._serialization_test_helper(dp, use_dill) 2123 # 2. Testing for serialization after DataPipe is partially read 2124 it = iter(dp) 2125 _ = next(it) 2126 self._serialization_test_helper(dp, use_dill) 2127 # 3. Testing for serialization after DataPipe is fully read 2128 _ = list(dp) 2129 self._serialization_test_helper(dp, use_dill) 2130 2131 def test_serializable(self): 2132 picklable_datapipes: List = [ 2133 (dp.map.Batcher, None, (2,), {}), 2134 (dp.map.Concater, None, (dp.map.SequenceWrapper(range(10)),), {}), 2135 (dp.map.Mapper, None, (), {}), 2136 (dp.map.Mapper, None, (_fake_fn,), {}), 2137 (dp.map.Mapper, None, (partial(_fake_add, 1),), {}), 2138 (dp.map.SequenceWrapper, range(10), (), {}), 2139 (dp.map.Shuffler, dp.map.SequenceWrapper([0] * 5), (), {}), 2140 (dp.map.Zipper, None, (dp.map.SequenceWrapper(range(10)),), {}), 2141 ] 2142 for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes: 2143 if custom_input is None: 2144 custom_input = dp.map.SequenceWrapper(range(10)) 2145 datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] 2146 self._serialization_test_for_single_dp(datapipe) 2147 2148 def test_serializable_with_dill(self): 2149 """Only for DataPipes that take in a function as argument""" 2150 input_dp = dp.map.SequenceWrapper(range(10)) 2151 2152 datapipes_with_lambda_fn: List[ 2153 Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] 2154 ] = [ 2155 (dp.map.Mapper, (lambda_fn1,), {}), 2156 ] 2157 2158 def _local_fns(): 2159 def _fn1(x): 2160 return x 2161 2162 return _fn1 2163 2164 fn1 = _local_fns() 2165 2166 datapipes_with_local_fn: List[ 2167 Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] 2168 ] = [ 2169 (dp.map.Mapper, (fn1,), {}), 2170 ] 2171 2172 if HAS_DILL: 2173 for dpipe, dp_args, dp_kwargs in ( 2174 datapipes_with_lambda_fn + datapipes_with_local_fn 2175 ): 2176 _ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] 2177 else: 2178 msgs = ( 2179 r"^Lambda function is not supported by pickle", 2180 r"^Local function is not supported by pickle", 2181 ) 2182 for dps, msg in zip( 2183 (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs 2184 ): 2185 for dpipe, dp_args, dp_kwargs in dps: 2186 with self.assertWarnsRegex(UserWarning, msg): 2187 datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] 2188 with self.assertRaises((pickle.PicklingError, AttributeError)): 2189 pickle.dumps(datapipe) 2190 2191 def test_docstring(self): 2192 """ 2193 Ensure functional form of MapDataPipe has the correct docstring from 2194 the class form. 2195 2196 Regression test for https://github.com/pytorch/data/issues/792. 2197 """ 2198 input_dp = dp.map.SequenceWrapper(range(10)) 2199 2200 for dp_funcname in [ 2201 "batch", 2202 "concat", 2203 "map", 2204 "shuffle", 2205 "zip", 2206 ]: 2207 if sys.version_info >= (3, 9): 2208 docstring = pydoc.render_doc( 2209 thing=getattr(input_dp, dp_funcname), forceload=True 2210 ) 2211 elif sys.version_info < (3, 9): 2212 # pydoc works differently on Python 3.8, see 2213 # https://docs.python.org/3/whatsnew/3.9.html#pydoc 2214 docstring = getattr(input_dp, dp_funcname).__doc__ 2215 assert f"(functional name: ``{dp_funcname}``)" in docstring 2216 assert "Args:" in docstring 2217 assert "Example:" in docstring or "Examples:" in docstring 2218 2219 def test_sequence_wrapper_datapipe(self): 2220 seq = list(range(10)) 2221 input_dp = dp.map.SequenceWrapper(seq) 2222 2223 # Functional Test: all elements are equal in the same order 2224 self.assertEqual(seq, list(input_dp)) 2225 2226 # Functional Test: confirm deepcopy works by default 2227 seq.append(11) 2228 self.assertEqual(list(range(10)), list(input_dp)) # input_dp shouldn't have 11 2229 2230 # Functional Test: non-deepcopy version is working 2231 seq2 = [1, 2, 3] 2232 input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False) 2233 seq2.append(4) 2234 self.assertEqual(list(seq2), list(input_dp_non_deep)) # should have 4 2235 2236 # Reset Test: reset the DataPipe 2237 seq = list(range(10)) 2238 n_elements_before_reset = 5 2239 res_before_reset, res_after_reset = reset_after_n_next_calls( 2240 input_dp, n_elements_before_reset 2241 ) 2242 self.assertEqual(list(range(5)), res_before_reset) 2243 self.assertEqual(seq, res_after_reset) 2244 2245 # __len__ Test: inherits length from sequence 2246 self.assertEqual(len(seq), len(input_dp)) 2247 2248 def test_concat_mapdatapipe(self): 2249 input_dp1 = dp.map.SequenceWrapper(range(10)) 2250 input_dp2 = dp.map.SequenceWrapper(range(5)) 2251 2252 with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): 2253 dp.map.Concater() 2254 2255 with self.assertRaisesRegex( 2256 TypeError, r"Expected all inputs to be `MapDataPipe`" 2257 ): 2258 dp.map.Concater(input_dp1, ()) # type: ignore[arg-type] 2259 2260 concat_dp = input_dp1.concat(input_dp2) 2261 self.assertEqual(len(concat_dp), 15) 2262 for index in range(15): 2263 self.assertEqual( 2264 concat_dp[index], (list(range(10)) + list(range(5)))[index] 2265 ) 2266 self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) 2267 2268 def test_zip_mapdatapipe(self): 2269 input_dp1 = dp.map.SequenceWrapper(range(10)) 2270 input_dp2 = dp.map.SequenceWrapper(range(5)) 2271 input_dp3 = dp.map.SequenceWrapper(range(15)) 2272 2273 # Functional Test: requires at least one input DataPipe 2274 with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): 2275 dp.map.Zipper() 2276 2277 # Functional Test: all inputs must be MapDataPipes 2278 with self.assertRaisesRegex( 2279 TypeError, r"Expected all inputs to be `MapDataPipe`" 2280 ): 2281 dp.map.Zipper(input_dp1, ()) # type: ignore[arg-type] 2282 2283 # Functional Test: Zip the elements up as a tuples 2284 zip_dp = input_dp1.zip(input_dp2, input_dp3) 2285 self.assertEqual([(i, i, i) for i in range(5)], [zip_dp[i] for i in range(5)]) 2286 2287 # Functional Test: Raise IndexError when index equal or exceed the length of the shortest DataPipe 2288 with self.assertRaisesRegex(IndexError, r"out of range"): 2289 input_dp1.zip(input_dp2, input_dp3)[5] 2290 2291 # Functional Test: Ensure `zip` can combine `Batcher` with others 2292 dp1 = dp.map.SequenceWrapper(range(10)) 2293 shuffle_dp1 = dp1.batch(2) 2294 dp2 = dp.map.SequenceWrapper(range(10)) 2295 shuffle_dp2 = dp2.batch(3) 2296 zip_dp1 = shuffle_dp1.zip(shuffle_dp2) 2297 self.assertEqual(4, len(list(zip_dp1))) 2298 zip_dp2 = shuffle_dp1.zip(dp2) 2299 self.assertEqual(5, len(list(zip_dp2))) 2300 2301 # __len__ Test: returns the length of the shortest DataPipe 2302 zip_dp = input_dp1.zip(input_dp2, input_dp3) 2303 self.assertEqual(5, len(zip_dp)) 2304 2305 def test_shuffler_mapdatapipe(self): 2306 input_dp1 = dp.map.SequenceWrapper(range(10)) 2307 input_dp2 = dp.map.SequenceWrapper({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) 2308 2309 # Functional Test: Assumes 0-index when indices is not given 2310 shuffler_dp = input_dp1.shuffle() 2311 self.assertEqual(set(range(10)), set(shuffler_dp)) 2312 2313 # Functional Test: Custom indices are working 2314 shuffler_dp = input_dp2.shuffle(indices=["a", "b", "c", "d", "e"]) 2315 self.assertEqual(set(range(1, 6)), set(shuffler_dp)) 2316 2317 # Functional Test: With global seed 2318 torch.manual_seed(123) 2319 shuffler_dp = input_dp1.shuffle() 2320 res = list(shuffler_dp) 2321 torch.manual_seed(123) 2322 self.assertEqual(list(shuffler_dp), res) 2323 2324 # Functional Test: Set seed 2325 shuffler_dp = input_dp1.shuffle().set_seed(123) 2326 res = list(shuffler_dp) 2327 shuffler_dp.set_seed(123) 2328 self.assertEqual(list(shuffler_dp), res) 2329 2330 # Functional Test: deactivate shuffling via set_shuffle 2331 unshuffled_dp = input_dp1.shuffle().set_shuffle(False) 2332 self.assertEqual(list(unshuffled_dp), list(input_dp1)) 2333 2334 # Reset Test: 2335 shuffler_dp = input_dp1.shuffle() 2336 n_elements_before_reset = 5 2337 res_before_reset, res_after_reset = reset_after_n_next_calls( 2338 shuffler_dp, n_elements_before_reset 2339 ) 2340 self.assertEqual(5, len(res_before_reset)) 2341 for x in res_before_reset: 2342 self.assertTrue(x in set(range(10))) 2343 self.assertEqual(set(range(10)), set(res_after_reset)) 2344 2345 # __len__ Test: returns the length of the input DataPipe 2346 shuffler_dp = input_dp1.shuffle() 2347 self.assertEqual(10, len(shuffler_dp)) 2348 2349 # Serialization Test 2350 from torch.utils.data.datapipes._hook_iterator import _SnapshotState 2351 2352 shuffler_dp = input_dp1.shuffle() 2353 it = iter(shuffler_dp) 2354 for _ in range(2): 2355 next(it) 2356 shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp)) 2357 2358 exp = list(it) 2359 shuffler_dp_copy._snapshot_state = _SnapshotState.Restored 2360 self.assertEqual(exp, list(shuffler_dp_copy)) 2361 2362 def test_map_mapdatapipe(self): 2363 arr = range(10) 2364 input_dp = dp.map.SequenceWrapper(arr) 2365 2366 def fn(item, dtype=torch.float, *, sum=False): 2367 data = torch.tensor(item, dtype=dtype) 2368 return data if not sum else data.sum() 2369 2370 map_dp = input_dp.map(fn) 2371 self.assertEqual(len(input_dp), len(map_dp)) 2372 for index in arr: 2373 self.assertEqual( 2374 map_dp[index], torch.tensor(input_dp[index], dtype=torch.float) 2375 ) 2376 2377 map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True)) 2378 self.assertEqual(len(input_dp), len(map_dp)) 2379 for index in arr: 2380 self.assertEqual( 2381 map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum() 2382 ) 2383 2384 def test_batch_mapdatapipe(self): 2385 arr = list(range(13)) 2386 input_dp = dp.map.SequenceWrapper(arr) 2387 2388 # Functional Test: batches top level by default 2389 batch_dp = dp.map.Batcher(input_dp, batch_size=2) 2390 self.assertEqual( 2391 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12]], list(batch_dp) 2392 ) 2393 2394 # Functional Test: drop_last on command 2395 batch_dp = dp.map.Batcher(input_dp, batch_size=2, drop_last=True) 2396 self.assertEqual( 2397 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], list(batch_dp) 2398 ) 2399 2400 # Functional Test: nested batching 2401 batch_dp_2 = batch_dp.batch(batch_size=3) 2402 self.assertEqual( 2403 [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], list(batch_dp_2) 2404 ) 2405 2406 # Reset Test: 2407 n_elements_before_reset = 3 2408 res_before_reset, res_after_reset = reset_after_n_next_calls( 2409 batch_dp, n_elements_before_reset 2410 ) 2411 self.assertEqual([[0, 1], [2, 3], [4, 5]], res_before_reset) 2412 self.assertEqual( 2413 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], res_after_reset 2414 ) 2415 2416 # __len__ Test: 2417 self.assertEqual(6, len(batch_dp)) 2418 self.assertEqual(2, len(batch_dp_2)) 2419 2420 2421# Metaclass conflict for Python 3.6 2422# Multiple inheritance with NamedTuple is not supported for Python 3.9 2423_generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9) 2424if _generic_namedtuple_allowed: 2425 2426 class InvalidData(NamedTuple, Generic[T_co]): 2427 name: str 2428 data: T_co 2429 2430 2431class TestTyping(TestCase): 2432 def test_isinstance(self): 2433 class A(IterDataPipe): 2434 pass 2435 2436 class B(IterDataPipe): 2437 pass 2438 2439 a = A() 2440 self.assertTrue(isinstance(a, A)) 2441 self.assertFalse(isinstance(a, B)) 2442 2443 def test_protocol(self): 2444 try: 2445 from typing import Protocol # type: ignore[attr-defined] 2446 except ImportError: 2447 from typing import _Protocol # type: ignore[attr-defined] 2448 2449 Protocol = _Protocol 2450 2451 class P(Protocol): 2452 pass 2453 2454 class A(IterDataPipe[P]): 2455 pass 2456 2457 @skipTyping 2458 def test_subtype(self): 2459 from torch.utils.data.datapipes._typing import issubtype 2460 2461 basic_type = (int, str, bool, float, complex, list, tuple, dict, set, T_co) 2462 for t in basic_type: 2463 self.assertTrue(issubtype(t, t)) 2464 self.assertTrue(issubtype(t, Any)) 2465 if t == T_co: 2466 self.assertTrue(issubtype(Any, t)) 2467 else: 2468 self.assertFalse(issubtype(Any, t)) 2469 for t1, t2 in itertools.product(basic_type, basic_type): 2470 if t1 == t2 or t2 == T_co: 2471 self.assertTrue(issubtype(t1, t2)) 2472 else: 2473 self.assertFalse(issubtype(t1, t2)) 2474 2475 T = TypeVar("T", int, str) 2476 S = TypeVar("S", bool, Union[str, int], Tuple[int, T]) # type: ignore[valid-type] 2477 types = ( 2478 (int, Optional[int]), 2479 (List, Union[int, list]), 2480 (Tuple[int, str], S), 2481 (Tuple[int, str], tuple), 2482 (T, S), 2483 (S, T_co), 2484 (T, Union[S, Set]), 2485 ) 2486 for sub, par in types: 2487 self.assertTrue(issubtype(sub, par)) 2488 self.assertFalse(issubtype(par, sub)) 2489 2490 subscriptable_types = { 2491 List: 1, 2492 Tuple: 2, # use 2 parameters 2493 Set: 1, 2494 Dict: 2, 2495 } 2496 for subscript_type, n in subscriptable_types.items(): 2497 for ts in itertools.combinations(types, n): 2498 subs, pars = zip(*ts) 2499 sub = subscript_type[subs] # type: ignore[index] 2500 par = subscript_type[pars] # type: ignore[index] 2501 self.assertTrue(issubtype(sub, par)) 2502 self.assertFalse(issubtype(par, sub)) 2503 # Non-recursive check 2504 self.assertTrue(issubtype(par, sub, recursive=False)) 2505 2506 @skipTyping 2507 def test_issubinstance(self): 2508 from torch.utils.data.datapipes._typing import issubinstance 2509 2510 basic_data = (1, "1", True, 1.0, complex(1.0, 0.0)) 2511 basic_type = (int, str, bool, float, complex) 2512 S = TypeVar("S", bool, Union[str, int]) 2513 for d in basic_data: 2514 self.assertTrue(issubinstance(d, Any)) 2515 self.assertTrue(issubinstance(d, T_co)) 2516 if type(d) in (bool, int, str): 2517 self.assertTrue(issubinstance(d, S)) 2518 else: 2519 self.assertFalse(issubinstance(d, S)) 2520 for t in basic_type: 2521 if type(d) == t: 2522 self.assertTrue(issubinstance(d, t)) 2523 else: 2524 self.assertFalse(issubinstance(d, t)) 2525 # list/set 2526 dt = (([1, "1", 2], List), (set({1, "1", 2}), Set)) 2527 for d, t in dt: 2528 self.assertTrue(issubinstance(d, t)) 2529 self.assertTrue(issubinstance(d, t[T_co])) # type: ignore[index] 2530 self.assertFalse(issubinstance(d, t[int])) # type: ignore[index] 2531 2532 # dict 2533 d = {"1": 1, "2": 2.0} 2534 self.assertTrue(issubinstance(d, Dict)) 2535 self.assertTrue(issubinstance(d, Dict[str, T_co])) 2536 self.assertFalse(issubinstance(d, Dict[str, int])) 2537 2538 # tuple 2539 d = (1, "1", 2) 2540 self.assertTrue(issubinstance(d, Tuple)) 2541 self.assertTrue(issubinstance(d, Tuple[int, str, T_co])) 2542 self.assertFalse(issubinstance(d, Tuple[int, Any])) 2543 self.assertFalse(issubinstance(d, Tuple[int, int, int])) 2544 2545 # Static checking annotation 2546 @skipTyping 2547 def test_compile_time(self): 2548 with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"): 2549 2550 class InvalidDP1(IterDataPipe[int]): 2551 def __iter__(self) -> str: # type: ignore[misc, override] 2552 yield 0 2553 2554 with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): 2555 2556 class InvalidDP2(IterDataPipe[Tuple]): 2557 def __iter__(self) -> Iterator[int]: # type: ignore[override] 2558 yield 0 2559 2560 with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): 2561 2562 class InvalidDP3(IterDataPipe[Tuple[int, str]]): 2563 def __iter__(self) -> Iterator[tuple]: # type: ignore[override] 2564 yield (0,) 2565 2566 if _generic_namedtuple_allowed: 2567 with self.assertRaisesRegex( 2568 TypeError, r"is not supported by Python typing" 2569 ): 2570 2571 class InvalidDP4(IterDataPipe["InvalidData[int]"]): # type: ignore[type-arg, misc] 2572 pass 2573 2574 class DP1(IterDataPipe[Tuple[int, str]]): 2575 def __init__(self, length): 2576 self.length = length 2577 2578 def __iter__(self) -> Iterator[Tuple[int, str]]: 2579 for d in range(self.length): 2580 yield d, str(d) 2581 2582 self.assertTrue(issubclass(DP1, IterDataPipe)) 2583 dp1 = DP1(10) 2584 self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type)) # type: ignore[attr-defined] 2585 dp1_ = DP1(5) 2586 self.assertEqual(dp1.type, dp1_.type) 2587 2588 with self.assertRaisesRegex(TypeError, r"is not a generic class"): 2589 2590 class InvalidDP5(DP1[tuple]): # type: ignore[type-arg] 2591 def __iter__(self) -> Iterator[tuple]: # type: ignore[override] 2592 yield (0,) 2593 2594 class DP2(IterDataPipe[T_co]): 2595 def __iter__(self) -> Iterator[T_co]: 2596 yield from range(10) # type: ignore[misc] 2597 2598 self.assertTrue(issubclass(DP2, IterDataPipe)) 2599 dp2 = DP2() # type: ignore[var-annotated] 2600 self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type)) # type: ignore[attr-defined] 2601 dp2_ = DP2() # type: ignore[var-annotated] 2602 self.assertEqual(dp2.type, dp2_.type) 2603 2604 class DP3(IterDataPipe[Tuple[T_co, str]]): 2605 r"""DataPipe without fixed type with __init__ function""" 2606 2607 def __init__(self, datasource): 2608 self.datasource = datasource 2609 2610 def __iter__(self) -> Iterator[Tuple[T_co, str]]: 2611 for d in self.datasource: 2612 yield d, str(d) 2613 2614 self.assertTrue(issubclass(DP3, IterDataPipe)) 2615 dp3 = DP3(range(10)) # type: ignore[var-annotated] 2616 self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type)) # type: ignore[attr-defined] 2617 dp3_ = DP3(5) # type: ignore[var-annotated] 2618 self.assertEqual(dp3.type, dp3_.type) 2619 2620 class DP4(IterDataPipe[tuple]): 2621 r"""DataPipe without __iter__ annotation""" 2622 2623 def __iter__(self): 2624 raise NotImplementedError 2625 2626 self.assertTrue(issubclass(DP4, IterDataPipe)) 2627 dp4 = DP4() 2628 self.assertTrue(dp4.type.param == tuple) 2629 2630 class DP5(IterDataPipe): 2631 r"""DataPipe without type annotation""" 2632 2633 def __iter__(self) -> Iterator[str]: 2634 raise NotImplementedError 2635 2636 self.assertTrue(issubclass(DP5, IterDataPipe)) 2637 dp5 = DP5() 2638 from torch.utils.data.datapipes._typing import issubtype 2639 2640 self.assertTrue( 2641 issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param) 2642 ) 2643 2644 class DP6(IterDataPipe[int]): 2645 r"""DataPipe with plain Iterator""" 2646 2647 def __iter__(self) -> Iterator: 2648 raise NotImplementedError 2649 2650 self.assertTrue(issubclass(DP6, IterDataPipe)) 2651 dp6 = DP6() 2652 self.assertTrue(dp6.type.param == int) 2653 2654 class DP7(IterDataPipe[Awaitable[T_co]]): 2655 r"""DataPipe with abstract base class""" 2656 2657 self.assertTrue(issubclass(DP7, IterDataPipe)) 2658 self.assertTrue(DP7.type.param == Awaitable[T_co]) # type: ignore[attr-defined] 2659 2660 class DP8(DP7[str]): 2661 r"""DataPipe subclass from a DataPipe with abc type""" 2662 2663 self.assertTrue(issubclass(DP8, IterDataPipe)) 2664 self.assertTrue(DP8.type.param == Awaitable[str]) # type: ignore[attr-defined] 2665 2666 @skipTyping 2667 def test_construct_time(self): 2668 class DP0(IterDataPipe[Tuple]): 2669 @argument_validation 2670 def __init__(self, dp: IterDataPipe): 2671 self.dp = dp 2672 2673 def __iter__(self) -> Iterator[Tuple]: 2674 for d in self.dp: 2675 yield d, str(d) 2676 2677 class DP1(IterDataPipe[int]): 2678 @argument_validation 2679 def __init__(self, dp: IterDataPipe[Tuple[int, str]]): 2680 self.dp = dp 2681 2682 def __iter__(self) -> Iterator[int]: 2683 for a, b in self.dp: 2684 yield a 2685 2686 # Non-DataPipe input with DataPipe hint 2687 datasource = [(1, "1"), (2, "2"), (3, "3")] 2688 with self.assertRaisesRegex( 2689 TypeError, r"Expected argument 'dp' as a IterDataPipe" 2690 ): 2691 dp0 = DP0(datasource) 2692 2693 dp0 = DP0(dp.iter.IterableWrapper(range(10))) 2694 with self.assertRaisesRegex( 2695 TypeError, r"Expected type of argument 'dp' as a subtype" 2696 ): 2697 dp1 = DP1(dp0) 2698 2699 @skipTyping 2700 def test_runtime(self): 2701 class DP(IterDataPipe[Tuple[int, T_co]]): 2702 def __init__(self, datasource): 2703 self.ds = datasource 2704 2705 @runtime_validation 2706 def __iter__(self) -> Iterator[Tuple[int, T_co]]: 2707 yield from self.ds 2708 2709 dss = ([(1, "1"), (2, "2")], [(1, 1), (2, "2")]) 2710 for ds in dss: 2711 dp0 = DP(ds) # type: ignore[var-annotated] 2712 self.assertEqual(list(dp0), ds) 2713 # Reset __iter__ 2714 self.assertEqual(list(dp0), ds) 2715 2716 dss = ( 2717 [(1, 1), ("2", 2)], # type: ignore[assignment, list-item] 2718 [[1, "1"], [2, "2"]], # type: ignore[list-item] 2719 [1, "1", 2, "2"], 2720 ) 2721 for ds in dss: 2722 dp0 = DP(ds) 2723 with self.assertRaisesRegex( 2724 RuntimeError, r"Expected an instance as subtype" 2725 ): 2726 list(dp0) 2727 2728 with runtime_validation_disabled(): 2729 self.assertEqual(list(dp0), ds) 2730 with runtime_validation_disabled(): 2731 self.assertEqual(list(dp0), ds) 2732 2733 with self.assertRaisesRegex( 2734 RuntimeError, r"Expected an instance as subtype" 2735 ): 2736 list(dp0) 2737 2738 @skipTyping 2739 def test_reinforce(self): 2740 T = TypeVar("T", int, str) 2741 2742 class DP(IterDataPipe[T]): 2743 def __init__(self, ds): 2744 self.ds = ds 2745 2746 @runtime_validation 2747 def __iter__(self) -> Iterator[T]: 2748 yield from self.ds 2749 2750 ds = list(range(10)) 2751 # Valid type reinforcement 2752 dp0 = DP(ds).reinforce_type(int) 2753 self.assertTrue(dp0.type, int) 2754 self.assertEqual(list(dp0), ds) 2755 2756 # Invalid type 2757 with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"): 2758 dp1 = DP(ds).reinforce_type(1) 2759 2760 # Type is not subtype 2761 with self.assertRaisesRegex( 2762 TypeError, r"Expected 'expected_type' as subtype of" 2763 ): 2764 dp2 = DP(ds).reinforce_type(float) 2765 2766 # Invalid data at runtime 2767 dp3 = DP(ds).reinforce_type(str) 2768 with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"): 2769 list(dp3) 2770 2771 # Context Manager to disable the runtime validation 2772 with runtime_validation_disabled(): 2773 self.assertEqual(list(dp3), ds) 2774 2775 2776class NumbersDataset(IterDataPipe): 2777 def __init__(self, size=10): 2778 self.size = size 2779 2780 def __iter__(self): 2781 yield from range(self.size) 2782 2783 def __len__(self): 2784 return self.size 2785 2786 2787class TestGraph(TestCase): 2788 class CustomIterDataPipe(IterDataPipe): 2789 def add_v(self, x): 2790 return x + self.v 2791 2792 def __init__(self, source_dp, v=1): 2793 self._dp = source_dp.map(self.add_v) 2794 self.v = 1 2795 2796 def __iter__(self): 2797 yield from self._dp 2798 2799 def __hash__(self): 2800 raise NotImplementedError 2801 2802 def test_simple_traverse(self): 2803 numbers_dp = NumbersDataset(size=50) 2804 shuffled_dp = numbers_dp.shuffle() 2805 sharded_dp = shuffled_dp.sharding_filter() 2806 mapped_dp = sharded_dp.map(lambda x: x * 10) 2807 graph = traverse_dps(mapped_dp) 2808 expected: Dict[Any, Any] = { 2809 id(mapped_dp): ( 2810 mapped_dp, 2811 { 2812 id(sharded_dp): ( 2813 sharded_dp, 2814 { 2815 id(shuffled_dp): ( 2816 shuffled_dp, 2817 {id(numbers_dp): (numbers_dp, {})}, 2818 ) 2819 }, 2820 ) 2821 }, 2822 ) 2823 } 2824 self.assertEqual(expected, graph) 2825 2826 dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) 2827 self.assertEqual(len(dps), 4) 2828 for datapipe in (numbers_dp, shuffled_dp, sharded_dp, mapped_dp): 2829 self.assertTrue(datapipe in dps) 2830 2831 def test_traverse_forked(self): 2832 numbers_dp = NumbersDataset(size=50) 2833 dp0, dp1, dp2 = numbers_dp.fork(num_instances=3) 2834 dp0_upd = dp0.map(lambda x: x * 10) 2835 dp1_upd = dp1.filter(lambda x: x % 3 == 1) 2836 combined_dp = dp0_upd.mux(dp1_upd, dp2) 2837 graph = traverse_dps(combined_dp) 2838 expected = { 2839 id(combined_dp): ( 2840 combined_dp, 2841 { 2842 id(dp0_upd): ( 2843 dp0_upd, 2844 { 2845 id(dp0): ( 2846 dp0, 2847 { 2848 id(dp0.main_datapipe): ( 2849 dp0.main_datapipe, 2850 { 2851 id(dp0.main_datapipe.main_datapipe): ( 2852 dp0.main_datapipe.main_datapipe, 2853 {}, 2854 ) 2855 }, 2856 ) 2857 }, 2858 ) 2859 }, 2860 ), 2861 id(dp1_upd): ( 2862 dp1_upd, 2863 { 2864 id(dp1): ( 2865 dp1, 2866 { 2867 id(dp1.main_datapipe): ( 2868 dp1.main_datapipe, 2869 { 2870 id(dp1.main_datapipe.main_datapipe): ( 2871 dp1.main_datapipe.main_datapipe, 2872 {}, 2873 ) 2874 }, 2875 ) 2876 }, 2877 ) 2878 }, 2879 ), 2880 id(dp2): ( 2881 dp2, 2882 { 2883 id(dp2.main_datapipe): ( 2884 dp2.main_datapipe, 2885 { 2886 id(dp2.main_datapipe.main_datapipe): ( 2887 dp2.main_datapipe.main_datapipe, 2888 {}, 2889 ) 2890 }, 2891 ) 2892 }, 2893 ), 2894 }, 2895 ) 2896 } 2897 self.assertEqual(expected, graph) 2898 2899 dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) 2900 self.assertEqual(len(dps), 8) 2901 for _dp in [ 2902 numbers_dp, 2903 dp0.main_datapipe, 2904 dp0, 2905 dp1, 2906 dp2, 2907 dp0_upd, 2908 dp1_upd, 2909 combined_dp, 2910 ]: 2911 self.assertTrue(_dp in dps) 2912 2913 def test_traverse_mapdatapipe(self): 2914 source_dp = dp.map.SequenceWrapper(range(10)) 2915 map_dp = source_dp.map(partial(_fake_add, 1)) 2916 graph = traverse_dps(map_dp) 2917 expected: Dict[Any, Any] = { 2918 id(map_dp): (map_dp, {id(source_dp): (source_dp, {})}) 2919 } 2920 self.assertEqual(expected, graph) 2921 2922 def test_traverse_mixdatapipe(self): 2923 source_map_dp = dp.map.SequenceWrapper(range(10)) 2924 iter_dp = dp.iter.IterableWrapper(source_map_dp) 2925 graph = traverse_dps(iter_dp) 2926 expected: Dict[Any, Any] = { 2927 id(iter_dp): (iter_dp, {id(source_map_dp): (source_map_dp, {})}) 2928 } 2929 self.assertEqual(expected, graph) 2930 2931 def test_traverse_circular_datapipe(self): 2932 source_iter_dp = dp.iter.IterableWrapper(list(range(10))) 2933 circular_dp = TestGraph.CustomIterDataPipe(source_iter_dp) 2934 graph = traverse_dps(circular_dp) 2935 # See issue: https://github.com/pytorch/data/issues/535 2936 expected: Dict[Any, Any] = { 2937 id(circular_dp): ( 2938 circular_dp, 2939 { 2940 id(circular_dp._dp): ( 2941 circular_dp._dp, 2942 {id(source_iter_dp): (source_iter_dp, {})}, 2943 ) 2944 }, 2945 ) 2946 } 2947 self.assertEqual(expected, graph) 2948 2949 dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) 2950 self.assertEqual(len(dps), 3) 2951 for _dp in [circular_dp, circular_dp._dp, source_iter_dp]: 2952 self.assertTrue(_dp in dps) 2953 2954 def test_traverse_unhashable_datapipe(self): 2955 source_iter_dp = dp.iter.IterableWrapper(list(range(10))) 2956 unhashable_dp = TestGraph.CustomIterDataPipe(source_iter_dp) 2957 graph = traverse_dps(unhashable_dp) 2958 with self.assertRaises(NotImplementedError): 2959 hash(unhashable_dp) 2960 expected: Dict[Any, Any] = { 2961 id(unhashable_dp): ( 2962 unhashable_dp, 2963 { 2964 id(unhashable_dp._dp): ( 2965 unhashable_dp._dp, 2966 {id(source_iter_dp): (source_iter_dp, {})}, 2967 ) 2968 }, 2969 ) 2970 } 2971 self.assertEqual(expected, graph) 2972 2973 2974def unbatch(x): 2975 return x[0] 2976 2977 2978class TestSerialization(TestCase): 2979 @skipIfNoDill 2980 def test_spawn_lambdas_iter(self): 2981 idp = dp.iter.IterableWrapper(range(3)).map(lambda x: x + 1).shuffle() 2982 dl = DataLoader( 2983 idp, 2984 num_workers=2, 2985 shuffle=True, 2986 multiprocessing_context="spawn", 2987 collate_fn=unbatch, 2988 batch_size=1, 2989 ) 2990 result = list(dl) 2991 self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result)) 2992 2993 @skipIfNoDill 2994 def test_spawn_lambdas_map(self): 2995 mdp = dp.map.SequenceWrapper(range(3)).map(lambda x: x + 1).shuffle() 2996 dl = DataLoader( 2997 mdp, 2998 num_workers=2, 2999 shuffle=True, 3000 multiprocessing_context="spawn", 3001 collate_fn=unbatch, 3002 batch_size=1, 3003 ) 3004 result = list(dl) 3005 self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result)) 3006 3007 3008class TestCircularSerialization(TestCase): 3009 class CustomIterDataPipe(IterDataPipe): 3010 @staticmethod 3011 def add_one(x): 3012 return x + 1 3013 3014 @classmethod 3015 def classify(cls, x): 3016 return 0 3017 3018 def add_v(self, x): 3019 return x + self.v 3020 3021 def __init__(self, fn, source_dp=None): 3022 self.fn = fn 3023 self.source_dp = ( 3024 source_dp if source_dp else dp.iter.IterableWrapper([1, 2, 4]) 3025 ) 3026 self._dp = ( 3027 self.source_dp.map(self.add_one) 3028 .map(self.add_v) 3029 .demux(2, self.classify)[0] 3030 ) 3031 self.v = 1 3032 3033 def __iter__(self): 3034 yield from self._dp 3035 3036 def test_circular_serialization_with_pickle(self): 3037 # Test for circular reference issue with pickle 3038 dp1 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn) 3039 self.assertTrue(list(dp1) == list(pickle.loads(pickle.dumps(dp1)))) 3040 3041 child_1 = dp1._dp 3042 dm_1 = child_1.main_datapipe 3043 m2_1 = dm_1.main_datapipe 3044 m1_1 = m2_1.datapipe 3045 src_1 = m1_1.datapipe 3046 3047 res1 = traverse_dps(dp1) 3048 exp_res_1 = { 3049 id(dp1): ( 3050 dp1, 3051 { 3052 id(src_1): (src_1, {}), 3053 id(child_1): ( 3054 child_1, 3055 { 3056 id(dm_1): ( 3057 dm_1, 3058 { 3059 id(m2_1): ( 3060 m2_1, 3061 {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}, 3062 ) 3063 }, 3064 ) 3065 }, 3066 ), 3067 }, 3068 ) 3069 } 3070 self.assertEqual(res1, exp_res_1) 3071 dp2 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn, source_dp=dp1) 3072 self.assertTrue(list(dp2) == list(pickle.loads(pickle.dumps(dp2)))) 3073 3074 child_2 = dp2._dp 3075 dm_2 = child_2.main_datapipe 3076 m2_2 = dm_2.main_datapipe 3077 m1_2 = m2_2.datapipe 3078 3079 res2 = traverse_dps(dp2) 3080 exp_res_2 = { 3081 id(dp2): ( 3082 dp2, 3083 { 3084 id(dp1): ( 3085 dp1, 3086 { 3087 id(src_1): (src_1, {}), 3088 id(child_1): ( 3089 child_1, 3090 { 3091 id(dm_1): ( 3092 dm_1, 3093 { 3094 id(m2_1): ( 3095 m2_1, 3096 { 3097 id(m1_1): ( 3098 m1_1, 3099 {id(src_1): (src_1, {})}, 3100 ) 3101 }, 3102 ) 3103 }, 3104 ) 3105 }, 3106 ), 3107 }, 3108 ), 3109 id(child_2): ( 3110 child_2, 3111 { 3112 id(dm_2): ( 3113 dm_2, 3114 { 3115 id(m2_2): ( 3116 m2_2, 3117 { 3118 id(m1_2): ( 3119 m1_2, 3120 { 3121 id(dp1): ( 3122 dp1, 3123 { 3124 id(src_1): (src_1, {}), 3125 id(child_1): ( 3126 child_1, 3127 { 3128 id(dm_1): ( 3129 dm_1, 3130 { 3131 id(m2_1): ( 3132 m2_1, 3133 { 3134 id( 3135 m1_1 3136 ): ( 3137 m1_1, 3138 { 3139 id( 3140 src_1 3141 ): ( 3142 src_1, 3143 {}, 3144 ) 3145 }, 3146 ) 3147 }, 3148 ) 3149 }, 3150 ) 3151 }, 3152 ), 3153 }, 3154 ), 3155 }, 3156 ) 3157 }, 3158 ) 3159 }, 3160 ) 3161 }, 3162 ), 3163 }, 3164 ) 3165 } 3166 self.assertEqual(res2, exp_res_2) 3167 3168 class LambdaIterDataPipe(CustomIterDataPipe): 3169 def __init__(self, fn, source_dp=None): 3170 super().__init__(fn, source_dp) 3171 self.container = [ 3172 lambda x: x + 1, 3173 ] 3174 self.lambda_fn = lambda x: x + 1 3175 self._dp = ( 3176 self.source_dp.map(self.add_one) 3177 .map(self.lambda_fn) 3178 .map(self.add_v) 3179 .demux(2, self.classify)[0] 3180 ) 3181 3182 @skipIfNoDill 3183 @skipIf(True, "Dill Tests") 3184 def test_circular_serialization_with_dill(self): 3185 # Test for circular reference issue with dill 3186 dp1 = TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1) 3187 self.assertTrue(list(dp1) == list(dill.loads(dill.dumps(dp1)))) 3188 3189 child_1 = dp1._dp 3190 dm_1 = child_1.main_datapipe 3191 m2_1 = dm_1.main_datapipe 3192 m1_1 = m2_1.datapipe 3193 src_1 = m1_1.datapipe 3194 3195 res1 = traverse_dps(dp1) 3196 3197 exp_res_1 = { 3198 id(dp1): ( 3199 dp1, 3200 { 3201 id(src_1): (src_1, {}), 3202 id(child_1): ( 3203 child_1, 3204 { 3205 id(dm_1): ( 3206 dm_1, 3207 { 3208 id(m2_1): ( 3209 m2_1, 3210 {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}, 3211 ) 3212 }, 3213 ) 3214 }, 3215 ), 3216 }, 3217 ) 3218 } 3219 3220 self.assertEqual(res1, exp_res_1) 3221 3222 dp2 = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn, source_dp=dp1) 3223 self.assertTrue(list(dp2) == list(dill.loads(dill.dumps(dp2)))) 3224 3225 child_2 = dp2._dp 3226 dm_2 = child_2.main_datapipe 3227 m2_2 = dm_2.main_datapipe 3228 m1_2 = m2_2.datapipe 3229 3230 res2 = traverse_dps(dp2) 3231 exp_res_2 = { 3232 id(dp2): ( 3233 dp2, 3234 { 3235 id(dp1): ( 3236 dp1, 3237 { 3238 id(src_1): (src_1, {}), 3239 id(child_1): ( 3240 child_1, 3241 { 3242 id(dm_1): ( 3243 dm_1, 3244 { 3245 id(m2_1): ( 3246 m2_1, 3247 { 3248 id(m1_1): ( 3249 m1_1, 3250 {id(src_1): (src_1, {})}, 3251 ) 3252 }, 3253 ) 3254 }, 3255 ) 3256 }, 3257 ), 3258 }, 3259 ), 3260 id(child_2): ( 3261 child_2, 3262 { 3263 id(dm_2): ( 3264 dm_2, 3265 { 3266 id(m2_2): ( 3267 m2_2, 3268 { 3269 id(m1_2): ( 3270 m1_2, 3271 { 3272 id(dp1): ( 3273 dp1, 3274 { 3275 id(src_1): (src_1, {}), 3276 id(child_1): ( 3277 child_1, 3278 { 3279 id(dm_1): ( 3280 dm_1, 3281 { 3282 id(m2_1): ( 3283 m2_1, 3284 { 3285 id( 3286 m1_1 3287 ): ( 3288 m1_1, 3289 { 3290 id( 3291 src_1 3292 ): ( 3293 src_1, 3294 {}, 3295 ) 3296 }, 3297 ) 3298 }, 3299 ) 3300 }, 3301 ) 3302 }, 3303 ), 3304 }, 3305 ), 3306 }, 3307 ) 3308 }, 3309 ) 3310 }, 3311 ) 3312 }, 3313 ), 3314 }, 3315 ) 3316 } 3317 self.assertEqual(res2, exp_res_2) 3318 3319 3320class CustomShardingIterDataPipe(IterDataPipe): 3321 def __init__(self, dp): 3322 self.dp = dp 3323 self.num_of_instances = 1 3324 self.instance_id = 0 3325 3326 def apply_sharding(self, num_of_instances, instance_id): 3327 self.num_of_instances = num_of_instances 3328 self.instance_id = instance_id 3329 3330 def __iter__(self): 3331 for i, d in enumerate(self.dp): 3332 if i % self.num_of_instances == self.instance_id: 3333 yield d 3334 3335 3336class TestSharding(TestCase): 3337 def _get_pipeline(self): 3338 numbers_dp = NumbersDataset(size=10) 3339 dp0, dp1 = numbers_dp.fork(num_instances=2) 3340 dp0_upd = dp0.map(_mul_10) 3341 dp1_upd = dp1.filter(_mod_3_test) 3342 combined_dp = dp0_upd.mux(dp1_upd) 3343 return combined_dp 3344 3345 def _get_dill_pipeline(self): 3346 numbers_dp = NumbersDataset(size=10) 3347 dp0, dp1 = numbers_dp.fork(num_instances=2) 3348 dp0_upd = dp0.map(lambda x: x * 10) 3349 dp1_upd = dp1.filter(lambda x: x % 3 == 1) 3350 combined_dp = dp0_upd.mux(dp1_upd) 3351 return combined_dp 3352 3353 def test_simple_sharding(self): 3354 sharded_dp = self._get_pipeline().sharding_filter() 3355 torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1) 3356 items = list(sharded_dp) 3357 self.assertEqual([1, 20], items) 3358 3359 all_items = [0, 1, 10, 4, 20, 7] 3360 items = [] 3361 for i in range(3): 3362 sharded_dp = self._get_pipeline().sharding_filter() 3363 torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, i) 3364 items += list(sharded_dp) 3365 self.assertEqual(sorted(all_items), sorted(items)) 3366 3367 def test_sharding_groups(self): 3368 def construct_sharded_pipe(): 3369 sharding_pipes = [] 3370 dp = NumbersDataset(size=90) 3371 dp = dp.sharding_filter( 3372 sharding_group_filter=SHARDING_PRIORITIES.DISTRIBUTED 3373 ) 3374 sharding_pipes.append(dp) 3375 dp = dp.sharding_filter( 3376 sharding_group_filter=SHARDING_PRIORITIES.MULTIPROCESSING 3377 ) 3378 sharding_pipes.append(dp) 3379 dp = dp.sharding_filter(sharding_group_filter=300) 3380 sharding_pipes.append(dp) 3381 return dp, sharding_pipes 3382 3383 dp, sharding_pipes = construct_sharded_pipe() 3384 3385 for pipe in sharding_pipes: 3386 pipe.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED) 3387 pipe.apply_sharding( 3388 5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING 3389 ) 3390 pipe.apply_sharding(3, 1, sharding_group=300) 3391 3392 actual = list(dp) 3393 expected = [17, 47, 77] 3394 self.assertEqual(expected, actual) 3395 self.assertEqual(3, len(dp)) 3396 3397 dp, _ = construct_sharded_pipe() 3398 dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) 3399 with self.assertRaises(Exception): 3400 dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) 3401 3402 dp, _ = construct_sharded_pipe() 3403 dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) 3404 with self.assertRaises(Exception): 3405 dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) 3406 3407 # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatbility 3408 # TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated 3409 def test_sharding_groups_in_legacy_grouping_package(self): 3410 with self.assertWarnsRegex( 3411 FutureWarning, 3412 r"Please use `SHARDING_PRIORITIES` " 3413 "from the `torch.utils.data.datapipes.iter.sharding`", 3414 ): 3415 from torch.utils.data.datapipes.iter.grouping import ( 3416 SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES, 3417 ) 3418 3419 def construct_sharded_pipe(): 3420 sharding_pipes = [] 3421 dp = NumbersDataset(size=90) 3422 dp = dp.sharding_filter( 3423 sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED 3424 ) 3425 sharding_pipes.append(dp) 3426 dp = dp.sharding_filter( 3427 sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3428 ) 3429 sharding_pipes.append(dp) 3430 dp = dp.sharding_filter(sharding_group_filter=300) 3431 sharding_pipes.append(dp) 3432 return dp, sharding_pipes 3433 3434 dp, sharding_pipes = construct_sharded_pipe() 3435 3436 for pipe in sharding_pipes: 3437 pipe.apply_sharding( 3438 2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED 3439 ) 3440 pipe.apply_sharding( 3441 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3442 ) 3443 pipe.apply_sharding(3, 1, sharding_group=300) 3444 3445 actual = list(dp) 3446 expected = [17, 47, 77] 3447 self.assertEqual(expected, actual) 3448 self.assertEqual(3, len(dp)) 3449 3450 dp, _ = construct_sharded_pipe() 3451 dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT) 3452 with self.assertRaises(Exception): 3453 dp.apply_sharding( 3454 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3455 ) 3456 3457 dp, _ = construct_sharded_pipe() 3458 dp.apply_sharding( 3459 5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING 3460 ) 3461 with self.assertRaises(Exception): 3462 dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT) 3463 3464 def test_legacy_custom_sharding(self): 3465 dp = self._get_pipeline() 3466 sharded_dp = CustomShardingIterDataPipe(dp) 3467 torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1) 3468 items = list(sharded_dp) 3469 self.assertEqual([1, 20], items) 3470 3471 def test_sharding_length(self): 3472 numbers_dp = dp.iter.IterableWrapper(range(13)) 3473 sharded_dp0 = numbers_dp.sharding_filter() 3474 torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 3, 0) 3475 sharded_dp1 = numbers_dp.sharding_filter() 3476 torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 3, 1) 3477 sharded_dp2 = numbers_dp.sharding_filter() 3478 torch.utils.data.graph_settings.apply_sharding(sharded_dp2, 3, 2) 3479 self.assertEqual(13, len(numbers_dp)) 3480 self.assertEqual(5, len(sharded_dp0)) 3481 self.assertEqual(4, len(sharded_dp1)) 3482 self.assertEqual(4, len(sharded_dp2)) 3483 3484 numbers_dp = dp.iter.IterableWrapper(range(1)) 3485 sharded_dp0 = numbers_dp.sharding_filter() 3486 torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 2, 0) 3487 sharded_dp1 = numbers_dp.sharding_filter() 3488 torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 2, 1) 3489 self.assertEqual(1, len(sharded_dp0)) 3490 self.assertEqual(0, len(sharded_dp1)) 3491 3492 def test_old_dataloader(self): 3493 dp0 = self._get_pipeline() 3494 expected = list(dp0) 3495 3496 dp0 = self._get_pipeline().sharding_filter() 3497 dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2) 3498 items = list(dl) 3499 3500 self.assertEqual(sorted(expected), sorted(items)) 3501 3502 def test_legacy_custom_sharding_with_old_dataloader(self): 3503 dp0 = self._get_pipeline() 3504 expected = list(dp0) 3505 3506 dp0 = self._get_pipeline() 3507 dp0 = CustomShardingIterDataPipe(dp0) 3508 dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2) 3509 items = list(dl) 3510 3511 self.assertEqual(sorted(expected), sorted(items)) 3512 3513 def test_multi_sharding(self): 3514 # Raises Error when multiple sharding on the single branch 3515 numbers_dp = dp.iter.IterableWrapper(range(13)) 3516 sharded_dp = numbers_dp.sharding_filter() 3517 sharded_dp = sharded_dp.sharding_filter() 3518 with self.assertRaisesRegex( 3519 RuntimeError, "Sharding twice on a single pipeline" 3520 ): 3521 torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0) 3522 3523 # Raises Error when sharding on both data source and branch 3524 numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter() 3525 dp1, dp2 = numbers_dp.fork(2) 3526 sharded_dp = dp1.sharding_filter() 3527 zip_dp = dp2.zip(sharded_dp) 3528 with self.assertRaisesRegex( 3529 RuntimeError, "Sharding twice on a single pipeline" 3530 ): 3531 torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3532 3533 # Raises Error when multiple sharding on the branch and end 3534 numbers_dp = dp.iter.IterableWrapper(range(13)) 3535 dp1, dp2 = numbers_dp.fork(2) 3536 sharded_dp = dp1.sharding_filter() 3537 zip_dp = dp2.zip(sharded_dp).sharding_filter() 3538 with self.assertRaisesRegex( 3539 RuntimeError, "Sharding twice on a single pipeline" 3540 ): 3541 torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3542 3543 # Single sharding_filter on data source 3544 numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter() 3545 dp1, dp2 = numbers_dp.fork(2) 3546 zip_dp = dp1.zip(dp2) 3547 torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3548 self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)]) 3549 3550 # Single sharding_filter per branch 3551 numbers_dp = dp.iter.IterableWrapper(range(13)) 3552 dp1, dp2 = numbers_dp.fork(2) 3553 sharded_dp1 = dp1.sharding_filter() 3554 sharded_dp2 = dp2.sharding_filter() 3555 zip_dp = sharded_dp1.zip(sharded_dp2) 3556 torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0) 3557 self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)]) 3558 3559 3560class TestIterDataPipeSingletonConstraint(TestCase): 3561 r""" 3562 Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older 3563 iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior. 3564 """ 3565 3566 def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe): 3567 r""" 3568 Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of 3569 a second iterator invalidates the first one. 3570 """ 3571 it1 = iter(source_dp) 3572 self.assertEqual(list(range(10)), list(it1)) 3573 it1 = iter(source_dp) 3574 self.assertEqual( 3575 list(range(10)), list(it1) 3576 ) # A fresh iterator can be read in full again 3577 it1 = iter(source_dp) 3578 self.assertEqual(0, next(it1)) 3579 it2 = iter(source_dp) # This should invalidate `it1` 3580 self.assertEqual(0, next(it2)) # Should read from the beginning again 3581 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3582 next(it1) 3583 3584 def test_iterdatapipe_singleton_generator(self): 3585 r""" 3586 Testing for the case where IterDataPipe's `__iter__` is a generator function. 3587 """ 3588 3589 # Functional Test: Check if invalidation logic is correct 3590 source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10)) 3591 self._check_single_iterator_invalidation_logic(source_dp) 3592 3593 # Functional Test: extend the test to a pipeline 3594 dps = source_dp.map(_fake_fn).filter(_fake_filter_fn) 3595 self._check_single_iterator_invalidation_logic(dps) 3596 3597 # Functional Test: multiple simultaneous references to the same DataPipe fails 3598 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3599 for _ in zip(source_dp, source_dp): 3600 pass 3601 3602 # Function Test: sequential references work 3603 for _ in zip(list(source_dp), list(source_dp)): 3604 pass 3605 3606 def test_iterdatapipe_singleton_self_next(self): 3607 r""" 3608 Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method 3609 Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`). 3610 """ 3611 3612 class _CustomIterDP_Self(IterDataPipe): 3613 def __init__(self, iterable): 3614 self.source = iterable 3615 self.iterable = iter(iterable) 3616 3617 def __iter__(self): 3618 self.reset() 3619 return self 3620 3621 def __next__(self): 3622 return next(self.iterable) 3623 3624 def reset(self): 3625 self.iterable = iter(self.source) 3626 3627 # Functional Test: Check that every `__iter__` call returns the same object 3628 source_dp = _CustomIterDP_Self(range(10)) 3629 res = list(source_dp) 3630 it = iter(source_dp) 3631 self.assertEqual(res, list(it)) 3632 3633 # Functional Test: Check if invalidation logic is correct 3634 source_dp = _CustomIterDP_Self(range(10)) 3635 self._check_single_iterator_invalidation_logic(source_dp) 3636 self.assertEqual( 3637 1, next(source_dp) 3638 ) # `source_dp` is still valid and can be read 3639 3640 # Functional Test: extend the test to a pipeline 3641 source_dp = _CustomIterDP_Self( 3642 dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn) 3643 ) 3644 self._check_single_iterator_invalidation_logic(source_dp) 3645 self.assertEqual( 3646 1, next(source_dp) 3647 ) # `source_dp` is still valid and can be read 3648 3649 # Functional Test: multiple simultaneous references to the same DataPipe fails 3650 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3651 for _ in zip(source_dp, source_dp): 3652 pass 3653 3654 def test_iterdatapipe_singleton_new_object(self): 3655 r""" 3656 Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`, 3657 and there isn't a `__next__` method. 3658 """ 3659 3660 class _CustomIterDP(IterDataPipe): 3661 def __init__(self, iterable): 3662 self.iterable = iter(iterable) 3663 3664 def __iter__(self): # Note that this doesn't reset 3665 return self.iterable # Intentionally not returning `self` 3666 3667 # Functional Test: Check if invalidation logic is correct 3668 source_dp = _CustomIterDP(range(10)) 3669 it1 = iter(source_dp) 3670 self.assertEqual(0, next(it1)) 3671 it2 = iter(source_dp) 3672 self.assertEqual(1, next(it2)) 3673 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3674 next(it1) 3675 3676 # Functional Test: extend the test to a pipeline 3677 source_dp = _CustomIterDP( 3678 dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn) 3679 ) 3680 it1 = iter(source_dp) 3681 self.assertEqual(0, next(it1)) 3682 it2 = iter(source_dp) 3683 self.assertEqual(1, next(it2)) 3684 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3685 next(it1) 3686 3687 # Functional Test: multiple simultaneous references to the same DataPipe fails 3688 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3689 for _ in zip(source_dp, source_dp): 3690 pass 3691 3692 def test_iterdatapipe_singleton_buggy(self): 3693 r""" 3694 Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has 3695 a `__next__` method. 3696 """ 3697 3698 class _CustomIterDP(IterDataPipe): 3699 def __init__(self, iterable): 3700 self.source = iterable 3701 self.iterable = iter(iterable) 3702 3703 def __iter__(self): 3704 return iter(self.source) # Intentionally not returning `self` 3705 3706 def __next__(self): 3707 return next(self.iterable) 3708 3709 # Functional Test: Check if invalidation logic is correct 3710 source_dp = _CustomIterDP(range(10)) 3711 self._check_single_iterator_invalidation_logic(source_dp) 3712 self.assertEqual(0, next(source_dp)) # `__next__` is unrelated with `__iter__` 3713 3714 # Functional Test: Special case to show `__next__` is unrelated with `__iter__` 3715 source_dp = _CustomIterDP(range(10)) 3716 self.assertEqual(0, next(source_dp)) 3717 it1 = iter(source_dp) 3718 self.assertEqual(0, next(it1)) 3719 self.assertEqual(1, next(source_dp)) 3720 it2 = iter(source_dp) # invalidates both `it1` 3721 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3722 next(it1) 3723 self.assertEqual(2, next(source_dp)) # not impacted by the creation of `it2` 3724 self.assertEqual( 3725 list(range(10)), list(it2) 3726 ) # `it2` still works because it is a new object 3727 3728 def test_iterdatapipe_singleton_constraint_multiple_outputs(self): 3729 r""" 3730 Testing for the case where IterDataPipe has multiple child DataPipes as outputs. 3731 """ 3732 # Functional Test: all previous related iterators should be invalidated when a new iterator 3733 # is created from a ChildDataPipe 3734 source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10)) 3735 cdp1, cdp2 = source_dp.fork(num_instances=2) 3736 it1, it2 = iter(cdp1), iter(cdp2) 3737 self.assertEqual(list(range(10)), list(it1)) 3738 self.assertEqual(list(range(10)), list(it2)) 3739 it1, it2 = iter(cdp1), iter(cdp2) 3740 with warnings.catch_warnings(record=True) as wa: 3741 it3 = iter(cdp1) # This should invalidate `it1` and `it2` 3742 self.assertEqual(len(wa), 1) 3743 self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 3744 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3745 next(it1) 3746 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3747 next(it2) 3748 self.assertEqual(0, next(it3)) 3749 # The next line should not invalidate anything, as there was no new iterator created 3750 # for `cdp2` after `it2` was invalidated 3751 it4 = iter(cdp2) 3752 self.assertEqual(1, next(it3)) # An error shouldn't be raised here 3753 self.assertEqual(list(range(10)), list(it4)) 3754 3755 # Functional Test: invalidation when a new iterator is created from `source_dp` 3756 source_dp = dp.iter.IterableWrapper(range(10)) 3757 cdp1, cdp2 = source_dp.fork(num_instances=2) 3758 it1, it2 = iter(cdp1), iter(cdp2) 3759 self.assertEqual(list(range(10)), list(it1)) 3760 self.assertEqual(list(range(10)), list(it2)) 3761 it1, it2 = iter(cdp1), iter(cdp2) 3762 self.assertEqual(0, next(it1)) 3763 self.assertEqual(0, next(it2)) 3764 it3 = iter(source_dp) # note that a new iterator is created from `source_dp` 3765 self.assertEqual( 3766 0, next(it3) 3767 ) # `it3` should invalidate `it1` and `it2` since they both use `source_dp` 3768 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3769 next(it1) 3770 self.assertEqual(1, next(it3)) 3771 3772 # Function Test: Extending test to pipeline 3773 source_dp = ( 3774 dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn) 3775 ) 3776 cdp1, cdp2 = source_dp.fork(num_instances=2) 3777 it1, it2 = iter(cdp1), iter(cdp2) 3778 self.assertEqual(list(range(10)), list(it1)) 3779 self.assertEqual(list(range(10)), list(it2)) 3780 it1, it2 = iter(cdp1), iter(cdp2) 3781 with warnings.catch_warnings(record=True) as wa: 3782 it3 = iter(cdp1) # This should invalidate `it1` and `it2` 3783 self.assertEqual(len(wa), 1) 3784 self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 3785 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3786 next(it1) 3787 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3788 next(it2) 3789 with warnings.catch_warnings(record=True) as wa: 3790 it1, it2 = iter(cdp1), iter(cdp2) 3791 self.assertEqual(len(wa), 1) 3792 self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted") 3793 self.assertEqual(0, next(it1)) 3794 self.assertEqual(0, next(it2)) 3795 it3 = iter(source_dp) # note that a new iterator is created from `source_dp` 3796 self.assertEqual( 3797 0, next(it3) 3798 ) # `it3` should invalidate `it1` and `it2` since they both use `source_dp` 3799 with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"): 3800 next(it1) 3801 self.assertEqual(1, next(it3)) 3802 3803 3804class TestIterDataPipeCountSampleYielded(TestCase): 3805 def _yield_count_test_helper(self, datapipe, n_expected_samples): 3806 # Functional Test: Check if number of samples yielded is as expected 3807 res = list(datapipe) 3808 self.assertEqual(len(res), datapipe._number_of_samples_yielded) 3809 3810 # Functional Test: Check if the count is correct when DataPipe is partially read 3811 it = iter(datapipe) 3812 res = [] 3813 for i, value in enumerate(it): 3814 res.append(value) 3815 if i == n_expected_samples - 1: 3816 break 3817 self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded) 3818 3819 # Functional Test: Check for reset behavior and if iterator also works 3820 it = iter(datapipe) # reset the DataPipe 3821 res = list(it) 3822 self.assertEqual(len(res), datapipe._number_of_samples_yielded) 3823 3824 def test_iterdatapipe_sample_yielded_generator_function(self): 3825 # Functional Test: `__iter__` is a generator function 3826 datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10)) 3827 self._yield_count_test_helper(datapipe, n_expected_samples=5) 3828 3829 def test_iterdatapipe_sample_yielded_generator_function_exception(self): 3830 # Functional Test: `__iter__` is a custom generator function with exception 3831 class _CustomGeneratorFnDataPipe(IterDataPipe): 3832 # This class's `__iter__` has a Runtime Error 3833 def __iter__(self): 3834 yield 0 3835 yield 1 3836 yield 2 3837 raise RuntimeError("Custom test error after yielding 3 elements") 3838 yield 3 3839 3840 # Functional Test: Ensure the count is correct even when exception is raised 3841 datapipe: IterDataPipe = _CustomGeneratorFnDataPipe() 3842 with self.assertRaisesRegex( 3843 RuntimeError, "Custom test error after yielding 3 elements" 3844 ): 3845 list(datapipe) 3846 self.assertEqual(3, datapipe._number_of_samples_yielded) 3847 3848 # Functional Test: Check for reset behavior and if iterator also works 3849 it = iter(datapipe) # reset the DataPipe 3850 with self.assertRaisesRegex( 3851 RuntimeError, "Custom test error after yielding 3 elements" 3852 ): 3853 list(it) 3854 self.assertEqual(3, datapipe._number_of_samples_yielded) 3855 3856 def test_iterdatapipe_sample_yielded_return_self(self): 3857 class _CustomGeneratorDataPipe(IterDataPipe): 3858 # This class's `__iter__` is not a generator function 3859 def __init__(self) -> None: 3860 self.source = iter(range(10)) 3861 3862 def __iter__(self): 3863 return self.source 3864 3865 def reset(self): 3866 self.source = iter(range(10)) 3867 3868 datapipe: IterDataPipe = _CustomGeneratorDataPipe() 3869 self._yield_count_test_helper(datapipe, n_expected_samples=5) 3870 3871 def test_iterdatapipe_sample_yielded_next(self): 3872 class _CustomNextDataPipe(IterDataPipe): 3873 # This class's `__iter__` returns `self` and has a `__next__` 3874 def __init__(self) -> None: 3875 self.source = iter(range(10)) 3876 3877 def __iter__(self): 3878 return self 3879 3880 def __next__(self): 3881 return next(self.source) 3882 3883 def reset(self): 3884 self.source = iter(range(10)) 3885 3886 datapipe: IterDataPipe = _CustomNextDataPipe() 3887 self._yield_count_test_helper(datapipe, n_expected_samples=5) 3888 3889 def test_iterdatapipe_sample_yielded_next_exception(self): 3890 class _CustomNextDataPipe(IterDataPipe): 3891 # This class's `__iter__` returns `self` and has a `__next__` 3892 def __init__(self) -> None: 3893 self.source = iter(range(10)) 3894 self.count = 0 3895 3896 def __iter__(self): 3897 return self 3898 3899 def __next__(self): 3900 if self.count == 3: 3901 raise RuntimeError("Custom test error after yielding 3 elements") 3902 self.count += 1 3903 return next(self.source) 3904 3905 def reset(self): 3906 self.count = 0 3907 self.source = iter(range(10)) 3908 3909 # Functional Test: Ensure the count is correct even when exception is raised 3910 datapipe: IterDataPipe = _CustomNextDataPipe() 3911 with self.assertRaisesRegex( 3912 RuntimeError, "Custom test error after yielding 3 elements" 3913 ): 3914 list(datapipe) 3915 self.assertEqual(3, datapipe._number_of_samples_yielded) 3916 3917 # Functional Test: Check for reset behavior and if iterator also works 3918 it = iter(datapipe) # reset the DataPipe 3919 with self.assertRaisesRegex( 3920 RuntimeError, "Custom test error after yielding 3 elements" 3921 ): 3922 list(it) 3923 self.assertEqual(3, datapipe._number_of_samples_yielded) 3924 3925 3926class _CustomNonGeneratorTestDataPipe(IterDataPipe): 3927 def __init__(self) -> None: 3928 self.n = 10 3929 self.source = list(range(self.n)) 3930 3931 # This class's `__iter__` is not a generator function 3932 def __iter__(self): 3933 return iter(self.source) 3934 3935 def __len__(self): 3936 return self.n 3937 3938 3939class _CustomSelfNextTestDataPipe(IterDataPipe): 3940 def __init__(self) -> None: 3941 self.n = 10 3942 self.iter = iter(range(self.n)) 3943 3944 def __iter__(self): 3945 return self 3946 3947 def __next__(self): 3948 return next(self.iter) 3949 3950 def reset(self): 3951 self.iter = iter(range(self.n)) 3952 3953 def __len__(self): 3954 return self.n 3955 3956 3957class TestIterDataPipeGraphFastForward(TestCase): 3958 def _fast_forward_graph_test_helper( 3959 self, datapipe, fast_forward_fn, expected_res, n_iterations=3, rng=None 3960 ): 3961 if rng is None: 3962 rng = torch.Generator() 3963 rng = rng.manual_seed(0) 3964 torch.utils.data.graph_settings.apply_random_seed(datapipe, rng) 3965 3966 # Test Case: fast forward works with list 3967 rng.manual_seed(0) 3968 fast_forward_fn(datapipe, n_iterations, rng) 3969 actual_res = list(datapipe) 3970 self.assertEqual(len(datapipe) - n_iterations, len(actual_res)) 3971 self.assertEqual(expected_res[n_iterations:], actual_res) 3972 3973 # Test Case: fast forward works with iterator 3974 rng.manual_seed(0) 3975 fast_forward_fn(datapipe, n_iterations, rng) 3976 it = iter(datapipe) 3977 actual_res = list(it) 3978 self.assertEqual(len(datapipe) - n_iterations, len(actual_res)) 3979 self.assertEqual(expected_res[n_iterations:], actual_res) 3980 with self.assertRaises(StopIteration): 3981 next(it) 3982 3983 def test_simple_snapshot_graph(self): 3984 graph1 = dp.iter.IterableWrapper(range(10)) 3985 res1 = list(range(10)) 3986 self._fast_forward_graph_test_helper( 3987 graph1, _simple_graph_snapshot_restoration, expected_res=res1 3988 ) 3989 3990 graph2 = graph1.map(_mul_10) 3991 res2 = [10 * x for x in res1] 3992 self._fast_forward_graph_test_helper( 3993 graph2, _simple_graph_snapshot_restoration, expected_res=res2 3994 ) 3995 3996 rng = torch.Generator() 3997 graph3 = graph2.shuffle() 3998 rng.manual_seed(0) 3999 torch.utils.data.graph_settings.apply_random_seed(graph3, rng) 4000 res3 = list(graph3) 4001 self._fast_forward_graph_test_helper( 4002 graph3, _simple_graph_snapshot_restoration, expected_res=res3 4003 ) 4004 4005 graph4 = graph3.map(_mul_10) 4006 res4 = [10 * x for x in res3] 4007 self._fast_forward_graph_test_helper( 4008 graph4, _simple_graph_snapshot_restoration, expected_res=res4 4009 ) 4010 4011 batch_size = 2 4012 graph5 = graph4.batch(batch_size) 4013 res5 = [ 4014 res4[i : i + batch_size] for i in range(0, len(res4), batch_size) 4015 ] # .batch(2) 4016 self._fast_forward_graph_test_helper( 4017 graph5, _simple_graph_snapshot_restoration, expected_res=res5 4018 ) 4019 4020 # With `fork` and `zip` 4021 cdp1, cdp2 = graph5.fork(2) 4022 graph6 = cdp1.zip(cdp2) 4023 rng = rng.manual_seed(100) 4024 torch.utils.data.graph_settings.apply_random_seed(graph6, rng) 4025 res6 = [(x, x) for x in res5] 4026 self._fast_forward_graph_test_helper( 4027 graph6, _simple_graph_snapshot_restoration, expected_res=res6 4028 ) 4029 4030 # With `fork` and `concat` 4031 graph7 = cdp1.concat(cdp2) 4032 res7 = res5 * 2 4033 self._fast_forward_graph_test_helper( 4034 graph7, _simple_graph_snapshot_restoration, expected_res=res7 4035 ) 4036 4037 # Raises an exception if the graph has already been restored 4038 with self.assertRaisesRegex( 4039 RuntimeError, "Snapshot restoration cannot be applied." 4040 ): 4041 _simple_graph_snapshot_restoration(graph7, 1) 4042 _simple_graph_snapshot_restoration(graph7, 1) 4043 4044 def test_simple_snapshot_custom_non_generator(self): 4045 graph = _CustomNonGeneratorTestDataPipe() 4046 self._fast_forward_graph_test_helper( 4047 graph, _simple_graph_snapshot_restoration, expected_res=range(10) 4048 ) 4049 4050 def test_simple_snapshot_custom_self_next(self): 4051 graph = _CustomSelfNextTestDataPipe() 4052 self._fast_forward_graph_test_helper( 4053 graph, _simple_graph_snapshot_restoration, expected_res=range(10) 4054 ) 4055 4056 def _snapshot_test_helper(self, datapipe, expected_res, n_iter=3, rng=None): 4057 """ 4058 Extend the previous test with serialization and deserialization test. 4059 """ 4060 if rng is None: 4061 rng = torch.Generator() 4062 rng.manual_seed(0) 4063 torch.utils.data.graph_settings.apply_random_seed(datapipe, rng) 4064 it = iter(datapipe) 4065 for _ in range(n_iter): 4066 next(it) 4067 serialized_graph = pickle.dumps(datapipe) 4068 deserialized_graph = pickle.loads(serialized_graph) 4069 self.assertEqual(n_iter, datapipe._number_of_samples_yielded) 4070 self.assertEqual(n_iter, deserialized_graph._number_of_samples_yielded) 4071 4072 rng_for_deserialized = torch.Generator() 4073 rng_for_deserialized.manual_seed(0) 4074 _simple_graph_snapshot_restoration( 4075 deserialized_graph, n_iter, rng=rng_for_deserialized 4076 ) 4077 self.assertEqual(expected_res[n_iter:], list(it)) 4078 self.assertEqual(expected_res[n_iter:], list(deserialized_graph)) 4079 4080 def test_simple_snapshot_graph_with_serialization(self): 4081 graph1 = dp.iter.IterableWrapper(range(10)) 4082 res1 = list(range(10)) 4083 self._snapshot_test_helper(graph1, expected_res=res1) 4084 4085 graph2 = graph1.map(_mul_10) 4086 res2 = [10 * x for x in res1] 4087 self._snapshot_test_helper(graph2, expected_res=res2) 4088 4089 rng = torch.Generator() 4090 graph3 = graph2.shuffle() 4091 rng.manual_seed(0) 4092 torch.utils.data.graph_settings.apply_random_seed(graph3, rng) 4093 res3 = list(graph3) 4094 self._snapshot_test_helper(graph3, expected_res=res3) 4095 4096 graph4 = graph3.map(_mul_10) 4097 res4 = [10 * x for x in res3] 4098 self._snapshot_test_helper(graph4, expected_res=res4) 4099 4100 batch_size = 2 4101 graph5 = graph4.batch(batch_size) 4102 res5 = [ 4103 res4[i : i + batch_size] for i in range(0, len(res4), batch_size) 4104 ] # .batch(2) 4105 self._snapshot_test_helper(graph5, expected_res=res5) 4106 4107 # With `fork` and `zip` 4108 cdp1, cdp2 = graph5.fork(2) 4109 graph6 = cdp1.zip(cdp2) 4110 res6 = [(x, x) for x in res5] 4111 self._snapshot_test_helper(graph6, expected_res=res6) 4112 4113 # With `fork` and `concat` 4114 graph7 = cdp1.concat(cdp2) 4115 res7 = res5 * 2 4116 self._snapshot_test_helper(graph7, expected_res=res7) 4117 4118 def test_simple_snapshot_graph_repeated(self): 4119 cdp1, cdp2 = ( 4120 dp.iter.IterableWrapper(range(10)) 4121 .map(_mul_10) 4122 .shuffle() 4123 .map(_mul_10) 4124 .map(_mul_10) 4125 .fork(2) 4126 ) 4127 graph = cdp1.zip(cdp2) 4128 4129 rng = torch.Generator() 4130 rng.manual_seed(0) 4131 torch.utils.data.graph_settings.apply_random_seed(graph, rng) 4132 4133 # Get expected result 4134 expected_res = list(graph) 4135 4136 rng.manual_seed(0) 4137 torch.utils.data.graph_settings.apply_random_seed(graph, rng) 4138 it = iter(graph) 4139 n_iter = 3 4140 for _ in range(n_iter): 4141 next(it) 4142 4143 # First serialization/deserialization 4144 serialized_graph = pickle.dumps(graph) 4145 deserialized_graph = pickle.loads(serialized_graph) 4146 4147 rng_for_deserialized = torch.Generator() 4148 rng_for_deserialized.manual_seed(0) 4149 _simple_graph_snapshot_restoration( 4150 deserialized_graph, 4151 deserialized_graph._number_of_samples_yielded, 4152 rng=rng_for_deserialized, 4153 ) 4154 4155 it = iter(deserialized_graph) 4156 # Get the next element and ensure it is as expected 4157 self.assertEqual(expected_res[3], next(it)) 4158 4159 # Serializalize/Deserialize and fast-forward again after to ensure it works 4160 serialized_graph2 = pickle.dumps(deserialized_graph) 4161 deserialized_graph2 = pickle.loads(serialized_graph2) 4162 4163 rng_for_deserialized = torch.Generator() 4164 rng_for_deserialized.manual_seed(0) 4165 _simple_graph_snapshot_restoration( 4166 deserialized_graph2, 4167 deserialized_graph._number_of_samples_yielded, 4168 rng=rng_for_deserialized, 4169 ) 4170 4171 # Get the next element and ensure it is as expected 4172 self.assertEqual(expected_res[4:], list(deserialized_graph2)) 4173 4174 4175if __name__ == "__main__": 4176 run_tests() 4177