1# mypy: allow-untyped-defs 2# Taking reference from official Python typing 3# https://github.com/python/cpython/blob/master/Lib/typing.py 4 5import collections 6import functools 7import numbers 8import sys 9 10# Please check [Note: TypeMeta and TypeAlias] 11# In case of metaclass conflict due to ABCMeta or _ProtocolMeta 12# For Python 3.9, only Protocol in typing uses metaclass 13from abc import ABCMeta 14 15# TODO: Use TypeAlias when Python 3.6 is deprecated 16from typing import ( # type: ignore[attr-defined] 17 _eval_type, 18 _GenericAlias, 19 _tp_cache, 20 _type_check, 21 _type_repr, 22 Any, 23 Dict, 24 ForwardRef, 25 Generic, 26 get_type_hints, 27 Iterator, 28 List, 29 Set, 30 Tuple, 31 TypeVar, 32 Union, 33) 34 35from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator 36 37 38class GenericMeta(ABCMeta): # type: ignore[no-redef] 39 pass 40 41 42class Integer(numbers.Integral): 43 pass 44 45 46class Boolean(numbers.Integral): 47 pass 48 49 50# Python 'type' object is not subscriptable 51# Tuple[int, List, dict] -> valid 52# tuple[int, list, dict] -> invalid 53# Map Python 'type' to abstract base class 54TYPE2ABC = { 55 bool: Boolean, 56 int: Integer, 57 float: numbers.Real, 58 complex: numbers.Complex, 59 dict: Dict, 60 list: List, 61 set: Set, 62 tuple: Tuple, 63 None: type(None), 64} 65 66 67def issubtype(left, right, recursive=True): 68 r""" 69 Check if the left-side type is a subtype of the right-side type. 70 71 If any of type is a composite type like `Union` and `TypeVar` with 72 bounds, it would be expanded into a list of types and check all 73 of left-side types are subtypes of either one from right-side types. 74 """ 75 left = TYPE2ABC.get(left, left) 76 right = TYPE2ABC.get(right, right) 77 78 if right is Any or left == right: 79 return True 80 81 if isinstance(right, _GenericAlias): 82 if getattr(right, "__origin__", None) is Generic: 83 return True 84 85 if right == type(None): 86 return False 87 88 # Right-side type 89 constraints = _decompose_type(right) 90 91 if len(constraints) == 0 or Any in constraints: 92 return True 93 94 if left is Any: 95 return False 96 97 # Left-side type 98 variants = _decompose_type(left) 99 100 # all() will return True for empty variants 101 if len(variants) == 0: 102 return False 103 104 return all( 105 _issubtype_with_constraints(variant, constraints, recursive) 106 for variant in variants 107 ) 108 109 110def _decompose_type(t, to_list=True): 111 if isinstance(t, TypeVar): 112 if t.__bound__ is not None: 113 ts = [t.__bound__] 114 else: 115 # For T_co, __constraints__ is () 116 ts = list(t.__constraints__) 117 elif hasattr(t, "__origin__") and t.__origin__ == Union: 118 ts = t.__args__ 119 else: 120 if not to_list: 121 return None 122 ts = [t] 123 # Ignored: Generator has incompatible item type "object"; expected "Type[Any]" 124 ts = [TYPE2ABC.get(_t, _t) for _t in ts] # type: ignore[misc] 125 return ts 126 127 128def _issubtype_with_constraints(variant, constraints, recursive=True): 129 r""" 130 Check if the variant is a subtype of either one from constraints. 131 132 For composite types like `Union` and `TypeVar` with bounds, they 133 would be expanded for testing. 134 """ 135 if variant in constraints: 136 return True 137 138 # [Note: Subtype for Union and TypeVar] 139 # Python typing is able to flatten Union[Union[...]] or Union[TypeVar]. 140 # But it couldn't flatten the following scenarios: 141 # - Union[int, TypeVar[Union[...]]] 142 # - TypeVar[TypeVar[...]] 143 # So, variant and each constraint may be a TypeVar or a Union. 144 # In these cases, all of inner types from the variant are required to be 145 # extraced and verified as a subtype of any constraint. And, all of 146 # inner types from any constraint being a TypeVar or a Union are 147 # also required to be extracted and verified if the variant belongs to 148 # any of them. 149 150 # Variant 151 vs = _decompose_type(variant, to_list=False) 152 153 # Variant is TypeVar or Union 154 if vs is not None: 155 return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs) 156 157 # Variant is not TypeVar or Union 158 if hasattr(variant, "__origin__") and variant.__origin__ is not None: 159 v_origin = variant.__origin__ 160 # In Python-3.9 typing library untyped generics do not have args 161 v_args = getattr(variant, "__args__", None) 162 else: 163 v_origin = variant 164 v_args = None 165 166 # Constraints 167 for constraint in constraints: 168 cs = _decompose_type(constraint, to_list=False) 169 170 # Constraint is TypeVar or Union 171 if cs is not None: 172 if _issubtype_with_constraints(variant, cs, recursive): 173 return True 174 # Constraint is not TypeVar or Union 175 else: 176 # __origin__ can be None for plain list, tuple, ... in Python 3.6 177 if hasattr(constraint, "__origin__") and constraint.__origin__ is not None: 178 c_origin = constraint.__origin__ 179 if v_origin == c_origin: 180 if not recursive: 181 return True 182 # In Python-3.9 typing library untyped generics do not have args 183 c_args = getattr(constraint, "__args__", None) 184 if c_args is None or len(c_args) == 0: 185 return True 186 if ( 187 v_args is not None 188 and len(v_args) == len(c_args) 189 and all( 190 issubtype(v_arg, c_arg) 191 for v_arg, c_arg in zip(v_args, c_args) 192 ) 193 ): 194 return True 195 # Tuple[int] -> Tuple 196 else: 197 if v_origin == constraint: 198 return True 199 200 return False 201 202 203def issubinstance(data, data_type): 204 if not issubtype(type(data), data_type, recursive=False): 205 return False 206 207 # In Python-3.9 typing library __args__ attribute is not defined for untyped generics 208 dt_args = getattr(data_type, "__args__", None) 209 if isinstance(data, tuple): 210 if dt_args is None or len(dt_args) == 0: 211 return True 212 if len(dt_args) != len(data): 213 return False 214 return all(issubinstance(d, t) for d, t in zip(data, dt_args)) 215 elif isinstance(data, (list, set)): 216 if dt_args is None or len(dt_args) == 0: 217 return True 218 t = dt_args[0] 219 return all(issubinstance(d, t) for d in data) 220 elif isinstance(data, dict): 221 if dt_args is None or len(dt_args) == 0: 222 return True 223 kt, vt = dt_args 224 return all( 225 issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items() 226 ) 227 228 return True 229 230 231# [Note: TypeMeta and TypeAlias] 232# In order to keep compatibility for Python 3.6, use Meta for the typing. 233# TODO: When PyTorch drops the support for Python 3.6, it can be converted 234# into the Alias system and using `__class_getitem__` for DataPipe. The 235# typing system will gain benefit of performance and resolving metaclass 236# conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/ 237 238 239class _DataPipeType: 240 r"""Save type annotation in `param`.""" 241 242 def __init__(self, param): 243 self.param = param 244 245 def __repr__(self): 246 return _type_repr(self.param) 247 248 def __eq__(self, other): 249 if isinstance(other, _DataPipeType): 250 return self.param == other.param 251 return NotImplemented 252 253 def __hash__(self): 254 return hash(self.param) 255 256 def issubtype(self, other): 257 if isinstance(other.param, _GenericAlias): 258 if getattr(other.param, "__origin__", None) is Generic: 259 return True 260 if isinstance(other, _DataPipeType): 261 return issubtype(self.param, other.param) 262 if isinstance(other, type): 263 return issubtype(self.param, other) 264 raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}") 265 266 def issubtype_of_instance(self, other): 267 return issubinstance(other, self.param) 268 269 270# Default type for DataPipe without annotation 271_T_co = TypeVar("_T_co", covariant=True) 272_DEFAULT_TYPE = _DataPipeType(Generic[_T_co]) 273 274 275class _DataPipeMeta(GenericMeta): 276 r""" 277 Metaclass for `DataPipe`. 278 279 Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`. 280 281 Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`. 282 """ 283 284 type: _DataPipeType 285 286 def __new__(cls, name, bases, namespace, **kwargs): 287 return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] 288 289 # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now. 290 cls.__origin__ = None 291 if "type" in namespace: 292 return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] 293 294 namespace["__type_class__"] = False 295 # For plain derived class without annotation 296 for base in bases: 297 if isinstance(base, _DataPipeMeta): 298 return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] 299 300 namespace.update( 301 {"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass} 302 ) 303 return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] 304 305 def __init__(self, name, bases, namespace, **kwargs): 306 super().__init__(name, bases, namespace, **kwargs) # type: ignore[call-overload] 307 308 # TODO: Fix isinstance bug 309 @_tp_cache 310 def _getitem_(self, params): 311 if params is None: 312 raise TypeError(f"{self.__name__}[t]: t can not be None") 313 if isinstance(params, str): 314 params = ForwardRef(params) 315 if not isinstance(params, tuple): 316 params = (params,) 317 318 msg = f"{self.__name__}[t]: t must be a type" 319 params = tuple(_type_check(p, msg) for p in params) 320 321 if isinstance(self.type.param, _GenericAlias): 322 orig = getattr(self.type.param, "__origin__", None) 323 if isinstance(orig, type) and orig is not Generic: 324 p = self.type.param[params] # type: ignore[index] 325 t = _DataPipeType(p) 326 l = len(str(self.type)) + 2 327 name = self.__name__[:-l] 328 name = name + "[" + str(t) + "]" 329 bases = (self,) + self.__bases__ 330 return self.__class__( 331 name, 332 bases, 333 { 334 "__init_subclass__": _dp_init_subclass, 335 "type": t, 336 "__type_class__": True, 337 }, 338 ) 339 340 if len(params) > 1: 341 raise TypeError( 342 f"Too many parameters for {self} actual {len(params)}, expected 1" 343 ) 344 345 t = _DataPipeType(params[0]) 346 347 if not t.issubtype(self.type): 348 raise TypeError( 349 f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]" 350 ) 351 352 # Types are equal, fast path for inheritance 353 if self.type == t: 354 return self 355 356 name = self.__name__ + "[" + str(t) + "]" 357 bases = (self,) + self.__bases__ 358 359 return self.__class__( 360 name, 361 bases, 362 {"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t}, 363 ) 364 365 # TODO: Fix isinstance bug 366 def _eq_(self, other): 367 if not isinstance(other, _DataPipeMeta): 368 return NotImplemented 369 if self.__origin__ is None or other.__origin__ is None: # type: ignore[has-type] 370 return self is other 371 return ( 372 self.__origin__ == other.__origin__ # type: ignore[has-type] 373 and self.type == other.type 374 ) 375 376 # TODO: Fix isinstance bug 377 def _hash_(self): 378 return hash((self.__name__, self.type)) 379 380 381class _IterDataPipeMeta(_DataPipeMeta): 382 r""" 383 Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`. 384 385 Add various functions for behaviors specific to `IterDataPipe`. 386 """ 387 388 def __new__(cls, name, bases, namespace, **kwargs): 389 if "reset" in namespace: 390 reset_func = namespace["reset"] 391 392 @functools.wraps(reset_func) 393 def conditional_reset(*args, **kwargs): 394 r""" 395 Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`. 396 397 This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call. 398 """ 399 datapipe = args[0] 400 if datapipe._snapshot_state in ( 401 _SnapshotState.Iterating, 402 _SnapshotState.NotStarted, 403 ): 404 # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have 405 # already begun iterating. 406 datapipe._number_of_samples_yielded = 0 407 datapipe._fast_forward_iterator = None 408 reset_func(*args, **kwargs) 409 datapipe._snapshot_state = _SnapshotState.Iterating 410 411 namespace["reset"] = conditional_reset 412 413 if "__iter__" in namespace: 414 hook_iterator(namespace) 415 return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] 416 417 418def _dp_init_subclass(sub_cls, *args, **kwargs): 419 # Add function for datapipe instance to reinforce the type 420 sub_cls.reinforce_type = reinforce_type 421 422 # TODO: 423 # - add global switch for type checking at compile-time 424 425 # Ignore internal type class 426 if getattr(sub_cls, "__type_class__", False): 427 return 428 429 # Check if the string type is valid 430 if isinstance(sub_cls.type.param, ForwardRef): 431 base_globals = sys.modules[sub_cls.__module__].__dict__ 432 try: 433 param = _eval_type(sub_cls.type.param, base_globals, locals()) 434 sub_cls.type.param = param 435 except TypeError as e: 436 raise TypeError( 437 f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing" 438 ) from e 439 440 if "__iter__" in sub_cls.__dict__: 441 iter_fn = sub_cls.__dict__["__iter__"] 442 hints = get_type_hints(iter_fn) 443 if "return" in hints: 444 return_hint = hints["return"] 445 # Plain Return Hint for Python 3.6 446 if return_hint == Iterator: 447 return 448 if not ( 449 hasattr(return_hint, "__origin__") 450 and ( 451 return_hint.__origin__ == Iterator 452 or return_hint.__origin__ == collections.abc.Iterator 453 ) 454 ): 455 raise TypeError( 456 "Expected 'Iterator' as the return annotation for `__iter__` of {}" 457 ", but found {}".format( 458 sub_cls.__name__, _type_repr(hints["return"]) 459 ) 460 ) 461 data_type = return_hint.__args__[0] 462 if not issubtype(data_type, sub_cls.type.param): 463 raise TypeError( 464 f"Expected return type of '__iter__' as a subtype of {sub_cls.type}," 465 f" but found {_type_repr(data_type)} for {sub_cls.__name__}" 466 ) 467 468 469def reinforce_type(self, expected_type): 470 r""" 471 Reinforce the type for DataPipe instance. 472 473 And the 'expected_type' is required to be a subtype of the original type 474 hint to restrict the type requirement of DataPipe instance. 475 """ 476 if isinstance(expected_type, tuple): 477 expected_type = Tuple[expected_type] 478 _type_check(expected_type, msg="'expected_type' must be a type") 479 480 if not issubtype(expected_type, self.type.param): 481 raise TypeError( 482 f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}" 483 ) 484 485 self.type = _DataPipeType(expected_type) 486 return self 487