xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/_hook_iterator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import inspect
4from enum import Enum
5
6import torch
7
8
9class _SnapshotState(Enum):
10    r"""
11    These are the snapshotting-related states that IterDataPipes can be in.
12
13    `NotStarted` - allows you to restore a snapshot and create an iterator with reset
14    `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
15    `Iterating` - can restore, will reset if you create a new iterator
16    """
17
18    NotStarted = 0
19    Restored = 1
20    Iterating = 2
21
22
23def _simplify_obj_name(obj) -> str:
24    """Simplify the display strings of objects for the purpose of rendering within DataPipe error messages."""
25    if inspect.isfunction(obj):
26        return obj.__name__
27    else:
28        return repr(obj)
29
30
31def _strip_datapipe_from_name(name: str) -> str:
32    return name.replace("IterDataPipe", "").replace("MapDataPipe", "")
33
34
35def _generate_input_args_string(obj):
36    """Generate a string for the input arguments of an object."""
37    signature = inspect.signature(obj.__class__)
38    input_param_names = set(signature.parameters.keys())
39    result = []
40    for name, value in inspect.getmembers(obj):
41        if name in input_param_names:
42            result.append((name, _simplify_obj_name(value)))
43    return ", ".join([f"{name}={value}" for name, value in result])
44
45
46def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False):
47    output_string = (
48        f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
49    )
50    if simplify_dp_name:
51        output_string = _strip_datapipe_from_name(output_string)
52    return output_string
53
54
55def _gen_invalid_iterdatapipe_msg(datapipe):
56    return (
57        "This iterator has been invalidated because another iterator has been created "
58        f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n"
59        "This may be caused multiple references to the same IterDataPipe. We recommend "
60        "using `.fork()` if that is necessary."
61    )
62
63
64_feedback_msg = (
65    "\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free "
66    "to comment on this issue: https://github.com/pytorch/data/issues/45."
67)
68
69
70def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None:
71    r"""
72    Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception.
73
74    In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well.
75    """
76    if next_method_exists:
77        # This is the case where `IterDataPipe` has both `__iter__` and `__next__`.
78        # The `_valid_iterator_id` should either be never set (`None`), or set by at most one
79        # iterator (`0`). Otherwise, it means there are multiple iterators.
80        if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0:
81            extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method"
82            raise RuntimeError(
83                _gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg
84            )
85    elif (
86        hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True
87    ):
88        if hasattr(datapipe, "_check_valid_iterator_id"):
89            if not datapipe._check_valid_iterator_id(iterator_id):
90                raise RuntimeError(
91                    "This iterator has been invalidated, because a new iterator has been created "
92                    f"from one of the ChildDataPipes of "
93                    f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}."
94                    + _feedback_msg
95                )
96        else:
97            raise RuntimeError(
98                "ChildDataPipe must have method `_check_valid_iterator_id`."
99            )
100    elif datapipe._valid_iterator_id != iterator_id:
101        raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg)
102
103
104def _set_datapipe_valid_iterator_id(datapipe):
105    """Given a DataPipe, updates its valid iterator ID and reset the DataPipe."""
106    if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
107        if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"):
108            datapipe._set_main_datapipe_valid_iterator_id()  # reset() is called within this method when appropriate
109        else:
110            raise RuntimeError(
111                "ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`."
112            )
113    else:
114        if datapipe._valid_iterator_id is None:
115            datapipe._valid_iterator_id = 0
116        else:
117            datapipe._valid_iterator_id += 1
118        datapipe.reset()
119    return datapipe._valid_iterator_id
120
121
122def hook_iterator(namespace):
123    r"""
124    Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`.
125
126    This is done for the purpose of profiling and checking if an iterator is still valid.
127    """
128
129    def profiler_record_fn_context(datapipe):
130        if not hasattr(datapipe, "_profile_name"):
131            datapipe._profile_name = _generate_iterdatapipe_msg(
132                datapipe, simplify_dp_name=True
133            )
134        return torch.autograd.profiler.record_function(datapipe._profile_name)
135
136    class IteratorDecorator:
137        r"""
138        Wrap the iterator and modifying its `__next__` method.
139
140        This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function.
141        Those `__iter__` method commonly returns `self` but not necessarily.
142        """
143
144        def __init__(self, iterator, datapipe, iterator_id, has_next_method):
145            self.iterator = iterator
146            self.datapipe = datapipe
147            self.iterator_id = iterator_id
148            self._profiler_enabled = torch.autograd._profiler_enabled()
149            # Check if `__iter__` returns `self` and `DataPipe` has `__next__`
150            self.self_and_has_next_method = (
151                self.iterator is self.datapipe and has_next_method
152            )
153
154        def __iter__(self):
155            return self
156
157        def _get_next(self):
158            """Return next with logic related to iterator validity, profiler, and incrementation of samples yielded."""
159            _check_iterator_valid(self.datapipe, self.iterator_id)
160            result = next(self.iterator)
161            if not self.self_and_has_next_method:
162                self.datapipe._number_of_samples_yielded += 1
163            return result
164
165        def __next__(self):
166            # TODO: Add try-except to in-place reduce traceback from the Exception
167            # See: https://github.com/pytorch/data/issues/284
168            if self._profiler_enabled:
169                with profiler_record_fn_context(self.datapipe):
170                    return self._get_next()
171            else:  # Decided against using `contextlib.nullcontext` for performance reasons
172                return self._get_next()
173
174        def __getattr__(self, name):
175            return getattr(self.iterator, name)
176
177    func = namespace["__iter__"]
178
179    # ``__iter__`` of IterDataPipe is a generator function
180    if inspect.isgeneratorfunction(func):
181
182        @functools.wraps(func)
183        def wrap_generator(*args, **kwargs):
184            gen = func(*args, **kwargs)
185            datapipe = args[0]
186            if datapipe._fast_forward_iterator:
187                it = datapipe._fast_forward_iterator
188                datapipe._fast_forward_iterator = None
189                datapipe._snapshot_state = _SnapshotState.Iterating
190                while True:
191                    try:
192                        yield next(it)
193                    except StopIteration:
194                        return
195            iterator_id = _set_datapipe_valid_iterator_id(
196                datapipe
197            )  # This ID is tied to each created iterator
198            _profiler_enabled = torch.autograd._profiler_enabled()
199            try:
200                if _profiler_enabled:
201                    with profiler_record_fn_context(datapipe):
202                        response = gen.send(None)
203                else:
204                    response = gen.send(None)
205
206                while True:
207                    datapipe._number_of_samples_yielded += 1
208                    request = yield response
209                    # Pass through here every time `__next__` is called
210                    if _profiler_enabled:
211                        with profiler_record_fn_context(datapipe):
212                            _check_iterator_valid(datapipe, iterator_id)
213                            response = gen.send(request)
214                    else:  # Decided against using `contextlib.nullcontext` for performance reasons
215                        _check_iterator_valid(datapipe, iterator_id)
216                        response = gen.send(request)
217            except StopIteration as e:
218                return
219            except Exception as e:
220                # TODO: Simplify the traceback message to skip over `response = gen.send(None)`
221                #       Part of https://github.com/pytorch/data/issues/284
222                datapipe = args[0]
223                msg = "thrown by __iter__ of"
224                single_iterator_msg = "single iterator per IterDataPipe constraint"
225                if hasattr(e.args, "__len__"):
226                    full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
227                    if len(e.args) == 0 or not isinstance(
228                        e.args[0], str
229                    ):  # If an exception message doesn't exist
230                        e.args = (f"\nThis exception is {full_msg}",)
231                    elif msg not in e.args[0] and single_iterator_msg not in e.args[0]:
232                        e.args = (
233                            e.args[0] + f"\nThis exception is {full_msg}",
234                        ) + e.args[1:]
235                raise
236
237        namespace["__iter__"] = wrap_generator
238    else:  # ``__iter__`` of IterDataPipe is NOT a generator function
239        # IterDataPipe is an iterator with both ``__iter__`` and ``__next__``
240        # And ``__iter__`` may or may not return `self`
241        if "__next__" in namespace:  # If `__next__` exists, put a wrapper around it
242            next_func = namespace["__next__"]
243
244            @functools.wraps(next_func)
245            def wrap_next(*args, **kwargs):
246                datapipe = args[0]
247                if torch.autograd._profiler_enabled():
248                    with profiler_record_fn_context(datapipe):
249                        result = next_func(*args, **kwargs)
250                else:
251                    result = next_func(*args, **kwargs)
252                datapipe._number_of_samples_yielded += 1
253                return result
254
255            namespace["__next__"] = wrap_next
256
257            # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
258            # the user will be violating the iterator protocol. Potential issue:
259            # 1. Valid iterator ID may not update or checked properly
260            # 2. The number of samples yielded will be miscounted
261
262        # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
263        @functools.wraps(func)
264        def wrap_iter(*args, **kwargs):
265            iter_ret = func(*args, **kwargs)
266            datapipe = args[0]
267            datapipe._snapshot_state = _SnapshotState.Iterating
268            if datapipe._fast_forward_iterator:
269                iter_ret = datapipe._fast_forward_iterator
270                datapipe._fast_forward_iterator = None
271                return iter_ret
272            iterator_id = _set_datapipe_valid_iterator_id(
273                datapipe
274            )  # This ID is tied to each created iterator
275            return IteratorDecorator(
276                iter_ret, datapipe, iterator_id, "__next__" in namespace
277            )
278
279        namespace["__iter__"] = wrap_iter
280