xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/_typing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Taking reference from official Python typing
3# https://github.com/python/cpython/blob/master/Lib/typing.py
4
5import collections
6import functools
7import numbers
8import sys
9
10# Please check [Note: TypeMeta and TypeAlias]
11# In case of metaclass conflict due to ABCMeta or _ProtocolMeta
12# For Python 3.9, only Protocol in typing uses metaclass
13from abc import ABCMeta
14
15# TODO: Use TypeAlias when Python 3.6 is deprecated
16from typing import (  # type: ignore[attr-defined]
17    _eval_type,
18    _GenericAlias,
19    _tp_cache,
20    _type_check,
21    _type_repr,
22    Any,
23    Dict,
24    ForwardRef,
25    Generic,
26    get_type_hints,
27    Iterator,
28    List,
29    Set,
30    Tuple,
31    TypeVar,
32    Union,
33)
34
35from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator
36
37
38class GenericMeta(ABCMeta):  # type: ignore[no-redef]
39    pass
40
41
42class Integer(numbers.Integral):
43    pass
44
45
46class Boolean(numbers.Integral):
47    pass
48
49
50# Python 'type' object is not subscriptable
51# Tuple[int, List, dict] -> valid
52# tuple[int, list, dict] -> invalid
53# Map Python 'type' to abstract base class
54TYPE2ABC = {
55    bool: Boolean,
56    int: Integer,
57    float: numbers.Real,
58    complex: numbers.Complex,
59    dict: Dict,
60    list: List,
61    set: Set,
62    tuple: Tuple,
63    None: type(None),
64}
65
66
67def issubtype(left, right, recursive=True):
68    r"""
69    Check if the left-side type is a subtype of the right-side type.
70
71    If any of type is a composite type like `Union` and `TypeVar` with
72    bounds, it would be expanded into a list of types and check all
73    of left-side types are subtypes of either one from right-side types.
74    """
75    left = TYPE2ABC.get(left, left)
76    right = TYPE2ABC.get(right, right)
77
78    if right is Any or left == right:
79        return True
80
81    if isinstance(right, _GenericAlias):
82        if getattr(right, "__origin__", None) is Generic:
83            return True
84
85    if right == type(None):
86        return False
87
88    # Right-side type
89    constraints = _decompose_type(right)
90
91    if len(constraints) == 0 or Any in constraints:
92        return True
93
94    if left is Any:
95        return False
96
97    # Left-side type
98    variants = _decompose_type(left)
99
100    # all() will return True for empty variants
101    if len(variants) == 0:
102        return False
103
104    return all(
105        _issubtype_with_constraints(variant, constraints, recursive)
106        for variant in variants
107    )
108
109
110def _decompose_type(t, to_list=True):
111    if isinstance(t, TypeVar):
112        if t.__bound__ is not None:
113            ts = [t.__bound__]
114        else:
115            # For T_co, __constraints__ is ()
116            ts = list(t.__constraints__)
117    elif hasattr(t, "__origin__") and t.__origin__ == Union:
118        ts = t.__args__
119    else:
120        if not to_list:
121            return None
122        ts = [t]
123    # Ignored: Generator has incompatible item type "object"; expected "Type[Any]"
124    ts = [TYPE2ABC.get(_t, _t) for _t in ts]  # type: ignore[misc]
125    return ts
126
127
128def _issubtype_with_constraints(variant, constraints, recursive=True):
129    r"""
130    Check if the variant is a subtype of either one from constraints.
131
132    For composite types like `Union` and `TypeVar` with bounds, they
133    would be expanded for testing.
134    """
135    if variant in constraints:
136        return True
137
138    # [Note: Subtype for Union and TypeVar]
139    # Python typing is able to flatten Union[Union[...]] or Union[TypeVar].
140    # But it couldn't flatten the following scenarios:
141    #   - Union[int, TypeVar[Union[...]]]
142    #   - TypeVar[TypeVar[...]]
143    # So, variant and each constraint may be a TypeVar or a Union.
144    # In these cases, all of inner types from the variant are required to be
145    # extraced and verified as a subtype of any constraint. And, all of
146    # inner types from any constraint being a TypeVar or a Union are
147    # also required to be extracted and verified if the variant belongs to
148    # any of them.
149
150    # Variant
151    vs = _decompose_type(variant, to_list=False)
152
153    # Variant is TypeVar or Union
154    if vs is not None:
155        return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs)
156
157    # Variant is not TypeVar or Union
158    if hasattr(variant, "__origin__") and variant.__origin__ is not None:
159        v_origin = variant.__origin__
160        # In Python-3.9 typing library untyped generics do not have args
161        v_args = getattr(variant, "__args__", None)
162    else:
163        v_origin = variant
164        v_args = None
165
166    # Constraints
167    for constraint in constraints:
168        cs = _decompose_type(constraint, to_list=False)
169
170        # Constraint is TypeVar or Union
171        if cs is not None:
172            if _issubtype_with_constraints(variant, cs, recursive):
173                return True
174        # Constraint is not TypeVar or Union
175        else:
176            # __origin__ can be None for plain list, tuple, ... in Python 3.6
177            if hasattr(constraint, "__origin__") and constraint.__origin__ is not None:
178                c_origin = constraint.__origin__
179                if v_origin == c_origin:
180                    if not recursive:
181                        return True
182                    # In Python-3.9 typing library untyped generics do not have args
183                    c_args = getattr(constraint, "__args__", None)
184                    if c_args is None or len(c_args) == 0:
185                        return True
186                    if (
187                        v_args is not None
188                        and len(v_args) == len(c_args)
189                        and all(
190                            issubtype(v_arg, c_arg)
191                            for v_arg, c_arg in zip(v_args, c_args)
192                        )
193                    ):
194                        return True
195            # Tuple[int] -> Tuple
196            else:
197                if v_origin == constraint:
198                    return True
199
200    return False
201
202
203def issubinstance(data, data_type):
204    if not issubtype(type(data), data_type, recursive=False):
205        return False
206
207    # In Python-3.9 typing library __args__ attribute is not defined for untyped generics
208    dt_args = getattr(data_type, "__args__", None)
209    if isinstance(data, tuple):
210        if dt_args is None or len(dt_args) == 0:
211            return True
212        if len(dt_args) != len(data):
213            return False
214        return all(issubinstance(d, t) for d, t in zip(data, dt_args))
215    elif isinstance(data, (list, set)):
216        if dt_args is None or len(dt_args) == 0:
217            return True
218        t = dt_args[0]
219        return all(issubinstance(d, t) for d in data)
220    elif isinstance(data, dict):
221        if dt_args is None or len(dt_args) == 0:
222            return True
223        kt, vt = dt_args
224        return all(
225            issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items()
226        )
227
228    return True
229
230
231# [Note: TypeMeta and TypeAlias]
232# In order to keep compatibility for Python 3.6, use Meta for the typing.
233# TODO: When PyTorch drops the support for Python 3.6, it can be converted
234# into the Alias system and using `__class_getitem__` for DataPipe. The
235# typing system will gain benefit of performance and resolving metaclass
236# conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/
237
238
239class _DataPipeType:
240    r"""Save type annotation in `param`."""
241
242    def __init__(self, param):
243        self.param = param
244
245    def __repr__(self):
246        return _type_repr(self.param)
247
248    def __eq__(self, other):
249        if isinstance(other, _DataPipeType):
250            return self.param == other.param
251        return NotImplemented
252
253    def __hash__(self):
254        return hash(self.param)
255
256    def issubtype(self, other):
257        if isinstance(other.param, _GenericAlias):
258            if getattr(other.param, "__origin__", None) is Generic:
259                return True
260        if isinstance(other, _DataPipeType):
261            return issubtype(self.param, other.param)
262        if isinstance(other, type):
263            return issubtype(self.param, other)
264        raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}")
265
266    def issubtype_of_instance(self, other):
267        return issubinstance(other, self.param)
268
269
270# Default type for DataPipe without annotation
271_T_co = TypeVar("_T_co", covariant=True)
272_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
273
274
275class _DataPipeMeta(GenericMeta):
276    r"""
277    Metaclass for `DataPipe`.
278
279    Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`.
280
281    Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`.
282    """
283
284    type: _DataPipeType
285
286    def __new__(cls, name, bases, namespace, **kwargs):
287        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
288
289        # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
290        cls.__origin__ = None
291        if "type" in namespace:
292            return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
293
294        namespace["__type_class__"] = False
295        #  For plain derived class without annotation
296        for base in bases:
297            if isinstance(base, _DataPipeMeta):
298                return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
299
300        namespace.update(
301            {"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass}
302        )
303        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
304
305    def __init__(self, name, bases, namespace, **kwargs):
306        super().__init__(name, bases, namespace, **kwargs)  # type: ignore[call-overload]
307
308    # TODO: Fix isinstance bug
309    @_tp_cache
310    def _getitem_(self, params):
311        if params is None:
312            raise TypeError(f"{self.__name__}[t]: t can not be None")
313        if isinstance(params, str):
314            params = ForwardRef(params)
315        if not isinstance(params, tuple):
316            params = (params,)
317
318        msg = f"{self.__name__}[t]: t must be a type"
319        params = tuple(_type_check(p, msg) for p in params)
320
321        if isinstance(self.type.param, _GenericAlias):
322            orig = getattr(self.type.param, "__origin__", None)
323            if isinstance(orig, type) and orig is not Generic:
324                p = self.type.param[params]  # type: ignore[index]
325                t = _DataPipeType(p)
326                l = len(str(self.type)) + 2
327                name = self.__name__[:-l]
328                name = name + "[" + str(t) + "]"
329                bases = (self,) + self.__bases__
330                return self.__class__(
331                    name,
332                    bases,
333                    {
334                        "__init_subclass__": _dp_init_subclass,
335                        "type": t,
336                        "__type_class__": True,
337                    },
338                )
339
340        if len(params) > 1:
341            raise TypeError(
342                f"Too many parameters for {self} actual {len(params)}, expected 1"
343            )
344
345        t = _DataPipeType(params[0])
346
347        if not t.issubtype(self.type):
348            raise TypeError(
349                f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]"
350            )
351
352        # Types are equal, fast path for inheritance
353        if self.type == t:
354            return self
355
356        name = self.__name__ + "[" + str(t) + "]"
357        bases = (self,) + self.__bases__
358
359        return self.__class__(
360            name,
361            bases,
362            {"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t},
363        )
364
365    # TODO: Fix isinstance bug
366    def _eq_(self, other):
367        if not isinstance(other, _DataPipeMeta):
368            return NotImplemented
369        if self.__origin__ is None or other.__origin__ is None:  # type: ignore[has-type]
370            return self is other
371        return (
372            self.__origin__ == other.__origin__  # type: ignore[has-type]
373            and self.type == other.type
374        )
375
376    # TODO: Fix isinstance bug
377    def _hash_(self):
378        return hash((self.__name__, self.type))
379
380
381class _IterDataPipeMeta(_DataPipeMeta):
382    r"""
383    Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`.
384
385    Add various functions for behaviors specific to `IterDataPipe`.
386    """
387
388    def __new__(cls, name, bases, namespace, **kwargs):
389        if "reset" in namespace:
390            reset_func = namespace["reset"]
391
392            @functools.wraps(reset_func)
393            def conditional_reset(*args, **kwargs):
394                r"""
395                Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`.
396
397                This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call.
398                """
399                datapipe = args[0]
400                if datapipe._snapshot_state in (
401                    _SnapshotState.Iterating,
402                    _SnapshotState.NotStarted,
403                ):
404                    # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have
405                    # already begun iterating.
406                    datapipe._number_of_samples_yielded = 0
407                    datapipe._fast_forward_iterator = None
408                    reset_func(*args, **kwargs)
409                datapipe._snapshot_state = _SnapshotState.Iterating
410
411            namespace["reset"] = conditional_reset
412
413        if "__iter__" in namespace:
414            hook_iterator(namespace)
415        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
416
417
418def _dp_init_subclass(sub_cls, *args, **kwargs):
419    # Add function for datapipe instance to reinforce the type
420    sub_cls.reinforce_type = reinforce_type
421
422    # TODO:
423    # - add global switch for type checking at compile-time
424
425    # Ignore internal type class
426    if getattr(sub_cls, "__type_class__", False):
427        return
428
429    # Check if the string type is valid
430    if isinstance(sub_cls.type.param, ForwardRef):
431        base_globals = sys.modules[sub_cls.__module__].__dict__
432        try:
433            param = _eval_type(sub_cls.type.param, base_globals, locals())
434            sub_cls.type.param = param
435        except TypeError as e:
436            raise TypeError(
437                f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing"
438            ) from e
439
440    if "__iter__" in sub_cls.__dict__:
441        iter_fn = sub_cls.__dict__["__iter__"]
442        hints = get_type_hints(iter_fn)
443        if "return" in hints:
444            return_hint = hints["return"]
445            # Plain Return Hint for Python 3.6
446            if return_hint == Iterator:
447                return
448            if not (
449                hasattr(return_hint, "__origin__")
450                and (
451                    return_hint.__origin__ == Iterator
452                    or return_hint.__origin__ == collections.abc.Iterator
453                )
454            ):
455                raise TypeError(
456                    "Expected 'Iterator' as the return annotation for `__iter__` of {}"
457                    ", but found {}".format(
458                        sub_cls.__name__, _type_repr(hints["return"])
459                    )
460                )
461            data_type = return_hint.__args__[0]
462            if not issubtype(data_type, sub_cls.type.param):
463                raise TypeError(
464                    f"Expected return type of '__iter__' as a subtype of {sub_cls.type},"
465                    f" but found {_type_repr(data_type)} for {sub_cls.__name__}"
466                )
467
468
469def reinforce_type(self, expected_type):
470    r"""
471    Reinforce the type for DataPipe instance.
472
473    And the 'expected_type' is required to be a subtype of the original type
474    hint to restrict the type requirement of DataPipe instance.
475    """
476    if isinstance(expected_type, tuple):
477        expected_type = Tuple[expected_type]
478    _type_check(expected_type, msg="'expected_type' must be a type")
479
480    if not issubtype(expected_type, self.type.param):
481        raise TypeError(
482            f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}"
483        )
484
485    self.type = _DataPipeType(expected_type)
486    return self
487