1{ 2 "metadata": { 3 "language_info": { 4 "codemirror_mode": { 5 "name": "ipython", 6 "version": 3 7 }, 8 "file_extension": ".py", 9 "mimetype": "text/x-python", 10 "name": "python", 11 "nbconvert_exporter": "python", 12 "pygments_lexer": "ipython3", 13 "version": "3.6.10" 14 }, 15 "orig_nbformat": 2, 16 "kernelspec": { 17 "name": "python3610jvsc74a57bd0eb5e09632d6ea1cbf3eb9da7e37b7cf581db5ed13074b21cc44e159dc62acdab", 18 "display_name": "Python 3.6.10 64-bit ('dataloader': conda)" 19 } 20 }, 21 "nbformat": 4, 22 "nbformat_minor": 2, 23 "cells": [ 24 { 25 "source": [ 26 "## Standard flow control and data processing DataPipes" 27 ], 28 "cell_type": "markdown", 29 "metadata": {} 30 }, 31 { 32 "cell_type": "code", 33 "execution_count": 1, 34 "metadata": {}, 35 "outputs": [], 36 "source": [ 37 "from torch.utils.data import IterDataPipe" 38 ] 39 }, 40 { 41 "cell_type": "code", 42 "execution_count": 2, 43 "metadata": {}, 44 "outputs": [], 45 "source": [ 46 "# Example IterDataPipe\n", 47 "class ExampleIterPipe(IterDataPipe):\n", 48 " def __init__(self, range = 20):\n", 49 " self.range = range\n", 50 " def __iter__(self):\n", 51 " for i in range(self.range):\n", 52 " yield i" 53 ] 54 }, 55 { 56 "source": [ 57 "## Batch\n", 58 "\n", 59 "Function: `batch`\n", 60 "\n", 61 "Description: \n", 62 "\n", 63 "Alternatives:\n", 64 "\n", 65 "Arguments:\n", 66 " - `batch_size: int` desired batch size\n", 67 " - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n", 68 " - `drop_last: bool = False`\n", 69 "\n", 70 "Example:\n", 71 "\n", 72 "Classic batching produce partial batches by default\n" 73 ], 74 "cell_type": "markdown", 75 "metadata": {} 76 }, 77 { 78 "cell_type": "code", 79 "execution_count": 3, 80 "metadata": {}, 81 "outputs": [ 82 { 83 "output_type": "stream", 84 "name": "stdout", 85 "text": [ 86 "[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n[9]\n" 87 ] 88 } 89 ], 90 "source": [ 91 "dp = ExampleIterPipe(10).batch(3)\n", 92 "for i in dp:\n", 93 " print(i)" 94 ] 95 }, 96 { 97 "source": [ 98 "To drop incomplete batches add `drop_last` argument" 99 ], 100 "cell_type": "markdown", 101 "metadata": {} 102 }, 103 { 104 "cell_type": "code", 105 "execution_count": 4, 106 "metadata": {}, 107 "outputs": [ 108 { 109 "output_type": "stream", 110 "name": "stdout", 111 "text": [ 112 "[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n" 113 ] 114 } 115 ], 116 "source": [ 117 "dp = ExampleIterPipe(10).batch(3, drop_last = True)\n", 118 "for i in dp:\n", 119 " print(i)" 120 ] 121 }, 122 { 123 "source": [ 124 "Sequential calling of `batch` produce nested batches" 125 ], 126 "cell_type": "markdown", 127 "metadata": {} 128 }, 129 { 130 "cell_type": "code", 131 "execution_count": 5, 132 "metadata": {}, 133 "outputs": [ 134 { 135 "output_type": "stream", 136 "name": "stdout", 137 "text": [ 138 "[[0, 1, 2], [3, 4, 5]]\n[[6, 7, 8], [9, 10, 11]]\n[[12, 13, 14], [15, 16, 17]]\n[[18, 19, 20], [21, 22, 23]]\n[[24, 25, 26], [27, 28, 29]]\n" 139 ] 140 } 141 ], 142 "source": [ 143 "dp = ExampleIterPipe(30).batch(3).batch(2)\n", 144 "for i in dp:\n", 145 " print(i)" 146 ] 147 }, 148 { 149 "source": [ 150 "It is possible to unbatch source data before applying the new batching rule using `unbatch_level` argument" 151 ], 152 "cell_type": "markdown", 153 "metadata": {} 154 }, 155 { 156 "cell_type": "code", 157 "execution_count": 6, 158 "metadata": {}, 159 "outputs": [ 160 { 161 "output_type": "stream", 162 "name": "stdout", 163 "text": [ 164 "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]\n[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n" 165 ] 166 } 167 ], 168 "source": [ 169 "dp = ExampleIterPipe(30).batch(3).batch(2).batch(10, unbatch_level=-1)\n", 170 "for i in dp:\n", 171 " print(i)" 172 ] 173 }, 174 { 175 "source": [ 176 "## Unbatch\n", 177 "\n", 178 "Function: `unbatch`\n", 179 "\n", 180 "Description: \n", 181 "\n", 182 "Alternatives:\n", 183 "\n", 184 "Arguments:\n", 185 " `unbatch_level:int = 1`\n", 186 " \n", 187 "Example:" 188 ], 189 "cell_type": "markdown", 190 "metadata": {} 191 }, 192 { 193 "cell_type": "code", 194 "execution_count": 7, 195 "metadata": {}, 196 "outputs": [ 197 { 198 "output_type": "stream", 199 "name": "stdout", 200 "text": [ 201 "9\n0\n1\n2\n6\n7\n8\n3\n4\n5\n" 202 ] 203 } 204 ], 205 "source": [ 206 "dp = ExampleIterPipe(10).batch(3).shuffle().unbatch()\n", 207 "for i in dp:\n", 208 " print(i)" 209 ] 210 }, 211 { 212 "source": [ 213 "By default unbatching is applied only on the first layer, to unbatch deeper use `unbatch_level` argument" 214 ], 215 "cell_type": "markdown", 216 "metadata": {} 217 }, 218 { 219 "cell_type": "code", 220 "execution_count": 8, 221 "metadata": {}, 222 "outputs": [ 223 { 224 "output_type": "stream", 225 "name": "stdout", 226 "text": [ 227 "[0, 1]\n[2, 3]\n[4, 5]\n[6, 7]\n[8, 9]\n[10, 11]\n[12, 13]\n[14, 15]\n[16, 17]\n[18, 19]\n[20, 21]\n[22, 23]\n[24, 25]\n[26, 27]\n[28, 29]\n[30, 31]\n[32, 33]\n[34, 35]\n[36, 37]\n[38, 39]\n" 228 ] 229 } 230 ], 231 "source": [ 232 "dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = 2)\n", 233 "for i in dp:\n", 234 " print(i)" 235 ] 236 }, 237 { 238 "source": [ 239 "Setting `unbatch_level` to `-1` will unbatch to the lowest level" 240 ], 241 "cell_type": "markdown", 242 "metadata": {} 243 }, 244 { 245 "cell_type": "code", 246 "execution_count": 9, 247 "metadata": {}, 248 "outputs": [ 249 { 250 "output_type": "stream", 251 "name": "stdout", 252 "text": [ 253 "0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16\n17\n18\n19\n20\n21\n22\n23\n24\n25\n26\n27\n28\n29\n30\n31\n32\n33\n34\n35\n36\n37\n38\n39\n" 254 ] 255 } 256 ], 257 "source": [ 258 "dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = -1)\n", 259 "for i in dp:\n", 260 " print(i)" 261 ] 262 }, 263 { 264 "source": [ 265 "## Map\n", 266 "\n", 267 "Function: `map`\n", 268 "\n", 269 "Description: \n", 270 "\n", 271 "Alternatives:\n", 272 "\n", 273 "Arguments:\n", 274 " - `nesting_level: int = 0`\n", 275 " \n", 276 "Example:" 277 ], 278 "cell_type": "markdown", 279 "metadata": {} 280 }, 281 { 282 "cell_type": "code", 283 "execution_count": 10, 284 "metadata": {}, 285 "outputs": [ 286 { 287 "output_type": "stream", 288 "name": "stdout", 289 "text": [ 290 "0\n2\n4\n6\n8\n10\n12\n14\n16\n18\n" 291 ] 292 } 293 ], 294 "source": [ 295 "dp = ExampleIterPipe(10).map(lambda x: x * 2)\n", 296 "for i in dp:\n", 297 " print(i)" 298 ] 299 }, 300 { 301 "source": [ 302 "`map` by default applies function to every mini-batch as a whole\n" 303 ], 304 "cell_type": "markdown", 305 "metadata": {} 306 }, 307 { 308 "cell_type": "code", 309 "execution_count": 11, 310 "metadata": {}, 311 "outputs": [ 312 { 313 "output_type": "stream", 314 "name": "stdout", 315 "text": [ 316 "[0, 1, 2, 0, 1, 2]\n[3, 4, 5, 3, 4, 5]\n[6, 7, 8, 6, 7, 8]\n[9, 9]\n" 317 ] 318 } 319 ], 320 "source": [ 321 "dp = ExampleIterPipe(10).batch(3).map(lambda x: x * 2)\n", 322 "for i in dp:\n", 323 " print(i)" 324 ] 325 }, 326 { 327 "source": [ 328 "To apply function on individual items of the mini-batch use `nesting_level` argument" 329 ], 330 "cell_type": "markdown", 331 "metadata": {} 332 }, 333 { 334 "cell_type": "code", 335 "execution_count": 12, 336 "metadata": {}, 337 "outputs": [ 338 { 339 "output_type": "stream", 340 "name": "stdout", 341 "text": [ 342 "[[0, 2, 4], [6, 8, 10]]\n[[12, 14, 16], [18]]\n" 343 ] 344 } 345 ], 346 "source": [ 347 "dp = ExampleIterPipe(10).batch(3).batch(2).map(lambda x: x * 2, nesting_level = 2)\n", 348 "for i in dp:\n", 349 " print(i)" 350 ] 351 }, 352 { 353 "source": [ 354 "Setting `nesting_level` to `-1` will apply `map` function to the lowest level possible" 355 ], 356 "cell_type": "markdown", 357 "metadata": {} 358 }, 359 { 360 "cell_type": "code", 361 "execution_count": 13, 362 "metadata": {}, 363 "outputs": [ 364 { 365 "output_type": "stream", 366 "name": "stdout", 367 "text": [ 368 "[[[0, 2, 4], [6, 8, 10]], [[12, 14, 16], [18]]]\n" 369 ] 370 } 371 ], 372 "source": [ 373 "dp = ExampleIterPipe(10).batch(3).batch(2).batch(2).map(lambda x: x * 2, nesting_level = -1)\n", 374 "for i in dp:\n", 375 " print(i)" 376 ] 377 }, 378 { 379 "source": [ 380 "## Filter\n", 381 "\n", 382 "Function: `filter`\n", 383 "\n", 384 "Description: \n", 385 "\n", 386 "Alternatives:\n", 387 "\n", 388 "Arguments:\n", 389 " - `nesting_level: int = 0`\n", 390 " - `drop_empty_batches = True` whether empty many batches dropped or not.\n", 391 " \n", 392 "Example:" 393 ], 394 "cell_type": "markdown", 395 "metadata": {} 396 }, 397 { 398 "cell_type": "code", 399 "execution_count": 14, 400 "metadata": {}, 401 "outputs": [ 402 { 403 "output_type": "stream", 404 "name": "stdout", 405 "text": [ 406 "0\n2\n4\n6\n8\n" 407 ] 408 } 409 ], 410 "source": [ 411 "dp = ExampleIterPipe(10).filter(lambda x: x % 2 == 0)\n", 412 "for i in dp:\n", 413 " print(i)" 414 ] 415 }, 416 { 417 "source": [ 418 "Classic `filter` by default applies filter function to every mini-batches as a whole \n" 419 ], 420 "cell_type": "markdown", 421 "metadata": {} 422 }, 423 { 424 "cell_type": "code", 425 "execution_count": 15, 426 "metadata": {}, 427 "outputs": [ 428 { 429 "output_type": "stream", 430 "name": "stdout", 431 "text": [ 432 "[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n" 433 ] 434 } 435 ], 436 "source": [ 437 "dp = ExampleIterPipe(10)\n", 438 "dp = dp.batch(3).filter(lambda x: len(x) > 2)\n", 439 "for i in dp:\n", 440 " print(i)" 441 ] 442 }, 443 { 444 "source": [ 445 "You can apply filter function on individual elements by setting `nesting_level` argument" 446 ], 447 "cell_type": "markdown", 448 "metadata": {} 449 }, 450 { 451 "cell_type": "code", 452 "execution_count": 16, 453 "metadata": {}, 454 "outputs": [ 455 { 456 "output_type": "stream", 457 "name": "stdout", 458 "text": [ 459 "[5]\n[6, 7, 8]\n[9]\n" 460 ] 461 } 462 ], 463 "source": [ 464 "dp = ExampleIterPipe(10)\n", 465 "dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = 1)\n", 466 "for i in dp:\n", 467 " print(i)" 468 ] 469 }, 470 { 471 "source": [ 472 "If mini-batch ends with zero elements after filtering default behaviour would be to drop them from the response. You can override this behaviour using `drop_empty_batches` argument.\n" 473 ], 474 "cell_type": "markdown", 475 "metadata": {} 476 }, 477 { 478 "cell_type": "code", 479 "execution_count": 17, 480 "metadata": {}, 481 "outputs": [ 482 { 483 "output_type": "stream", 484 "name": "stdout", 485 "text": [ 486 "[]\n[5]\n[6, 7, 8]\n[9]\n" 487 ] 488 } 489 ], 490 "source": [ 491 "dp = ExampleIterPipe(10)\n", 492 "dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = -1, drop_empty_batches = False)\n", 493 "for i in dp:\n", 494 " print(i)" 495 ] 496 }, 497 { 498 "cell_type": "code", 499 "execution_count": 18, 500 "metadata": {}, 501 "outputs": [ 502 { 503 "output_type": "stream", 504 "name": "stdout", 505 "text": [ 506 "[[[0, 1, 2], [3]], [[], [10, 11]]]\n[[[12, 13, 14], [15, 16, 17]], [[18, 19]]]\n" 507 ] 508 } 509 ], 510 "source": [ 511 "dp = ExampleIterPipe(20)\n", 512 "dp = dp.batch(3).batch(2).batch(2).filter(lambda x: x < 4 or x > 9 , nesting_level = -1, drop_empty_batches = False)\n", 513 "for i in dp:\n", 514 " print(i)" 515 ] 516 }, 517 { 518 "source": [ 519 "## Shuffle\n", 520 "\n", 521 "Function: `shuffle`\n", 522 "\n", 523 "Description: \n", 524 "\n", 525 "Alternatives:\n", 526 "\n", 527 "Arguments:\n", 528 " - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n", 529 " - `buffer_size: int = 10000`\n", 530 " \n", 531 "Example:" 532 ], 533 "cell_type": "markdown", 534 "metadata": {} 535 }, 536 { 537 "cell_type": "code", 538 "execution_count": 19, 539 "metadata": {}, 540 "outputs": [ 541 { 542 "output_type": "stream", 543 "name": "stdout", 544 "text": [ 545 "2\n9\n4\n0\n3\n7\n8\n5\n6\n1\n" 546 ] 547 } 548 ], 549 "source": [ 550 "dp = ExampleIterPipe(10).shuffle()\n", 551 "for i in dp:\n", 552 " print(i)" 553 ] 554 }, 555 { 556 "source": [ 557 "`shuffle` operates on input mini-batches similar as on individual items" 558 ], 559 "cell_type": "markdown", 560 "metadata": {} 561 }, 562 { 563 "cell_type": "code", 564 "execution_count": 20, 565 "metadata": {}, 566 "outputs": [ 567 { 568 "output_type": "stream", 569 "name": "stdout", 570 "text": [ 571 "[0, 1, 2]\n[3, 4, 5]\n[9]\n[6, 7, 8]\n" 572 ] 573 } 574 ], 575 "source": [ 576 "dp = ExampleIterPipe(10).batch(3).shuffle()\n", 577 "for i in dp:\n", 578 " print(i)" 579 ] 580 }, 581 { 582 "source": [ 583 "To shuffle elements across batches use `shuffle(unbatch_level)` followed by `batch` pattern " 584 ], 585 "cell_type": "markdown", 586 "metadata": {} 587 }, 588 { 589 "cell_type": "code", 590 "execution_count": 21, 591 "metadata": {}, 592 "outputs": [ 593 { 594 "output_type": "stream", 595 "name": "stdout", 596 "text": [ 597 "[2, 1, 0]\n[7, 9, 6]\n[3, 5, 4]\n[8]\n" 598 ] 599 } 600 ], 601 "source": [ 602 "dp = ExampleIterPipe(10).batch(3).shuffle(unbatch_level = -1).batch(3)\n", 603 "for i in dp:\n", 604 " print(i)" 605 ] 606 }, 607 { 608 "source": [ 609 "## Collate\n", 610 "\n", 611 "Function: `collate`\n", 612 "\n", 613 "Description: \n", 614 "\n", 615 "Alternatives:\n", 616 "\n", 617 "Arguments:\n", 618 " \n", 619 "Example:" 620 ], 621 "cell_type": "markdown", 622 "metadata": {} 623 }, 624 { 625 "cell_type": "code", 626 "execution_count": 22, 627 "metadata": {}, 628 "outputs": [ 629 { 630 "output_type": "stream", 631 "name": "stdout", 632 "text": [ 633 "tensor([0, 1, 2])\ntensor([3, 4, 5])\ntensor([6, 7, 8])\ntensor([9])\n" 634 ] 635 } 636 ], 637 "source": [ 638 "dp = ExampleIterPipe(10).batch(3).collate()\n", 639 "for i in dp:\n", 640 " print(i)" 641 ] 642 }, 643 { 644 "source": [ 645 "## GroupBy\n", 646 "\n", 647 "Function: `groupby`\n", 648 "\n", 649 "Usage: `dp.groupby(lambda x: x[0])`\n", 650 "\n", 651 "Description: Batching items by combining items with same key into same batch \n", 652 "\n", 653 "Arguments:\n", 654 " - `group_key_fn`\n", 655 " - `group_size` - yeild resulted group as soon as `group_size` elements accumulated\n", 656 " - `guaranteed_group_size:int = None`\n", 657 " - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n", 658 "\n", 659 "#### Attention\n", 660 "As datasteam can be arbitrary large, grouping is done on best effort basis and there is no guarantee that same key will never present in the different groups. You can call it local groupby where locallity is the one DataPipe process/thread." 661 ], 662 "cell_type": "markdown", 663 "metadata": {} 664 }, 665 { 666 "cell_type": "code", 667 "execution_count": 23, 668 "metadata": {}, 669 "outputs": [ 670 { 671 "output_type": "stream", 672 "name": "stdout", 673 "text": [ 674 "[0, 3, 6, 9]\n[1, 4, 7]\n[5, 2, 8]\n" 675 ] 676 } 677 ], 678 "source": [ 679 "dp = ExampleIterPipe(10).shuffle().groupby(lambda x: x % 3)\n", 680 "for i in dp:\n", 681 " print(i)" 682 ] 683 }, 684 { 685 "source": [ 686 "By default group key function is applied to entire input (mini-batch)" 687 ], 688 "cell_type": "markdown", 689 "metadata": {} 690 }, 691 { 692 "cell_type": "code", 693 "execution_count": 24, 694 "metadata": {}, 695 "outputs": [ 696 { 697 "output_type": "stream", 698 "name": "stdout", 699 "text": [ 700 "[[0, 1, 2], [3, 4, 5], [6, 7, 8]]\n[[9]]\n" 701 ] 702 } 703 ], 704 "source": [ 705 "dp = ExampleIterPipe(10).batch(3).groupby(lambda x: len(x))\n", 706 "for i in dp:\n", 707 " print(i)" 708 ] 709 }, 710 { 711 "source": [ 712 "It is possible to unnest items from the mini-batches using `unbatch_level` argument" 713 ], 714 "cell_type": "markdown", 715 "metadata": {} 716 }, 717 { 718 "cell_type": "code", 719 "execution_count": 25, 720 "metadata": {}, 721 "outputs": [ 722 { 723 "output_type": "stream", 724 "name": "stdout", 725 "text": [ 726 "[0, 3, 6, 9]\n[1, 4, 7]\n[2, 5, 8]\n" 727 ] 728 } 729 ], 730 "source": [ 731 "dp = ExampleIterPipe(10).batch(3).groupby(lambda x: x % 3, unbatch_level = 1)\n", 732 "for i in dp:\n", 733 " print(i)" 734 ] 735 }, 736 { 737 "source": [ 738 "When internal buffer (defined by `buffer_size`) is overfilled, groupby will yield biggest group available" 739 ], 740 "cell_type": "markdown", 741 "metadata": {} 742 }, 743 { 744 "cell_type": "code", 745 "execution_count": 26, 746 "metadata": {}, 747 "outputs": [ 748 { 749 "output_type": "stream", 750 "name": "stdout", 751 "text": [ 752 "[9, 3]\n[13, 4, 7]\n[2, 11, 14, 5]\n[0, 6, 12]\n[1, 10]\n[8]\n" 753 ] 754 } 755 ], 756 "source": [ 757 "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, buffer_size = 5)\n", 758 "for i in dp:\n", 759 " print(i)" 760 ] 761 }, 762 { 763 "source": [ 764 "`groupby` will produce `group_size` sized batches on as fast as possible basis" 765 ], 766 "cell_type": "markdown", 767 "metadata": {} 768 }, 769 { 770 "cell_type": "code", 771 "execution_count": 27, 772 "metadata": {}, 773 "outputs": [ 774 { 775 "output_type": "stream", 776 "name": "stdout", 777 "text": [ 778 "[6, 3, 12]\n[1, 16, 7]\n[2, 5, 8]\n[14, 11, 17]\n[15, 9, 0]\n[10, 4, 13]\n" 779 ] 780 } 781 ], 782 "source": [ 783 "dp = ExampleIterPipe(18).shuffle().groupby(lambda x: x % 3, group_size = 3)\n", 784 "for i in dp:\n", 785 " print(i)" 786 ] 787 }, 788 { 789 "source": [ 790 "Remaining groups must be at least `guaranteed_group_size` big. " 791 ], 792 "cell_type": "markdown", 793 "metadata": {} 794 }, 795 { 796 "cell_type": "code", 797 "execution_count": 28, 798 "metadata": {}, 799 "outputs": [ 800 { 801 "output_type": "stream", 802 "name": "stdout", 803 "text": [ 804 "[11, 2, 5]\n[1, 4, 10]\n[0, 9, 6]\n[14, 8]\n[13, 7]\n[12, 3]\n" 805 ] 806 } 807 ], 808 "source": [ 809 "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, group_size = 3, guaranteed_group_size = 2)\n", 810 "for i in dp:\n", 811 " print(i)" 812 ] 813 }, 814 { 815 "source": [ 816 "Without defined `group_size` function will try to accumulate at least `guaranteed_group_size` elements before yielding resulted group" 817 ], 818 "cell_type": "markdown", 819 "metadata": {} 820 }, 821 { 822 "cell_type": "code", 823 "execution_count": 29, 824 "metadata": {}, 825 "outputs": [ 826 { 827 "output_type": "stream", 828 "name": "stdout", 829 "text": [ 830 "[3, 6, 9, 12, 0]\n[14, 2, 8, 11, 5]\n[7, 4, 1, 13, 10]\n" 831 ] 832 } 833 ], 834 "source": [ 835 "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2)\n", 836 "for i in dp:\n", 837 " print(i)" 838 ] 839 }, 840 { 841 "source": [ 842 "This behaviour becomes noticable when data is bigger than buffer and some groups getting evicted before gathering all potential items" 843 ], 844 "cell_type": "markdown", 845 "metadata": {} 846 }, 847 { 848 "cell_type": "code", 849 "execution_count": 30, 850 "metadata": {}, 851 "outputs": [ 852 { 853 "output_type": "stream", 854 "name": "stdout", 855 "text": [ 856 "[0, 3]\n[1, 4, 7]\n[2, 5, 8]\n[6, 9, 12]\n[10, 13]\n[11, 14]\n" 857 ] 858 } 859 ], 860 "source": [ 861 "dp = ExampleIterPipe(15).groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n", 862 "for i in dp:\n", 863 " print(i)" 864 ] 865 }, 866 { 867 "source": [ 868 "With randomness involved you might end up with incomplete groups (so next example expected to fail in most cases)" 869 ], 870 "cell_type": "markdown", 871 "metadata": {} 872 }, 873 { 874 "cell_type": "code", 875 "execution_count": 31, 876 "metadata": {}, 877 "outputs": [ 878 { 879 "output_type": "stream", 880 "name": "stdout", 881 "text": [ 882 "[14, 5, 11]\n[1, 7, 4, 10]\n[0, 12, 6]\n[8, 2]\n[9, 3]\n" 883 ] 884 }, 885 { 886 "output_type": "error", 887 "ename": "Exception", 888 "evalue": "('Failed to group items', '[13]')", 889 "traceback": [ 890 "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 891 "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", 892 "\u001b[0;32m<ipython-input-31-673b9dd7fb43>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mExampleIterPipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroupby\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguaranteed_group_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffer_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 893 "\u001b[0;32m~/dataset/pytorch/torch/utils/data/datapipes/iter/grouping.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop_remaining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 277\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Failed to group items'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbiggest_key\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 278\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 894 "\u001b[0;31mException\u001b[0m: ('Failed to group items', '[13]')" 895 ] 896 } 897 ], 898 "source": [ 899 "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n", 900 "for i in dp:\n", 901 " print(i)" 902 ] 903 }, 904 { 905 "source": [ 906 "To avoid this error and drop incomplete groups, use `drop_remaining` argument" 907 ], 908 "cell_type": "markdown", 909 "metadata": {} 910 }, 911 { 912 "cell_type": "code", 913 "execution_count": 32, 914 "metadata": {}, 915 "outputs": [ 916 { 917 "output_type": "stream", 918 "name": "stdout", 919 "text": [ 920 "[5, 2, 14]\n[4, 7, 13, 1, 10]\n[12, 6, 3, 9]\n[8, 11]\n" 921 ] 922 } 923 ], 924 "source": [ 925 "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6, drop_remaining = True)\n", 926 "for i in dp:\n", 927 " print(i)" 928 ] 929 }, 930 { 931 "source": [ 932 "## Zip\n", 933 "\n", 934 "Function: `zip`\n", 935 "\n", 936 "Description: \n", 937 "\n", 938 "Alternatives:\n", 939 "\n", 940 "Arguments:\n", 941 " \n", 942 "Example:" 943 ], 944 "cell_type": "markdown", 945 "metadata": {} 946 }, 947 { 948 "cell_type": "code", 949 "execution_count": 35, 950 "metadata": {}, 951 "outputs": [ 952 { 953 "output_type": "stream", 954 "name": "stdout", 955 "text": [ 956 "(0, 3)\n(1, 0)\n(2, 4)\n(3, 2)\n(4, 1)\n" 957 ] 958 } 959 ], 960 "source": [ 961 "_dp = ExampleIterPipe(5).shuffle()\n", 962 "dp = ExampleIterPipe(5).zip(_dp)\n", 963 "for i in dp:\n", 964 " print(i)" 965 ] 966 }, 967 { 968 "source": [ 969 "## Fork\n", 970 "\n", 971 "Function: `fork`\n", 972 "\n", 973 "Description: \n", 974 "\n", 975 "Alternatives:\n", 976 "\n", 977 "Arguments:\n", 978 " \n", 979 "Example:" 980 ], 981 "cell_type": "markdown", 982 "metadata": {} 983 }, 984 { 985 "cell_type": "code", 986 "execution_count": 36, 987 "metadata": {}, 988 "outputs": [ 989 { 990 "output_type": "stream", 991 "name": "stdout", 992 "text": [ 993 "0\n1\n0\n1\n0\n1\n" 994 ] 995 } 996 ], 997 "source": [ 998 "dp = ExampleIterPipe(2)\n", 999 "dp1, dp2, dp3 = dp.fork(3)\n", 1000 "for i in dp1 + dp2 + dp3:\n", 1001 " print(i)" 1002 ] 1003 }, 1004 { 1005 "cell_type": "markdown", 1006 "metadata": {}, 1007 "source": [ 1008 "## Demultiplexer\n", 1009 "\n", 1010 "Function: `demux`\n", 1011 "\n", 1012 "Description: \n", 1013 "\n", 1014 "Alternatives:\n", 1015 "\n", 1016 "Arguments:\n", 1017 " \n", 1018 "Example:" 1019 ] 1020 }, 1021 { 1022 "cell_type": "code", 1023 "execution_count": 32, 1024 "metadata": {}, 1025 "outputs": [ 1026 { 1027 "name": "stdout", 1028 "output_type": "stream", 1029 "text": [ 1030 "1\n", 1031 "4\n", 1032 "7\n" 1033 ] 1034 } 1035 ], 1036 "source": [ 1037 "dp = ExampleIterPipe(10)\n", 1038 "dp1, dp2, dp3 = dp.demux(3, lambda x: x % 3)\n", 1039 "for i in dp2:\n", 1040 " print(i)" 1041 ] 1042 }, 1043 { 1044 "cell_type": "markdown", 1045 "metadata": {}, 1046 "source": [ 1047 "## Multiplexer\n", 1048 "\n", 1049 "Function: `mux`\n", 1050 "\n", 1051 "Description: \n", 1052 "\n", 1053 "Alternatives:\n", 1054 "\n", 1055 "Arguments:\n", 1056 " \n", 1057 "Example:" 1058 ] 1059 }, 1060 { 1061 "cell_type": "code", 1062 "execution_count": 34, 1063 "metadata": {}, 1064 "outputs": [ 1065 { 1066 "name": "stdout", 1067 "output_type": "stream", 1068 "text": [ 1069 "0\n", 1070 "0\n", 1071 "0\n", 1072 "1\n", 1073 "10\n", 1074 "100\n", 1075 "2\n", 1076 "20\n", 1077 "200\n" 1078 ] 1079 } 1080 ], 1081 "source": [ 1082 "dp1 = ExampleIterPipe(3)\n", 1083 "dp2 = ExampleIterPipe(3).map(lambda x: x * 10)\n", 1084 "dp3 = ExampleIterPipe(3).map(lambda x: x * 100)\n", 1085 "\n", 1086 "dp = dp1.mux(dp2, dp3)\n", 1087 "for i in dp:\n", 1088 " print(i)" 1089 ] 1090 }, 1091 { 1092 "source": [ 1093 "## Concat\n", 1094 "\n", 1095 "Function: `concat`\n", 1096 "\n", 1097 "Description: Returns DataPipes with elements from the first datapipe following by elements from second datapipes\n", 1098 "\n", 1099 "Alternatives:\n", 1100 " \n", 1101 " `dp = dp.concat(dp2, dp3)`\n", 1102 " `dp = dp.concat(*datapipes_list)`\n", 1103 "\n", 1104 "Example:\n" 1105 ], 1106 "cell_type": "markdown", 1107 "metadata": {} 1108 }, 1109 { 1110 "cell_type": "code", 1111 "execution_count": 37, 1112 "metadata": {}, 1113 "outputs": [ 1114 { 1115 "output_type": "stream", 1116 "name": "stdout", 1117 "text": [ 1118 "0\n1\n2\n3\n0\n1\n2\n" 1119 ] 1120 } 1121 ], 1122 "source": [ 1123 "dp = ExampleIterPipe(4)\n", 1124 "dp2 = ExampleIterPipe(3)\n", 1125 "dp = dp.concat(dp2)\n", 1126 "for i in dp:\n", 1127 " print(i)" 1128 ] 1129 } 1130 ] 1131}