xref: /aosp_15_r20/external/pytorch/torch/utils/data/standard_pipes.ipynb (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1{
2 "metadata": {
3  "language_info": {
4   "codemirror_mode": {
5    "name": "ipython",
6    "version": 3
7   },
8   "file_extension": ".py",
9   "mimetype": "text/x-python",
10   "name": "python",
11   "nbconvert_exporter": "python",
12   "pygments_lexer": "ipython3",
13   "version": "3.6.10"
14  },
15  "orig_nbformat": 2,
16  "kernelspec": {
17   "name": "python3610jvsc74a57bd0eb5e09632d6ea1cbf3eb9da7e37b7cf581db5ed13074b21cc44e159dc62acdab",
18   "display_name": "Python 3.6.10 64-bit ('dataloader': conda)"
19  }
20 },
21 "nbformat": 4,
22 "nbformat_minor": 2,
23 "cells": [
24  {
25   "source": [
26    "## Standard flow control and data processing DataPipes"
27   ],
28   "cell_type": "markdown",
29   "metadata": {}
30  },
31  {
32   "cell_type": "code",
33   "execution_count": 1,
34   "metadata": {},
35   "outputs": [],
36   "source": [
37    "from torch.utils.data import IterDataPipe"
38   ]
39  },
40  {
41   "cell_type": "code",
42   "execution_count": 2,
43   "metadata": {},
44   "outputs": [],
45   "source": [
46    "# Example IterDataPipe\n",
47    "class ExampleIterPipe(IterDataPipe):\n",
48    "    def __init__(self, range = 20):\n",
49    "        self.range = range\n",
50    "    def __iter__(self):\n",
51    "        for i in range(self.range):\n",
52    "            yield i"
53   ]
54  },
55  {
56   "source": [
57    "## Batch\n",
58    "\n",
59    "Function: `batch`\n",
60    "\n",
61    "Description: \n",
62    "\n",
63    "Alternatives:\n",
64    "\n",
65    "Arguments:\n",
66    "  - `batch_size: int` desired batch size\n",
67    "  - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
68    "  - `drop_last: bool = False`\n",
69    "\n",
70    "Example:\n",
71    "\n",
72    "Classic batching produce partial batches by default\n"
73   ],
74   "cell_type": "markdown",
75   "metadata": {}
76  },
77  {
78   "cell_type": "code",
79   "execution_count": 3,
80   "metadata": {},
81   "outputs": [
82    {
83     "output_type": "stream",
84     "name": "stdout",
85     "text": [
86      "[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n[9]\n"
87     ]
88    }
89   ],
90   "source": [
91    "dp = ExampleIterPipe(10).batch(3)\n",
92    "for i in dp:\n",
93    "    print(i)"
94   ]
95  },
96  {
97   "source": [
98    "To drop incomplete batches add `drop_last` argument"
99   ],
100   "cell_type": "markdown",
101   "metadata": {}
102  },
103  {
104   "cell_type": "code",
105   "execution_count": 4,
106   "metadata": {},
107   "outputs": [
108    {
109     "output_type": "stream",
110     "name": "stdout",
111     "text": [
112      "[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n"
113     ]
114    }
115   ],
116   "source": [
117    "dp = ExampleIterPipe(10).batch(3, drop_last = True)\n",
118    "for i in dp:\n",
119    "    print(i)"
120   ]
121  },
122  {
123   "source": [
124    "Sequential calling of `batch` produce nested batches"
125   ],
126   "cell_type": "markdown",
127   "metadata": {}
128  },
129  {
130   "cell_type": "code",
131   "execution_count": 5,
132   "metadata": {},
133   "outputs": [
134    {
135     "output_type": "stream",
136     "name": "stdout",
137     "text": [
138      "[[0, 1, 2], [3, 4, 5]]\n[[6, 7, 8], [9, 10, 11]]\n[[12, 13, 14], [15, 16, 17]]\n[[18, 19, 20], [21, 22, 23]]\n[[24, 25, 26], [27, 28, 29]]\n"
139     ]
140    }
141   ],
142   "source": [
143    "dp = ExampleIterPipe(30).batch(3).batch(2)\n",
144    "for i in dp:\n",
145    "    print(i)"
146   ]
147  },
148  {
149   "source": [
150    "It is possible to unbatch source data before applying the new batching rule using `unbatch_level` argument"
151   ],
152   "cell_type": "markdown",
153   "metadata": {}
154  },
155  {
156   "cell_type": "code",
157   "execution_count": 6,
158   "metadata": {},
159   "outputs": [
160    {
161     "output_type": "stream",
162     "name": "stdout",
163     "text": [
164      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]\n[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n"
165     ]
166    }
167   ],
168   "source": [
169    "dp = ExampleIterPipe(30).batch(3).batch(2).batch(10, unbatch_level=-1)\n",
170    "for i in dp:\n",
171    "    print(i)"
172   ]
173  },
174  {
175   "source": [
176    "## Unbatch\n",
177    "\n",
178    "Function: `unbatch`\n",
179    "\n",
180    "Description: \n",
181    "\n",
182    "Alternatives:\n",
183    "\n",
184    "Arguments:\n",
185    "    `unbatch_level:int = 1`\n",
186    " \n",
187    "Example:"
188   ],
189   "cell_type": "markdown",
190   "metadata": {}
191  },
192  {
193   "cell_type": "code",
194   "execution_count": 7,
195   "metadata": {},
196   "outputs": [
197    {
198     "output_type": "stream",
199     "name": "stdout",
200     "text": [
201      "9\n0\n1\n2\n6\n7\n8\n3\n4\n5\n"
202     ]
203    }
204   ],
205   "source": [
206    "dp = ExampleIterPipe(10).batch(3).shuffle().unbatch()\n",
207    "for i in dp:\n",
208    "    print(i)"
209   ]
210  },
211  {
212   "source": [
213    "By default unbatching is applied only on the first layer, to unbatch deeper use `unbatch_level` argument"
214   ],
215   "cell_type": "markdown",
216   "metadata": {}
217  },
218  {
219   "cell_type": "code",
220   "execution_count": 8,
221   "metadata": {},
222   "outputs": [
223    {
224     "output_type": "stream",
225     "name": "stdout",
226     "text": [
227      "[0, 1]\n[2, 3]\n[4, 5]\n[6, 7]\n[8, 9]\n[10, 11]\n[12, 13]\n[14, 15]\n[16, 17]\n[18, 19]\n[20, 21]\n[22, 23]\n[24, 25]\n[26, 27]\n[28, 29]\n[30, 31]\n[32, 33]\n[34, 35]\n[36, 37]\n[38, 39]\n"
228     ]
229    }
230   ],
231   "source": [
232    "dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = 2)\n",
233    "for i in dp:\n",
234    "    print(i)"
235   ]
236  },
237  {
238   "source": [
239    "Setting `unbatch_level` to `-1` will unbatch to the lowest level"
240   ],
241   "cell_type": "markdown",
242   "metadata": {}
243  },
244  {
245   "cell_type": "code",
246   "execution_count": 9,
247   "metadata": {},
248   "outputs": [
249    {
250     "output_type": "stream",
251     "name": "stdout",
252     "text": [
253      "0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16\n17\n18\n19\n20\n21\n22\n23\n24\n25\n26\n27\n28\n29\n30\n31\n32\n33\n34\n35\n36\n37\n38\n39\n"
254     ]
255    }
256   ],
257   "source": [
258    "dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = -1)\n",
259    "for i in dp:\n",
260    "    print(i)"
261   ]
262  },
263  {
264   "source": [
265    "## Map\n",
266    "\n",
267    "Function: `map`\n",
268    "\n",
269    "Description: \n",
270    "\n",
271    "Alternatives:\n",
272    "\n",
273    "Arguments:\n",
274    "  - `nesting_level: int = 0`\n",
275    " \n",
276    "Example:"
277   ],
278   "cell_type": "markdown",
279   "metadata": {}
280  },
281  {
282   "cell_type": "code",
283   "execution_count": 10,
284   "metadata": {},
285   "outputs": [
286    {
287     "output_type": "stream",
288     "name": "stdout",
289     "text": [
290      "0\n2\n4\n6\n8\n10\n12\n14\n16\n18\n"
291     ]
292    }
293   ],
294   "source": [
295    "dp = ExampleIterPipe(10).map(lambda x: x * 2)\n",
296    "for i in dp:\n",
297    "    print(i)"
298   ]
299  },
300  {
301   "source": [
302    "`map` by default applies function to every mini-batch as a whole\n"
303   ],
304   "cell_type": "markdown",
305   "metadata": {}
306  },
307  {
308   "cell_type": "code",
309   "execution_count": 11,
310   "metadata": {},
311   "outputs": [
312    {
313     "output_type": "stream",
314     "name": "stdout",
315     "text": [
316      "[0, 1, 2, 0, 1, 2]\n[3, 4, 5, 3, 4, 5]\n[6, 7, 8, 6, 7, 8]\n[9, 9]\n"
317     ]
318    }
319   ],
320   "source": [
321    "dp = ExampleIterPipe(10).batch(3).map(lambda x: x * 2)\n",
322    "for i in dp:\n",
323    "    print(i)"
324   ]
325  },
326  {
327   "source": [
328    "To apply function on individual items of the mini-batch use `nesting_level` argument"
329   ],
330   "cell_type": "markdown",
331   "metadata": {}
332  },
333  {
334   "cell_type": "code",
335   "execution_count": 12,
336   "metadata": {},
337   "outputs": [
338    {
339     "output_type": "stream",
340     "name": "stdout",
341     "text": [
342      "[[0, 2, 4], [6, 8, 10]]\n[[12, 14, 16], [18]]\n"
343     ]
344    }
345   ],
346   "source": [
347    "dp = ExampleIterPipe(10).batch(3).batch(2).map(lambda x: x * 2, nesting_level = 2)\n",
348    "for i in dp:\n",
349    "    print(i)"
350   ]
351  },
352  {
353   "source": [
354    "Setting `nesting_level` to `-1` will apply `map` function to the lowest level possible"
355   ],
356   "cell_type": "markdown",
357   "metadata": {}
358  },
359  {
360   "cell_type": "code",
361   "execution_count": 13,
362   "metadata": {},
363   "outputs": [
364    {
365     "output_type": "stream",
366     "name": "stdout",
367     "text": [
368      "[[[0, 2, 4], [6, 8, 10]], [[12, 14, 16], [18]]]\n"
369     ]
370    }
371   ],
372   "source": [
373    "dp = ExampleIterPipe(10).batch(3).batch(2).batch(2).map(lambda x: x * 2, nesting_level = -1)\n",
374    "for i in dp:\n",
375    "    print(i)"
376   ]
377  },
378  {
379   "source": [
380    "## Filter\n",
381    "\n",
382    "Function: `filter`\n",
383    "\n",
384    "Description: \n",
385    "\n",
386    "Alternatives:\n",
387    "\n",
388    "Arguments:\n",
389    "  - `nesting_level: int = 0`\n",
390    "  - `drop_empty_batches = True` whether empty many batches dropped or not.\n",
391    " \n",
392    "Example:"
393   ],
394   "cell_type": "markdown",
395   "metadata": {}
396  },
397  {
398   "cell_type": "code",
399   "execution_count": 14,
400   "metadata": {},
401   "outputs": [
402    {
403     "output_type": "stream",
404     "name": "stdout",
405     "text": [
406      "0\n2\n4\n6\n8\n"
407     ]
408    }
409   ],
410   "source": [
411    "dp = ExampleIterPipe(10).filter(lambda x: x % 2 == 0)\n",
412    "for i in dp:\n",
413    "    print(i)"
414   ]
415  },
416  {
417   "source": [
418    "Classic `filter` by default applies filter function to every mini-batches as a whole \n"
419   ],
420   "cell_type": "markdown",
421   "metadata": {}
422  },
423  {
424   "cell_type": "code",
425   "execution_count": 15,
426   "metadata": {},
427   "outputs": [
428    {
429     "output_type": "stream",
430     "name": "stdout",
431     "text": [
432      "[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n"
433     ]
434    }
435   ],
436   "source": [
437    "dp = ExampleIterPipe(10)\n",
438    "dp = dp.batch(3).filter(lambda x: len(x) > 2)\n",
439    "for i in dp:\n",
440    "    print(i)"
441   ]
442  },
443  {
444   "source": [
445    "You can apply filter function on individual elements by setting `nesting_level` argument"
446   ],
447   "cell_type": "markdown",
448   "metadata": {}
449  },
450  {
451   "cell_type": "code",
452   "execution_count": 16,
453   "metadata": {},
454   "outputs": [
455    {
456     "output_type": "stream",
457     "name": "stdout",
458     "text": [
459      "[5]\n[6, 7, 8]\n[9]\n"
460     ]
461    }
462   ],
463   "source": [
464    "dp = ExampleIterPipe(10)\n",
465    "dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = 1)\n",
466    "for i in dp:\n",
467    "    print(i)"
468   ]
469  },
470  {
471   "source": [
472    "If mini-batch ends with zero elements after filtering default behaviour would be to drop them from the response. You can override this behaviour using `drop_empty_batches` argument.\n"
473   ],
474   "cell_type": "markdown",
475   "metadata": {}
476  },
477  {
478   "cell_type": "code",
479   "execution_count": 17,
480   "metadata": {},
481   "outputs": [
482    {
483     "output_type": "stream",
484     "name": "stdout",
485     "text": [
486      "[]\n[5]\n[6, 7, 8]\n[9]\n"
487     ]
488    }
489   ],
490   "source": [
491    "dp = ExampleIterPipe(10)\n",
492    "dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = -1, drop_empty_batches = False)\n",
493    "for i in dp:\n",
494    "    print(i)"
495   ]
496  },
497  {
498   "cell_type": "code",
499   "execution_count": 18,
500   "metadata": {},
501   "outputs": [
502    {
503     "output_type": "stream",
504     "name": "stdout",
505     "text": [
506      "[[[0, 1, 2], [3]], [[], [10, 11]]]\n[[[12, 13, 14], [15, 16, 17]], [[18, 19]]]\n"
507     ]
508    }
509   ],
510   "source": [
511    "dp = ExampleIterPipe(20)\n",
512    "dp = dp.batch(3).batch(2).batch(2).filter(lambda x: x < 4 or x > 9 , nesting_level = -1, drop_empty_batches = False)\n",
513    "for i in dp:\n",
514    "    print(i)"
515   ]
516  },
517  {
518   "source": [
519    "## Shuffle\n",
520    "\n",
521    "Function: `shuffle`\n",
522    "\n",
523    "Description: \n",
524    "\n",
525    "Alternatives:\n",
526    "\n",
527    "Arguments:\n",
528    "  - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
529    "  - `buffer_size: int = 10000`\n",
530    " \n",
531    "Example:"
532   ],
533   "cell_type": "markdown",
534   "metadata": {}
535  },
536  {
537   "cell_type": "code",
538   "execution_count": 19,
539   "metadata": {},
540   "outputs": [
541    {
542     "output_type": "stream",
543     "name": "stdout",
544     "text": [
545      "2\n9\n4\n0\n3\n7\n8\n5\n6\n1\n"
546     ]
547    }
548   ],
549   "source": [
550    "dp = ExampleIterPipe(10).shuffle()\n",
551    "for i in dp:\n",
552    "    print(i)"
553   ]
554  },
555  {
556   "source": [
557    "`shuffle` operates on input mini-batches similar as on individual items"
558   ],
559   "cell_type": "markdown",
560   "metadata": {}
561  },
562  {
563   "cell_type": "code",
564   "execution_count": 20,
565   "metadata": {},
566   "outputs": [
567    {
568     "output_type": "stream",
569     "name": "stdout",
570     "text": [
571      "[0, 1, 2]\n[3, 4, 5]\n[9]\n[6, 7, 8]\n"
572     ]
573    }
574   ],
575   "source": [
576    "dp = ExampleIterPipe(10).batch(3).shuffle()\n",
577    "for i in dp:\n",
578    "    print(i)"
579   ]
580  },
581  {
582   "source": [
583    "To shuffle elements across batches use `shuffle(unbatch_level)` followed by `batch` pattern "
584   ],
585   "cell_type": "markdown",
586   "metadata": {}
587  },
588  {
589   "cell_type": "code",
590   "execution_count": 21,
591   "metadata": {},
592   "outputs": [
593    {
594     "output_type": "stream",
595     "name": "stdout",
596     "text": [
597      "[2, 1, 0]\n[7, 9, 6]\n[3, 5, 4]\n[8]\n"
598     ]
599    }
600   ],
601   "source": [
602    "dp = ExampleIterPipe(10).batch(3).shuffle(unbatch_level = -1).batch(3)\n",
603    "for i in dp:\n",
604    "    print(i)"
605   ]
606  },
607  {
608   "source": [
609    "## Collate\n",
610    "\n",
611    "Function: `collate`\n",
612    "\n",
613    "Description: \n",
614    "\n",
615    "Alternatives:\n",
616    "\n",
617    "Arguments:\n",
618    " \n",
619    "Example:"
620   ],
621   "cell_type": "markdown",
622   "metadata": {}
623  },
624  {
625   "cell_type": "code",
626   "execution_count": 22,
627   "metadata": {},
628   "outputs": [
629    {
630     "output_type": "stream",
631     "name": "stdout",
632     "text": [
633      "tensor([0, 1, 2])\ntensor([3, 4, 5])\ntensor([6, 7, 8])\ntensor([9])\n"
634     ]
635    }
636   ],
637   "source": [
638    "dp = ExampleIterPipe(10).batch(3).collate()\n",
639    "for i in dp:\n",
640    "    print(i)"
641   ]
642  },
643  {
644   "source": [
645    "## GroupBy\n",
646    "\n",
647    "Function: `groupby`\n",
648    "\n",
649    "Usage: `dp.groupby(lambda x: x[0])`\n",
650    "\n",
651    "Description: Batching items by combining items with same key into same batch \n",
652    "\n",
653    "Arguments:\n",
654    " - `group_key_fn`\n",
655    " - `group_size` - yeild resulted group as soon as `group_size` elements accumulated\n",
656    " - `guaranteed_group_size:int = None`\n",
657    " - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
658    "\n",
659    "#### Attention\n",
660    "As datasteam can be arbitrary large, grouping is done on best effort basis and there is no guarantee that same key will never present in the different groups. You can call it local groupby where locallity is the one DataPipe process/thread."
661   ],
662   "cell_type": "markdown",
663   "metadata": {}
664  },
665  {
666   "cell_type": "code",
667   "execution_count": 23,
668   "metadata": {},
669   "outputs": [
670    {
671     "output_type": "stream",
672     "name": "stdout",
673     "text": [
674      "[0, 3, 6, 9]\n[1, 4, 7]\n[5, 2, 8]\n"
675     ]
676    }
677   ],
678   "source": [
679    "dp = ExampleIterPipe(10).shuffle().groupby(lambda x: x % 3)\n",
680    "for i in dp:\n",
681    "    print(i)"
682   ]
683  },
684  {
685   "source": [
686    "By default group key function is applied to entire input (mini-batch)"
687   ],
688   "cell_type": "markdown",
689   "metadata": {}
690  },
691  {
692   "cell_type": "code",
693   "execution_count": 24,
694   "metadata": {},
695   "outputs": [
696    {
697     "output_type": "stream",
698     "name": "stdout",
699     "text": [
700      "[[0, 1, 2], [3, 4, 5], [6, 7, 8]]\n[[9]]\n"
701     ]
702    }
703   ],
704   "source": [
705    "dp = ExampleIterPipe(10).batch(3).groupby(lambda x: len(x))\n",
706    "for i in dp:\n",
707    "    print(i)"
708   ]
709  },
710  {
711   "source": [
712    "It is possible to unnest items from the mini-batches using `unbatch_level` argument"
713   ],
714   "cell_type": "markdown",
715   "metadata": {}
716  },
717  {
718   "cell_type": "code",
719   "execution_count": 25,
720   "metadata": {},
721   "outputs": [
722    {
723     "output_type": "stream",
724     "name": "stdout",
725     "text": [
726      "[0, 3, 6, 9]\n[1, 4, 7]\n[2, 5, 8]\n"
727     ]
728    }
729   ],
730   "source": [
731    "dp = ExampleIterPipe(10).batch(3).groupby(lambda x: x % 3, unbatch_level = 1)\n",
732    "for i in dp:\n",
733    "    print(i)"
734   ]
735  },
736  {
737   "source": [
738    "When internal buffer (defined by `buffer_size`) is overfilled, groupby will yield biggest group available"
739   ],
740   "cell_type": "markdown",
741   "metadata": {}
742  },
743  {
744   "cell_type": "code",
745   "execution_count": 26,
746   "metadata": {},
747   "outputs": [
748    {
749     "output_type": "stream",
750     "name": "stdout",
751     "text": [
752      "[9, 3]\n[13, 4, 7]\n[2, 11, 14, 5]\n[0, 6, 12]\n[1, 10]\n[8]\n"
753     ]
754    }
755   ],
756   "source": [
757    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, buffer_size = 5)\n",
758    "for i in dp:\n",
759    "    print(i)"
760   ]
761  },
762  {
763   "source": [
764    "`groupby` will produce `group_size` sized batches on as fast as possible basis"
765   ],
766   "cell_type": "markdown",
767   "metadata": {}
768  },
769  {
770   "cell_type": "code",
771   "execution_count": 27,
772   "metadata": {},
773   "outputs": [
774    {
775     "output_type": "stream",
776     "name": "stdout",
777     "text": [
778      "[6, 3, 12]\n[1, 16, 7]\n[2, 5, 8]\n[14, 11, 17]\n[15, 9, 0]\n[10, 4, 13]\n"
779     ]
780    }
781   ],
782   "source": [
783    "dp = ExampleIterPipe(18).shuffle().groupby(lambda x: x % 3, group_size = 3)\n",
784    "for i in dp:\n",
785    "    print(i)"
786   ]
787  },
788  {
789   "source": [
790    "Remaining groups must be at least `guaranteed_group_size` big. "
791   ],
792   "cell_type": "markdown",
793   "metadata": {}
794  },
795  {
796   "cell_type": "code",
797   "execution_count": 28,
798   "metadata": {},
799   "outputs": [
800    {
801     "output_type": "stream",
802     "name": "stdout",
803     "text": [
804      "[11, 2, 5]\n[1, 4, 10]\n[0, 9, 6]\n[14, 8]\n[13, 7]\n[12, 3]\n"
805     ]
806    }
807   ],
808   "source": [
809    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, group_size = 3, guaranteed_group_size = 2)\n",
810    "for i in dp:\n",
811    "    print(i)"
812   ]
813  },
814  {
815   "source": [
816    "Without defined `group_size` function will try to accumulate at least `guaranteed_group_size` elements before yielding resulted group"
817   ],
818   "cell_type": "markdown",
819   "metadata": {}
820  },
821  {
822   "cell_type": "code",
823   "execution_count": 29,
824   "metadata": {},
825   "outputs": [
826    {
827     "output_type": "stream",
828     "name": "stdout",
829     "text": [
830      "[3, 6, 9, 12, 0]\n[14, 2, 8, 11, 5]\n[7, 4, 1, 13, 10]\n"
831     ]
832    }
833   ],
834   "source": [
835    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2)\n",
836    "for i in dp:\n",
837    "    print(i)"
838   ]
839  },
840  {
841   "source": [
842    "This behaviour becomes noticable when data is bigger than buffer and some groups getting evicted before gathering all potential items"
843   ],
844   "cell_type": "markdown",
845   "metadata": {}
846  },
847  {
848   "cell_type": "code",
849   "execution_count": 30,
850   "metadata": {},
851   "outputs": [
852    {
853     "output_type": "stream",
854     "name": "stdout",
855     "text": [
856      "[0, 3]\n[1, 4, 7]\n[2, 5, 8]\n[6, 9, 12]\n[10, 13]\n[11, 14]\n"
857     ]
858    }
859   ],
860   "source": [
861    "dp = ExampleIterPipe(15).groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n",
862    "for i in dp:\n",
863    "    print(i)"
864   ]
865  },
866  {
867   "source": [
868    "With randomness involved you might end up with incomplete groups (so next example expected to fail in most cases)"
869   ],
870   "cell_type": "markdown",
871   "metadata": {}
872  },
873  {
874   "cell_type": "code",
875   "execution_count": 31,
876   "metadata": {},
877   "outputs": [
878    {
879     "output_type": "stream",
880     "name": "stdout",
881     "text": [
882      "[14, 5, 11]\n[1, 7, 4, 10]\n[0, 12, 6]\n[8, 2]\n[9, 3]\n"
883     ]
884    },
885    {
886     "output_type": "error",
887     "ename": "Exception",
888     "evalue": "('Failed to group items', '[13]')",
889     "traceback": [
890      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
891      "\u001b[0;31mException\u001b[0m                                 Traceback (most recent call last)",
892      "\u001b[0;32m<ipython-input-31-673b9dd7fb43>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mdp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mExampleIterPipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroupby\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguaranteed_group_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffer_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
893      "\u001b[0;32m~/dataset/pytorch/torch/utils/data/datapipes/iter/grouping.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    276\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop_remaining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 277\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Failed to group items'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbiggest_key\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    278\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    279\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
894      "\u001b[0;31mException\u001b[0m: ('Failed to group items', '[13]')"
895     ]
896    }
897   ],
898   "source": [
899    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n",
900    "for i in dp:\n",
901    "    print(i)"
902   ]
903  },
904  {
905   "source": [
906    "To avoid this error and drop incomplete groups, use `drop_remaining` argument"
907   ],
908   "cell_type": "markdown",
909   "metadata": {}
910  },
911  {
912   "cell_type": "code",
913   "execution_count": 32,
914   "metadata": {},
915   "outputs": [
916    {
917     "output_type": "stream",
918     "name": "stdout",
919     "text": [
920      "[5, 2, 14]\n[4, 7, 13, 1, 10]\n[12, 6, 3, 9]\n[8, 11]\n"
921     ]
922    }
923   ],
924   "source": [
925    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6, drop_remaining = True)\n",
926    "for i in dp:\n",
927    "    print(i)"
928   ]
929  },
930  {
931   "source": [
932    "## Zip\n",
933    "\n",
934    "Function: `zip`\n",
935    "\n",
936    "Description: \n",
937    "\n",
938    "Alternatives:\n",
939    "\n",
940    "Arguments:\n",
941    " \n",
942    "Example:"
943   ],
944   "cell_type": "markdown",
945   "metadata": {}
946  },
947  {
948   "cell_type": "code",
949   "execution_count": 35,
950   "metadata": {},
951   "outputs": [
952    {
953     "output_type": "stream",
954     "name": "stdout",
955     "text": [
956      "(0, 3)\n(1, 0)\n(2, 4)\n(3, 2)\n(4, 1)\n"
957     ]
958    }
959   ],
960   "source": [
961    "_dp = ExampleIterPipe(5).shuffle()\n",
962    "dp = ExampleIterPipe(5).zip(_dp)\n",
963    "for i in dp:\n",
964    "    print(i)"
965   ]
966  },
967  {
968   "source": [
969    "## Fork\n",
970    "\n",
971    "Function: `fork`\n",
972    "\n",
973    "Description: \n",
974    "\n",
975    "Alternatives:\n",
976    "\n",
977    "Arguments:\n",
978    " \n",
979    "Example:"
980   ],
981   "cell_type": "markdown",
982   "metadata": {}
983  },
984  {
985   "cell_type": "code",
986   "execution_count": 36,
987   "metadata": {},
988   "outputs": [
989    {
990     "output_type": "stream",
991     "name": "stdout",
992     "text": [
993      "0\n1\n0\n1\n0\n1\n"
994     ]
995    }
996   ],
997   "source": [
998    "dp = ExampleIterPipe(2)\n",
999    "dp1, dp2, dp3 = dp.fork(3)\n",
1000    "for i in dp1 + dp2 + dp3:\n",
1001    "    print(i)"
1002   ]
1003  },
1004  {
1005   "cell_type": "markdown",
1006   "metadata": {},
1007   "source": [
1008    "## Demultiplexer\n",
1009    "\n",
1010    "Function: `demux`\n",
1011    "\n",
1012    "Description: \n",
1013    "\n",
1014    "Alternatives:\n",
1015    "\n",
1016    "Arguments:\n",
1017    " \n",
1018    "Example:"
1019   ]
1020  },
1021  {
1022   "cell_type": "code",
1023   "execution_count": 32,
1024   "metadata": {},
1025   "outputs": [
1026    {
1027     "name": "stdout",
1028     "output_type": "stream",
1029     "text": [
1030      "1\n",
1031      "4\n",
1032      "7\n"
1033     ]
1034    }
1035   ],
1036   "source": [
1037    "dp = ExampleIterPipe(10)\n",
1038    "dp1, dp2, dp3 = dp.demux(3, lambda x: x % 3)\n",
1039    "for i in dp2:\n",
1040    "    print(i)"
1041   ]
1042  },
1043  {
1044   "cell_type": "markdown",
1045   "metadata": {},
1046   "source": [
1047    "## Multiplexer\n",
1048    "\n",
1049    "Function: `mux`\n",
1050    "\n",
1051    "Description: \n",
1052    "\n",
1053    "Alternatives:\n",
1054    "\n",
1055    "Arguments:\n",
1056    " \n",
1057    "Example:"
1058   ]
1059  },
1060  {
1061   "cell_type": "code",
1062   "execution_count": 34,
1063   "metadata": {},
1064   "outputs": [
1065    {
1066     "name": "stdout",
1067     "output_type": "stream",
1068     "text": [
1069      "0\n",
1070      "0\n",
1071      "0\n",
1072      "1\n",
1073      "10\n",
1074      "100\n",
1075      "2\n",
1076      "20\n",
1077      "200\n"
1078     ]
1079    }
1080   ],
1081   "source": [
1082    "dp1 = ExampleIterPipe(3)\n",
1083    "dp2 = ExampleIterPipe(3).map(lambda x: x * 10)\n",
1084    "dp3 = ExampleIterPipe(3).map(lambda x: x * 100)\n",
1085    "\n",
1086    "dp = dp1.mux(dp2, dp3)\n",
1087    "for i in dp:\n",
1088    "    print(i)"
1089   ]
1090  },
1091  {
1092   "source": [
1093    "## Concat\n",
1094    "\n",
1095    "Function: `concat`\n",
1096    "\n",
1097    "Description: Returns DataPipes with elements from the first datapipe following by elements from second datapipes\n",
1098    "\n",
1099    "Alternatives:\n",
1100    "    \n",
1101    "    `dp = dp.concat(dp2, dp3)`\n",
1102    "    `dp = dp.concat(*datapipes_list)`\n",
1103    "\n",
1104    "Example:\n"
1105   ],
1106   "cell_type": "markdown",
1107   "metadata": {}
1108  },
1109  {
1110   "cell_type": "code",
1111   "execution_count": 37,
1112   "metadata": {},
1113   "outputs": [
1114    {
1115     "output_type": "stream",
1116     "name": "stdout",
1117     "text": [
1118      "0\n1\n2\n3\n0\n1\n2\n"
1119     ]
1120    }
1121   ],
1122   "source": [
1123    "dp = ExampleIterPipe(4)\n",
1124    "dp2 = ExampleIterPipe(3)\n",
1125    "dp = dp.concat(dp2)\n",
1126    "for i in dp:\n",
1127    "    print(i)"
1128   ]
1129  }
1130 ]
1131}