1# mypy: allow-untyped-defs 2import inspect 3from functools import wraps 4from typing import Any, Callable, get_type_hints, Optional, Type, Union 5 6from torch.utils.data.datapipes._typing import _DataPipeMeta 7from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe 8 9 10###################################################### 11# Functional API 12###################################################### 13class functional_datapipe: 14 name: str 15 16 def __init__(self, name: str, enable_df_api_tracing=False) -> None: 17 """ 18 Define a functional datapipe. 19 20 Args: 21 enable_df_api_tracing - if set, any returned DataPipe would accept 22 DataFrames API in tracing mode. 23 """ 24 self.name = name 25 self.enable_df_api_tracing = enable_df_api_tracing 26 27 def __call__(self, cls): 28 if issubclass(cls, IterDataPipe): 29 if isinstance(cls, Type): # type: ignore[arg-type] 30 if not isinstance(cls, _DataPipeMeta): 31 raise TypeError( 32 "`functional_datapipe` can only decorate IterDataPipe" 33 ) 34 # with non_deterministic decorator 35 else: 36 if not isinstance(cls, non_deterministic) and not ( 37 hasattr(cls, "__self__") 38 and isinstance(cls.__self__, non_deterministic) 39 ): 40 raise TypeError( 41 "`functional_datapipe` can only decorate IterDataPipe" 42 ) 43 IterDataPipe.register_datapipe_as_function( 44 self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing 45 ) 46 elif issubclass(cls, MapDataPipe): 47 MapDataPipe.register_datapipe_as_function(self.name, cls) 48 49 return cls 50 51 52###################################################### 53# Determinism 54###################################################### 55_determinism: bool = False 56 57 58class guaranteed_datapipes_determinism: 59 prev: bool 60 61 def __init__(self) -> None: 62 global _determinism 63 self.prev = _determinism 64 _determinism = True 65 66 def __enter__(self) -> None: 67 pass 68 69 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 70 global _determinism 71 _determinism = self.prev 72 73 74class non_deterministic: 75 cls: Optional[Type[IterDataPipe]] = None 76 # TODO: Lambda for picking 77 deterministic_fn: Callable[[], bool] 78 79 def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None: 80 # 1. Decorator doesn't have any argument 81 if isinstance(arg, Type): # type: ignore[arg-type] 82 if not issubclass(arg, IterDataPipe): # type: ignore[arg-type] 83 raise TypeError( 84 "Only `IterDataPipe` can be decorated with `non_deterministic`" 85 f", but {arg.__name__} is found" 86 ) 87 self.cls = arg # type: ignore[assignment] 88 # 2. Decorator has an argument of a function 89 # This class should behave differently given different inputs. Use this 90 # function to verify the determinism for each instance. 91 # When the function returns True, the instance is non-deterministic. Otherwise, 92 # the instance is a deterministic DataPipe. 93 elif isinstance(arg, Callable): # type:ignore[arg-type] 94 self.deterministic_fn = arg # type: ignore[assignment, misc] 95 else: 96 raise TypeError(f"{arg} can not be decorated by non_deterministic") 97 98 def __call__(self, *args, **kwargs): 99 global _determinism 100 # Decorate IterDataPipe 101 if self.cls is not None: 102 if _determinism: 103 raise TypeError( 104 f"{self.cls.__name__} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. " 105 "You can turn off determinism for this DataPipe if that is acceptable " 106 "for your application" 107 ) 108 return self.cls(*args, **kwargs) # type: ignore[call-arg] 109 110 # Decorate with a functional argument 111 if not ( 112 isinstance(args[0], type) 113 and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] 114 ): 115 raise TypeError( 116 f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" 117 ) 118 self.cls = args[0] 119 return self.deterministic_wrapper_fn 120 121 def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe: 122 res = self.deterministic_fn(*args, **kwargs) # type: ignore[call-arg, misc] 123 if not isinstance(res, bool): 124 raise TypeError( 125 "deterministic_fn of `non_deterministic` decorator is required " 126 f"to return a boolean value, but {type(res)} is found" 127 ) 128 global _determinism 129 if _determinism and res: 130 raise TypeError( 131 f"{self.cls.__name__} is non-deterministic with the inputs, but you set " # type: ignore[union-attr] 132 "'guaranteed_datapipes_determinism'. You can turn off determinism " 133 "for this DataPipe if that is acceptable for your application" 134 ) 135 return self.cls(*args, **kwargs) # type: ignore[call-arg, misc] 136 137 138###################################################### 139# Type validation 140###################################################### 141# Validate each argument of DataPipe with hint as a subtype of the hint. 142def argument_validation(f): 143 signature = inspect.signature(f) 144 hints = get_type_hints(f) 145 146 @wraps(f) 147 def wrapper(*args, **kwargs): 148 bound = signature.bind(*args, **kwargs) 149 for argument_name, value in bound.arguments.items(): 150 if argument_name in hints and isinstance( 151 hints[argument_name], _DataPipeMeta 152 ): 153 hint = hints[argument_name] 154 if not isinstance(value, IterDataPipe): 155 raise TypeError( 156 f"Expected argument '{argument_name}' as a IterDataPipe, but found {type(value)}" 157 ) 158 if not value.type.issubtype(hint.type): 159 raise TypeError( 160 f"Expected type of argument '{argument_name}' as a subtype of " 161 f"hint {hint.type}, but found {value.type}" 162 ) 163 164 return f(*args, **kwargs) 165 166 return wrapper 167 168 169# Default value is True 170_runtime_validation_enabled: bool = True 171 172 173class runtime_validation_disabled: 174 prev: bool 175 176 def __init__(self) -> None: 177 global _runtime_validation_enabled 178 self.prev = _runtime_validation_enabled 179 _runtime_validation_enabled = False 180 181 def __enter__(self) -> None: 182 pass 183 184 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 185 global _runtime_validation_enabled 186 _runtime_validation_enabled = self.prev 187 188 189# Runtime checking 190# Validate output data is subtype of return hint 191def runtime_validation(f): 192 # TODO: 193 # Can be extended to validate '__getitem__' and nonblocking 194 if f.__name__ != "__iter__": 195 raise TypeError( 196 f"Can not decorate function {f.__name__} with 'runtime_validation'" 197 ) 198 199 @wraps(f) 200 def wrapper(self): 201 global _runtime_validation_enabled 202 if not _runtime_validation_enabled: 203 yield from f(self) 204 else: 205 it = f(self) 206 for d in it: 207 if not self.type.issubtype_of_instance(d): 208 raise RuntimeError( 209 f"Expected an instance as subtype of {self.type}, but found {d}({type(d)})" 210 ) 211 yield d 212 213 return wrapper 214