xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/callable.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from collections import namedtuple
4from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, TypeVar, Union
5
6from torch.utils.data._utils.collate import default_collate
7from torch.utils.data.datapipes._decorator import functional_datapipe
8from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
9from torch.utils.data.datapipes.datapipe import IterDataPipe
10from torch.utils.data.datapipes.utils.common import (
11    _check_unpickable_fn,
12    validate_input_col,
13)
14
15
16__all__ = [
17    "CollatorIterDataPipe",
18    "MapperIterDataPipe",
19]
20
21
22_T_co = TypeVar("_T_co", covariant=True)
23
24
25@functional_datapipe("map")
26class MapperIterDataPipe(IterDataPipe[_T_co]):
27    r"""
28    Applies a function over each item from the source DataPipe (functional name: ``map``).
29
30    The function can be any regular Python function or partial object. Lambda
31    function is not recommended as it is not supported by pickle.
32
33    Args:
34        datapipe: Source Iterable DataPipe
35        fn: Function being applied over each item
36        input_col: Index or indices of data which ``fn`` is applied, such as:
37
38            - ``None`` as default to apply ``fn`` to the data directly.
39            - Integer(s) is used for list/tuple.
40            - Key(s) is used for dict.
41
42        output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
43            only when ``input_col`` is not ``None``
44
45            - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
46              multiple indices, the left-most one is used, and other indices will be removed.
47            - Integer is used for list/tuple. ``-1`` represents to append result at the end.
48            - Key is used for dict. New key is acceptable.
49
50    Example:
51        >>> # xdoctest: +SKIP
52        >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
53        >>> def add_one(x):
54        ...     return x + 1
55        >>> dp = IterableWrapper(range(10))
56        >>> map_dp_1 = dp.map(add_one)  # Invocation via functional form is preferred
57        >>> list(map_dp_1)
58        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
59        >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
60        >>> # Use `functools.partial` or explicitly define the function instead
61        >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
62        >>> list(map_dp_2)
63        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
64    """
65
66    datapipe: IterDataPipe
67    fn: Callable
68
69    def __init__(
70        self,
71        datapipe: IterDataPipe,
72        fn: Callable,
73        input_col=None,
74        output_col=None,
75    ) -> None:
76        super().__init__()
77        self.datapipe = datapipe
78
79        _check_unpickable_fn(fn)
80        self.fn = fn  # type: ignore[assignment]
81
82        self.input_col = input_col
83        if input_col is None and output_col is not None:
84            raise ValueError("`output_col` must be None when `input_col` is None.")
85        if isinstance(output_col, (list, tuple)):
86            if len(output_col) > 1:
87                raise ValueError("`output_col` must be a single-element list or tuple")
88            output_col = output_col[0]
89        self.output_col = output_col
90        validate_input_col(fn, input_col)
91
92    def _apply_fn(self, data):
93        if self.input_col is None and self.output_col is None:
94            return self.fn(data)
95
96        if self.input_col is None:
97            res = self.fn(data)
98        elif isinstance(self.input_col, (list, tuple)):
99            args = tuple(data[col] for col in self.input_col)
100            res = self.fn(*args)
101        else:
102            res = self.fn(data[self.input_col])
103
104        # Copy tuple to list and run in-place modification because tuple is immutable.
105        if isinstance(data, tuple):
106            t_flag = True
107            data = list(data)
108        else:
109            t_flag = False
110
111        if self.output_col is None:
112            if isinstance(self.input_col, (list, tuple)):
113                data[self.input_col[0]] = res
114                for idx in sorted(self.input_col[1:], reverse=True):
115                    del data[idx]
116            else:
117                data[self.input_col] = res
118        else:
119            if self.output_col == -1:
120                data.append(res)
121            else:
122                data[self.output_col] = res
123
124        # Convert list back to tuple
125        return tuple(data) if t_flag else data
126
127    def __iter__(self) -> Iterator[_T_co]:
128        for data in self.datapipe:
129            yield self._apply_fn(data)
130
131    def __len__(self) -> int:
132        if isinstance(self.datapipe, Sized):
133            return len(self.datapipe)
134        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
135
136
137def _collate_helper(conversion, item):
138    # TODO(VitalyFedyunin): Verify that item is any sort of batch
139    if len(item.items) > 1:
140        # TODO(VitalyFedyunin): Compact all batch dataframes into one
141        raise RuntimeError("Only supports one DataFrame per batch")
142    df = item[0]
143    columns_name = df_wrapper.get_columns(df)
144    tuple_names: List = []
145    tuple_values: List = []
146
147    for name in conversion.keys():
148        if name not in columns_name:
149            raise RuntimeError("Conversion keys missmatch")
150
151    for name in columns_name:
152        if name in conversion:
153            if not callable(conversion[name]):
154                raise RuntimeError(
155                    "Collate (DF)DataPipe requires callable as dict values"
156                )
157            collation_fn = conversion[name]
158        else:
159            # TODO(VitalyFedyunin): Add default collation into df_wrapper
160            try:
161                import torcharrow.pytorch as tap  # type: ignore[import]
162
163                collation_fn = tap.rec.Default()
164            except Exception as e:
165                raise RuntimeError(
166                    "unable to import default collation function from the TorchArrow"
167                ) from e
168
169        tuple_names.append(str(name))
170        value = collation_fn(df[name])
171        tuple_values.append(value)
172
173    # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here
174    # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty
175    tpl_cls = namedtuple("CollateResult", tuple_names)  # type: ignore[misc]
176    tuple = tpl_cls(*tuple_values)
177    return tuple
178
179
180@functional_datapipe("collate")
181class CollatorIterDataPipe(MapperIterDataPipe):
182    r"""
183    Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
184
185    By default, it uses :func:`torch.utils.data.default_collate`.
186
187    .. note::
188        While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
189        default behavior and `functools.partial` to specify any additional arguments.
190
191    Args:
192        datapipe: Iterable DataPipe being collated
193        collate_fn: Customized collate function to collect and combine data or a batch of data.
194            Default function collates to Tensor(s) based on data type.
195
196    Example:
197        >>> # xdoctest: +SKIP
198        >>> # Convert integer data to float Tensor
199        >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
200        ...     def __init__(self, start, end):
201        ...         super(MyIterDataPipe).__init__()
202        ...         assert end > start, "this example code only works with end >= start"
203        ...         self.start = start
204        ...         self.end = end
205        ...
206        ...     def __iter__(self):
207        ...         return iter(range(self.start, self.end))
208        ...
209        ...     def __len__(self):
210        ...         return self.end - self.start
211        ...
212        >>> ds = MyIterDataPipe(start=3, end=7)
213        >>> print(list(ds))
214        [3, 4, 5, 6]
215        >>> def collate_fn(batch):
216        ...     return torch.tensor(batch, dtype=torch.float)
217        ...
218        >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
219        >>> print(list(collated_ds))
220        [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
221    """
222
223    def __init__(
224        self,
225        datapipe: IterDataPipe,
226        conversion: Union[
227            Callable[..., Any], Dict[Union[str, Any], Union[Callable, Any]], None
228        ] = default_collate,
229        collate_fn: Optional[Callable] = None,
230    ) -> None:
231        # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]`
232        # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]`
233        if collate_fn is not None:
234            super().__init__(datapipe, fn=collate_fn)
235        else:
236            if callable(conversion):
237                super().__init__(datapipe, fn=conversion)
238            else:
239                # TODO(VitalyFedyunin): Validate passed dictionary
240                collate_fn = functools.partial(_collate_helper, conversion)
241                super().__init__(datapipe, fn=collate_fn)
242