1# mypy: allow-untyped-defs 2import copy 3import warnings 4 5from torch.utils.data.datapipes.datapipe import IterDataPipe 6 7 8__all__ = ["IterableWrapperIterDataPipe"] 9 10 11class IterableWrapperIterDataPipe(IterDataPipe): 12 r""" 13 Wraps an iterable object to create an IterDataPipe. 14 15 Args: 16 iterable: Iterable object to be wrapped into an IterDataPipe 17 deepcopy: Option to deepcopy input iterable object for each 18 iterator. The copy is made when the first element is read in ``iter()``. 19 20 .. note:: 21 If ``deepcopy`` is explicitly set to ``False``, users should ensure 22 that the data pipeline doesn't contain any in-place operations over 23 the iterable instance to prevent data inconsistency across iterations. 24 25 Example: 26 >>> # xdoctest: +SKIP 27 >>> from torchdata.datapipes.iter import IterableWrapper 28 >>> dp = IterableWrapper(range(10)) 29 >>> list(dp) 30 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 31 """ 32 33 def __init__(self, iterable, deepcopy=True): 34 self.iterable = iterable 35 self.deepcopy = deepcopy 36 37 def __iter__(self): 38 source_data = self.iterable 39 if self.deepcopy: 40 try: 41 source_data = copy.deepcopy(self.iterable) 42 # For the case that data cannot be deep-copied, 43 # all in-place operations will affect iterable variable. 44 # When this DataPipe is iterated second time, it will 45 # yield modified items. 46 except TypeError: 47 warnings.warn( 48 "The input iterable can not be deepcopied, " 49 "please be aware of in-place modification would affect source data." 50 ) 51 yield from source_data 52 53 def __len__(self): 54 return len(self.iterable) 55