xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/_decorator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3from functools import wraps
4from typing import Any, Callable, get_type_hints, Optional, Type, Union
5
6from torch.utils.data.datapipes._typing import _DataPipeMeta
7from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
8
9
10######################################################
11# Functional API
12######################################################
13class functional_datapipe:
14    name: str
15
16    def __init__(self, name: str, enable_df_api_tracing=False) -> None:
17        """
18        Define a functional datapipe.
19
20        Args:
21            enable_df_api_tracing - if set, any returned DataPipe would accept
22            DataFrames API in tracing mode.
23        """
24        self.name = name
25        self.enable_df_api_tracing = enable_df_api_tracing
26
27    def __call__(self, cls):
28        if issubclass(cls, IterDataPipe):
29            if isinstance(cls, Type):  # type: ignore[arg-type]
30                if not isinstance(cls, _DataPipeMeta):
31                    raise TypeError(
32                        "`functional_datapipe` can only decorate IterDataPipe"
33                    )
34            # with non_deterministic decorator
35            else:
36                if not isinstance(cls, non_deterministic) and not (
37                    hasattr(cls, "__self__")
38                    and isinstance(cls.__self__, non_deterministic)
39                ):
40                    raise TypeError(
41                        "`functional_datapipe` can only decorate IterDataPipe"
42                    )
43            IterDataPipe.register_datapipe_as_function(
44                self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing
45            )
46        elif issubclass(cls, MapDataPipe):
47            MapDataPipe.register_datapipe_as_function(self.name, cls)
48
49        return cls
50
51
52######################################################
53# Determinism
54######################################################
55_determinism: bool = False
56
57
58class guaranteed_datapipes_determinism:
59    prev: bool
60
61    def __init__(self) -> None:
62        global _determinism
63        self.prev = _determinism
64        _determinism = True
65
66    def __enter__(self) -> None:
67        pass
68
69    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
70        global _determinism
71        _determinism = self.prev
72
73
74class non_deterministic:
75    cls: Optional[Type[IterDataPipe]] = None
76    # TODO: Lambda for picking
77    deterministic_fn: Callable[[], bool]
78
79    def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None:
80        # 1. Decorator doesn't have any argument
81        if isinstance(arg, Type):  # type: ignore[arg-type]
82            if not issubclass(arg, IterDataPipe):  # type: ignore[arg-type]
83                raise TypeError(
84                    "Only `IterDataPipe` can be decorated with `non_deterministic`"
85                    f", but {arg.__name__} is found"
86                )
87            self.cls = arg  # type: ignore[assignment]
88        # 2. Decorator has an argument of a function
89        #    This class should behave differently given different inputs. Use this
90        #    function to verify the determinism for each instance.
91        #    When the function returns True, the instance is non-deterministic. Otherwise,
92        #    the instance is a deterministic DataPipe.
93        elif isinstance(arg, Callable):  # type:ignore[arg-type]
94            self.deterministic_fn = arg  # type: ignore[assignment, misc]
95        else:
96            raise TypeError(f"{arg} can not be decorated by non_deterministic")
97
98    def __call__(self, *args, **kwargs):
99        global _determinism
100        #  Decorate IterDataPipe
101        if self.cls is not None:
102            if _determinism:
103                raise TypeError(
104                    f"{self.cls.__name__} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. "
105                    "You can turn off determinism for this DataPipe if that is acceptable "
106                    "for your application"
107                )
108            return self.cls(*args, **kwargs)  # type: ignore[call-arg]
109
110        # Decorate with a functional argument
111        if not (
112            isinstance(args[0], type)
113            and issubclass(args[0], IterDataPipe)  # type: ignore[arg-type]
114        ):
115            raise TypeError(
116                f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"
117            )
118        self.cls = args[0]
119        return self.deterministic_wrapper_fn
120
121    def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe:
122        res = self.deterministic_fn(*args, **kwargs)  # type: ignore[call-arg, misc]
123        if not isinstance(res, bool):
124            raise TypeError(
125                "deterministic_fn of `non_deterministic` decorator is required "
126                f"to return a boolean value, but {type(res)} is found"
127            )
128        global _determinism
129        if _determinism and res:
130            raise TypeError(
131                f"{self.cls.__name__} is non-deterministic with the inputs, but you set "  # type: ignore[union-attr]
132                "'guaranteed_datapipes_determinism'. You can turn off determinism "
133                "for this DataPipe if that is acceptable for your application"
134            )
135        return self.cls(*args, **kwargs)  # type: ignore[call-arg, misc]
136
137
138######################################################
139# Type validation
140######################################################
141# Validate each argument of DataPipe with hint as a subtype of the hint.
142def argument_validation(f):
143    signature = inspect.signature(f)
144    hints = get_type_hints(f)
145
146    @wraps(f)
147    def wrapper(*args, **kwargs):
148        bound = signature.bind(*args, **kwargs)
149        for argument_name, value in bound.arguments.items():
150            if argument_name in hints and isinstance(
151                hints[argument_name], _DataPipeMeta
152            ):
153                hint = hints[argument_name]
154                if not isinstance(value, IterDataPipe):
155                    raise TypeError(
156                        f"Expected argument '{argument_name}' as a IterDataPipe, but found {type(value)}"
157                    )
158                if not value.type.issubtype(hint.type):
159                    raise TypeError(
160                        f"Expected type of argument '{argument_name}' as a subtype of "
161                        f"hint {hint.type}, but found {value.type}"
162                    )
163
164        return f(*args, **kwargs)
165
166    return wrapper
167
168
169# Default value is True
170_runtime_validation_enabled: bool = True
171
172
173class runtime_validation_disabled:
174    prev: bool
175
176    def __init__(self) -> None:
177        global _runtime_validation_enabled
178        self.prev = _runtime_validation_enabled
179        _runtime_validation_enabled = False
180
181    def __enter__(self) -> None:
182        pass
183
184    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
185        global _runtime_validation_enabled
186        _runtime_validation_enabled = self.prev
187
188
189# Runtime checking
190# Validate output data is subtype of return hint
191def runtime_validation(f):
192    # TODO:
193    # Can be extended to validate '__getitem__' and nonblocking
194    if f.__name__ != "__iter__":
195        raise TypeError(
196            f"Can not decorate function {f.__name__} with 'runtime_validation'"
197        )
198
199    @wraps(f)
200    def wrapper(self):
201        global _runtime_validation_enabled
202        if not _runtime_validation_enabled:
203            yield from f(self)
204        else:
205            it = f(self)
206            for d in it:
207                if not self.type.issubtype_of_instance(d):
208                    raise RuntimeError(
209                        f"Expected an instance as subtype of {self.type}, but found {d}({type(d)})"
210                    )
211                yield d
212
213    return wrapper
214