1The [`datapipes`](https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes) folder holds the implementation of the `IterDataPipe` and `MapDataPipe`. 2 3This document serves as an entry point for DataPipe implementation. 4 5## Implementing DataPipe 6For the sake of an example, let us implement an `IterDataPipe` to apply a callable over data under [`iter`](https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes/iter). 7For `MapDataPipe`, please take reference from files in [map](https://github.com/pytorch/pytorch/tree/main/torch/utils/data/datapipes/map) folder and implement the corresponding `__getitem__` method. 8 9### Naming 10The naming convention for DataPipe is Operation-er and with suffix of `IterDataPipe` because each DataPipe behaves like a container to apply the operation to data yielded from the source DataPipe. 11And, when importing the DataPipe into `iter` module under `datapipes`, each DataPipe will be aliased as Op-er without the suffix of `IterDataPipe`. 12Please check [`__init__.py`](https://github.com/pytorch/pytorch/blob/main/torch/utils/data/datapipes/iter/__init__.py) in `iter` module for how we aliasing each DataPipe class. 13Like the example of `IterDataPipe` to map a function, we are going to name it as `MapperIterDataPipe` and alias it as `iter.Mapper` under `datapipes`. 14 15### Constructor 16As DataSet now constructed by a stack of DataPipe-s, each DataPipe normally takes a source DataPipe as the first argument. 17```py 18class MapperIterDataPipe(IterDataPipe): 19 def __init__(self, dp, fn): 20 super().__init__() 21 self.dp = dp 22 self.fn = fn 23``` 24Note: 25- Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory. 26- If `IterDataPipe` instance holds data in memory, please be ware of the in-place modification of data. When second iterator is created from the instance, the data may have already changed. Please take [`IterableWrapper`](https://github.com/pytorch/pytorch/blob/main/torch/utils/data/datapipes/iter/utils.py) class as reference to `deepcopy` data for each iterator. 27 28### Iterator 29For `IterDataPipe`, an `__iter__` function is needed to consume data from the source `IterDataPipe` then apply operation over the data before yield. 30```py 31class MapperIterDataPipe(IterDataPipe): 32 ... 33 34 def __iter__(self): 35 for d in self.dp: 36 yield self.fn(d) 37``` 38 39### Length 40In the most common cases, as the example of `MapperIterDataPipe` above, the `__len__` method of DataPipe should return the length of source DataPipe. 41Take care that `__len__` must be computed dynamically, because the length of source data-pipes might change after initialization (for example if sharding is applied). 42 43```py 44class MapperIterDataPipe(IterDataPipe): 45 ... 46 47 def __len__(self): 48 return len(self.dp) 49``` 50Note that `__len__` method is optional for `IterDataPipe`. 51Like `CSVParserIterDataPipe` in the [Using DataPipe sector](#using-datapipe), `__len__` is not implemented because the size of each file streams is unknown for us before loading it. 52 53Besides, in some special cases, `__len__` method can be provided, but it would either return an integer length or raise Error depending on the arguments of DataPipe. 54And, the Error is required to be `TypeError` to support Python's build-in functions like `list(dp)`. 55Please check NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] for detailed reason in PyTorch. 56 57### Registering DataPipe with functional API 58Each DataPipe can be registered to support functional API using the decorator `functional_datapipe`. 59```py 60@functional_datapipe("map") 61class MapperIterDataPipe(IterDataPipe): 62 ... 63``` 64Then, the stack of DataPipe can be constructed in functional-programming manner. 65```py 66>>> import torch.utils.data.datapipes as dp 67>>> datapipes1 = dp.iter.FileOpener(['a.file', 'b.file']).map(fn=decoder).shuffle().batch(2) 68 69>>> datapipes2 = dp.iter.FileOpener(['a.file', 'b.file']) 70>>> datapipes2 = dp.iter.Mapper(datapipes2) 71>>> datapipes2 = dp.iter.Shuffler(datapipes2) 72>>> datapipes2 = dp.iter.Batcher(datapipes2, 2) 73``` 74In the above example, `datapipes1` and `datapipes2` represent the exact same stack of `IterDataPipe`-s. 75 76## Using DataPipe 77For example, we want to load data from CSV files with the following data pipeline: 78- List all csv files 79- Load csv files 80- Parse csv file and yield rows 81 82To support the above pipeline, `CSVParser` is registered as `parse_csv_files` to consume file streams and expand them as rows. 83```py 84@functional_datapipe("parse_csv_files") 85class CSVParserIterDataPipe(IterDataPipe): 86 def __init__(self, dp, **fmtparams): 87 self.dp = dp 88 self.fmtparams = fmtparams 89 90 def __iter__(self): 91 for filename, stream in self.dp: 92 reader = csv.reader(stream, **self.fmtparams) 93 for row in reader: 94 yield filename, row 95``` 96Then, the pipeline can be assembled as following: 97```py 98>>> import torch.utils.data.datapipes as dp 99 100>>> FOLDER = 'path/2/csv/folder' 101>>> datapipe = dp.iter.FileLister([FOLDER]).filter(fn=lambda filename: filename.endswith('.csv')) 102>>> datapipe = dp.iter.FileOpener(datapipe, mode='rt') 103>>> datapipe = datapipe.parse_csv_files(delimiter=' ') 104 105>>> for d in datapipe: # Start loading data 106... pass 107``` 108