1# mypy: allow-untyped-defs 2import functools 3from collections import namedtuple 4from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, TypeVar, Union 5 6from torch.utils.data._utils.collate import default_collate 7from torch.utils.data.datapipes._decorator import functional_datapipe 8from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper 9from torch.utils.data.datapipes.datapipe import IterDataPipe 10from torch.utils.data.datapipes.utils.common import ( 11 _check_unpickable_fn, 12 validate_input_col, 13) 14 15 16__all__ = [ 17 "CollatorIterDataPipe", 18 "MapperIterDataPipe", 19] 20 21 22_T_co = TypeVar("_T_co", covariant=True) 23 24 25@functional_datapipe("map") 26class MapperIterDataPipe(IterDataPipe[_T_co]): 27 r""" 28 Applies a function over each item from the source DataPipe (functional name: ``map``). 29 30 The function can be any regular Python function or partial object. Lambda 31 function is not recommended as it is not supported by pickle. 32 33 Args: 34 datapipe: Source Iterable DataPipe 35 fn: Function being applied over each item 36 input_col: Index or indices of data which ``fn`` is applied, such as: 37 38 - ``None`` as default to apply ``fn`` to the data directly. 39 - Integer(s) is used for list/tuple. 40 - Key(s) is used for dict. 41 42 output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified 43 only when ``input_col`` is not ``None`` 44 45 - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with 46 multiple indices, the left-most one is used, and other indices will be removed. 47 - Integer is used for list/tuple. ``-1`` represents to append result at the end. 48 - Key is used for dict. New key is acceptable. 49 50 Example: 51 >>> # xdoctest: +SKIP 52 >>> from torchdata.datapipes.iter import IterableWrapper, Mapper 53 >>> def add_one(x): 54 ... return x + 1 55 >>> dp = IterableWrapper(range(10)) 56 >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred 57 >>> list(map_dp_1) 58 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 59 >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` 60 >>> # Use `functools.partial` or explicitly define the function instead 61 >>> map_dp_2 = Mapper(dp, lambda x: x + 1) 62 >>> list(map_dp_2) 63 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 64 """ 65 66 datapipe: IterDataPipe 67 fn: Callable 68 69 def __init__( 70 self, 71 datapipe: IterDataPipe, 72 fn: Callable, 73 input_col=None, 74 output_col=None, 75 ) -> None: 76 super().__init__() 77 self.datapipe = datapipe 78 79 _check_unpickable_fn(fn) 80 self.fn = fn # type: ignore[assignment] 81 82 self.input_col = input_col 83 if input_col is None and output_col is not None: 84 raise ValueError("`output_col` must be None when `input_col` is None.") 85 if isinstance(output_col, (list, tuple)): 86 if len(output_col) > 1: 87 raise ValueError("`output_col` must be a single-element list or tuple") 88 output_col = output_col[0] 89 self.output_col = output_col 90 validate_input_col(fn, input_col) 91 92 def _apply_fn(self, data): 93 if self.input_col is None and self.output_col is None: 94 return self.fn(data) 95 96 if self.input_col is None: 97 res = self.fn(data) 98 elif isinstance(self.input_col, (list, tuple)): 99 args = tuple(data[col] for col in self.input_col) 100 res = self.fn(*args) 101 else: 102 res = self.fn(data[self.input_col]) 103 104 # Copy tuple to list and run in-place modification because tuple is immutable. 105 if isinstance(data, tuple): 106 t_flag = True 107 data = list(data) 108 else: 109 t_flag = False 110 111 if self.output_col is None: 112 if isinstance(self.input_col, (list, tuple)): 113 data[self.input_col[0]] = res 114 for idx in sorted(self.input_col[1:], reverse=True): 115 del data[idx] 116 else: 117 data[self.input_col] = res 118 else: 119 if self.output_col == -1: 120 data.append(res) 121 else: 122 data[self.output_col] = res 123 124 # Convert list back to tuple 125 return tuple(data) if t_flag else data 126 127 def __iter__(self) -> Iterator[_T_co]: 128 for data in self.datapipe: 129 yield self._apply_fn(data) 130 131 def __len__(self) -> int: 132 if isinstance(self.datapipe, Sized): 133 return len(self.datapipe) 134 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 135 136 137def _collate_helper(conversion, item): 138 # TODO(VitalyFedyunin): Verify that item is any sort of batch 139 if len(item.items) > 1: 140 # TODO(VitalyFedyunin): Compact all batch dataframes into one 141 raise RuntimeError("Only supports one DataFrame per batch") 142 df = item[0] 143 columns_name = df_wrapper.get_columns(df) 144 tuple_names: List = [] 145 tuple_values: List = [] 146 147 for name in conversion.keys(): 148 if name not in columns_name: 149 raise RuntimeError("Conversion keys missmatch") 150 151 for name in columns_name: 152 if name in conversion: 153 if not callable(conversion[name]): 154 raise RuntimeError( 155 "Collate (DF)DataPipe requires callable as dict values" 156 ) 157 collation_fn = conversion[name] 158 else: 159 # TODO(VitalyFedyunin): Add default collation into df_wrapper 160 try: 161 import torcharrow.pytorch as tap # type: ignore[import] 162 163 collation_fn = tap.rec.Default() 164 except Exception as e: 165 raise RuntimeError( 166 "unable to import default collation function from the TorchArrow" 167 ) from e 168 169 tuple_names.append(str(name)) 170 value = collation_fn(df[name]) 171 tuple_values.append(value) 172 173 # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here 174 # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty 175 tpl_cls = namedtuple("CollateResult", tuple_names) # type: ignore[misc] 176 tuple = tpl_cls(*tuple_values) 177 return tuple 178 179 180@functional_datapipe("collate") 181class CollatorIterDataPipe(MapperIterDataPipe): 182 r""" 183 Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``). 184 185 By default, it uses :func:`torch.utils.data.default_collate`. 186 187 .. note:: 188 While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the 189 default behavior and `functools.partial` to specify any additional arguments. 190 191 Args: 192 datapipe: Iterable DataPipe being collated 193 collate_fn: Customized collate function to collect and combine data or a batch of data. 194 Default function collates to Tensor(s) based on data type. 195 196 Example: 197 >>> # xdoctest: +SKIP 198 >>> # Convert integer data to float Tensor 199 >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): 200 ... def __init__(self, start, end): 201 ... super(MyIterDataPipe).__init__() 202 ... assert end > start, "this example code only works with end >= start" 203 ... self.start = start 204 ... self.end = end 205 ... 206 ... def __iter__(self): 207 ... return iter(range(self.start, self.end)) 208 ... 209 ... def __len__(self): 210 ... return self.end - self.start 211 ... 212 >>> ds = MyIterDataPipe(start=3, end=7) 213 >>> print(list(ds)) 214 [3, 4, 5, 6] 215 >>> def collate_fn(batch): 216 ... return torch.tensor(batch, dtype=torch.float) 217 ... 218 >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) 219 >>> print(list(collated_ds)) 220 [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] 221 """ 222 223 def __init__( 224 self, 225 datapipe: IterDataPipe, 226 conversion: Union[ 227 Callable[..., Any], Dict[Union[str, Any], Union[Callable, Any]], None 228 ] = default_collate, 229 collate_fn: Optional[Callable] = None, 230 ) -> None: 231 # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]` 232 # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]` 233 if collate_fn is not None: 234 super().__init__(datapipe, fn=collate_fn) 235 else: 236 if callable(conversion): 237 super().__init__(datapipe, fn=conversion) 238 else: 239 # TODO(VitalyFedyunin): Validate passed dictionary 240 collate_fn = functools.partial(_collate_helper, conversion) 241 super().__init__(datapipe, fn=collate_fn) 242