1# mypy: allow-untyped-defs 2import fnmatch 3import functools 4import inspect 5import os 6import warnings 7from io import IOBase 8from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 9 10from torch.utils._import_utils import dill_available 11 12 13__all__ = [ 14 "validate_input_col", 15 "StreamWrapper", 16 "get_file_binaries_from_pathnames", 17 "get_file_pathnames_from_root", 18 "match_masks", 19 "validate_pathname_binary_tuple", 20] 21 22 23# BC for torchdata 24DILL_AVAILABLE = dill_available() 25 26 27def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]): 28 """ 29 Check that function used in a callable datapipe works with the input column. 30 31 This simply ensures that the number of positional arguments matches the size 32 of the input column. The function must not contain any non-default 33 keyword-only arguments. 34 35 Examples: 36 >>> # xdoctest: +SKIP("Failing on some CI machines") 37 >>> def f(a, b, *, c=1): 38 >>> return a + b + c 39 >>> def f_def(a, b=1, *, c=1): 40 >>> return a + b + c 41 >>> assert validate_input_col(f, [1, 2]) 42 >>> assert validate_input_col(f_def, 1) 43 >>> assert validate_input_col(f_def, [1, 2]) 44 45 Notes: 46 If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments, 47 for example, f(a, *args), the validator will accept any size of input column 48 greater than or equal to the number of positional arguments. 49 (in this case, 1). 50 51 Args: 52 fn: The function to check. 53 input_col: The input column to check. 54 55 Raises: 56 ValueError: If the function is not compatible with the input column. 57 """ 58 try: 59 sig = inspect.signature(fn) 60 except ( 61 ValueError 62 ): # Signature cannot be inspected, likely it is a built-in fn or written in C 63 return 64 if isinstance(input_col, (list, tuple)): 65 input_col_size = len(input_col) 66 else: 67 input_col_size = 1 68 69 pos = [] 70 var_positional = False 71 non_default_kw_only = [] 72 73 for p in sig.parameters.values(): 74 if p.kind in ( 75 inspect.Parameter.POSITIONAL_ONLY, 76 inspect.Parameter.POSITIONAL_OR_KEYWORD, 77 ): 78 pos.append(p) 79 elif p.kind is inspect.Parameter.VAR_POSITIONAL: 80 var_positional = True 81 elif p.kind is inspect.Parameter.KEYWORD_ONLY: 82 if p.default is p.empty: 83 non_default_kw_only.append(p) 84 else: 85 continue 86 87 if isinstance(fn, functools.partial): 88 fn_name = getattr(fn.func, "__name__", repr(fn.func)) 89 else: 90 fn_name = getattr(fn, "__name__", repr(fn)) 91 92 if len(non_default_kw_only) > 0: 93 raise ValueError( 94 f"The function {fn_name} takes {len(non_default_kw_only)} " 95 f"non-default keyword-only parameters, which is not allowed." 96 ) 97 98 if len(sig.parameters) < input_col_size: 99 if not var_positional: 100 raise ValueError( 101 f"The function {fn_name} takes {len(sig.parameters)} " 102 f"parameters, but {input_col_size} are required." 103 ) 104 else: 105 if len(pos) > input_col_size: 106 if any(p.default is p.empty for p in pos[input_col_size:]): 107 raise ValueError( 108 f"The function {fn_name} takes {len(pos)} " 109 f"positional parameters, but {input_col_size} are required." 110 ) 111 elif len(pos) < input_col_size: 112 if not var_positional: 113 raise ValueError( 114 f"The function {fn_name} takes {len(pos)} " 115 f"positional parameters, but {input_col_size} are required." 116 ) 117 118 119def _is_local_fn(fn): 120 # Functions or Methods 121 if hasattr(fn, "__code__"): 122 return fn.__code__.co_flags & inspect.CO_NESTED 123 # Callable Objects 124 else: 125 if hasattr(fn, "__qualname__"): 126 return "<locals>" in fn.__qualname__ 127 fn_type = type(fn) 128 if hasattr(fn_type, "__qualname__"): 129 return "<locals>" in fn_type.__qualname__ 130 return False 131 132 133def _check_unpickable_fn(fn: Callable): 134 """ 135 Check function is pickable or not. 136 137 If it is a lambda or local function, a UserWarning will be raised. If it's not a callable function, a TypeError will be raised. 138 """ 139 if not callable(fn): 140 raise TypeError(f"A callable function is expected, but {type(fn)} is provided.") 141 142 # Extract function from partial object 143 # Nested partial function is automatically expanded as a single partial object 144 if isinstance(fn, functools.partial): 145 fn = fn.func 146 147 # Local function 148 if _is_local_fn(fn) and not dill_available(): 149 warnings.warn( 150 "Local function is not supported by pickle, please use " 151 "regular python function or functools.partial instead." 152 ) 153 return 154 155 # Lambda function 156 if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not dill_available(): 157 warnings.warn( 158 "Lambda function is not supported by pickle, please use " 159 "regular python function or functools.partial instead." 160 ) 161 return 162 163 164def match_masks(name: str, masks: Union[str, List[str]]) -> bool: 165 # empty mask matches any input name 166 if not masks: 167 return True 168 169 if isinstance(masks, str): 170 return fnmatch.fnmatch(name, masks) 171 172 for mask in masks: 173 if fnmatch.fnmatch(name, mask): 174 return True 175 return False 176 177 178def get_file_pathnames_from_root( 179 root: str, 180 masks: Union[str, List[str]], 181 recursive: bool = False, 182 abspath: bool = False, 183 non_deterministic: bool = False, 184) -> Iterable[str]: 185 # print out an error message and raise the error out 186 def onerror(err: OSError): 187 warnings.warn(err.filename + " : " + err.strerror) 188 raise err 189 190 if os.path.isfile(root): 191 path = root 192 if abspath: 193 path = os.path.abspath(path) 194 fname = os.path.basename(path) 195 if match_masks(fname, masks): 196 yield path 197 else: 198 for path, dirs, files in os.walk(root, onerror=onerror): 199 if abspath: 200 path = os.path.abspath(path) 201 if not non_deterministic: 202 files.sort() 203 for f in files: 204 if match_masks(f, masks): 205 yield os.path.join(path, f) 206 if not recursive: 207 break 208 if not non_deterministic: 209 # Note that this is in-place modifying the internal list from `os.walk` 210 # This only works because `os.walk` doesn't shallow copy before turn 211 # https://github.com/python/cpython/blob/f4c03484da59049eb62a9bf7777b963e2267d187/Lib/os.py#L407 212 dirs.sort() 213 214 215def get_file_binaries_from_pathnames( 216 pathnames: Iterable, mode: str, encoding: Optional[str] = None 217): 218 if not isinstance(pathnames, Iterable): 219 pathnames = [ 220 pathnames, 221 ] 222 223 if mode in ("b", "t"): 224 mode = "r" + mode 225 226 for pathname in pathnames: 227 if not isinstance(pathname, str): 228 raise TypeError( 229 f"Expected string type for pathname, but got {type(pathname)}" 230 ) 231 yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding)) 232 233 234def validate_pathname_binary_tuple(data: Tuple[str, IOBase]): 235 if not isinstance(data, tuple): 236 raise TypeError( 237 f"pathname binary data should be tuple type, but it is type {type(data)}" 238 ) 239 if len(data) != 2: 240 raise TypeError( 241 f"pathname binary stream tuple length should be 2, but got {len(data)}" 242 ) 243 if not isinstance(data[0], str): 244 raise TypeError( 245 f"pathname within the tuple should have string type pathname, but it is type {type(data[0])}" 246 ) 247 if not isinstance(data[1], IOBase) and not isinstance(data[1], StreamWrapper): 248 raise TypeError( 249 f"binary stream within the tuple should have IOBase or" 250 f"its subclasses as type, but it is type {type(data[1])}" 251 ) 252 253 254# Deprecated function names and its corresponding DataPipe type and kwargs for the `_deprecation_warning` function 255_iter_deprecated_functional_names: Dict[str, Dict] = {} 256_map_deprecated_functional_names: Dict[str, Dict] = {} 257 258 259def _deprecation_warning( 260 old_class_name: str, 261 *, 262 deprecation_version: str, 263 removal_version: str, 264 old_functional_name: str = "", 265 old_argument_name: str = "", 266 new_class_name: str = "", 267 new_functional_name: str = "", 268 new_argument_name: str = "", 269 deprecate_functional_name_only: bool = False, 270) -> None: 271 if new_functional_name and not old_functional_name: 272 raise ValueError( 273 "Old functional API needs to be specified for the deprecation warning." 274 ) 275 if new_argument_name and not old_argument_name: 276 raise ValueError( 277 "Old argument name needs to be specified for the deprecation warning." 278 ) 279 280 if old_functional_name and old_argument_name: 281 raise ValueError( 282 "Deprecating warning for functional API and argument should be separated." 283 ) 284 285 msg = f"`{old_class_name}()`" 286 if deprecate_functional_name_only and old_functional_name: 287 msg = f"{msg}'s functional API `.{old_functional_name}()` is" 288 elif old_functional_name: 289 msg = f"{msg} and its functional API `.{old_functional_name}()` are" 290 elif old_argument_name: 291 msg = f"The argument `{old_argument_name}` of {msg} is" 292 else: 293 msg = f"{msg} is" 294 msg = ( 295 f"{msg} deprecated since {deprecation_version} and will be removed in {removal_version}." 296 f"\nSee https://github.com/pytorch/data/issues/163 for details." 297 ) 298 299 if new_class_name or new_functional_name: 300 msg = f"{msg}\nPlease use" 301 if new_class_name: 302 msg = f"{msg} `{new_class_name}()`" 303 if new_class_name and new_functional_name: 304 msg = f"{msg} or" 305 if new_functional_name: 306 msg = f"{msg} `.{new_functional_name}()`" 307 msg = f"{msg} instead." 308 309 if new_argument_name: 310 msg = f"{msg}\nPlease use `{old_class_name}({new_argument_name}=)` instead." 311 312 warnings.warn(msg, FutureWarning) 313 314 315class StreamWrapper: 316 """ 317 StreamWrapper is introduced to wrap file handler generated by DataPipe operation like `FileOpener`. 318 319 StreamWrapper would guarantee the wrapped file handler is closed when it's out of scope. 320 """ 321 322 session_streams: Dict[Any, int] = {} 323 debug_unclosed_streams: bool = False 324 325 def __init__(self, file_obj, parent_stream=None, name=None): 326 self.file_obj = file_obj 327 self.child_counter = 0 328 self.parent_stream = parent_stream 329 self.close_on_last_child = False 330 self.name = name 331 self.closed = False 332 if parent_stream is not None: 333 if not isinstance(parent_stream, StreamWrapper): 334 raise RuntimeError( 335 f"Parent stream should be StreamWrapper, {type(parent_stream)} was given" 336 ) 337 parent_stream.child_counter += 1 338 self.parent_stream = parent_stream 339 if StreamWrapper.debug_unclosed_streams: 340 StreamWrapper.session_streams[self] = 1 341 342 @classmethod 343 def close_streams(cls, v, depth=0): 344 """Traverse structure and attempts to close all found StreamWrappers on best effort basis.""" 345 if depth > 10: 346 return 347 if isinstance(v, StreamWrapper): 348 v.close() 349 else: 350 # Traverse only simple structures 351 if isinstance(v, dict): 352 for vv in v.values(): 353 cls.close_streams(vv, depth=depth + 1) 354 elif isinstance(v, (list, tuple)): 355 for vv in v: 356 cls.close_streams(vv, depth=depth + 1) 357 358 def __getattr__(self, name): 359 file_obj = self.__dict__["file_obj"] 360 return getattr(file_obj, name) 361 362 def close(self, *args, **kwargs): 363 if self.closed: 364 return 365 if StreamWrapper.debug_unclosed_streams: 366 del StreamWrapper.session_streams[self] 367 if hasattr(self, "parent_stream") and self.parent_stream is not None: 368 self.parent_stream.child_counter -= 1 369 if ( 370 not self.parent_stream.child_counter 371 and self.parent_stream.close_on_last_child 372 ): 373 self.parent_stream.close() 374 try: 375 self.file_obj.close(*args, **kwargs) 376 except AttributeError: 377 pass 378 self.closed = True 379 380 def autoclose(self): 381 """Automatically close stream when all child streams are closed or if there are none.""" 382 self.close_on_last_child = True 383 if self.child_counter == 0: 384 self.close() 385 386 def __dir__(self): 387 attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys()) 388 attrs += dir(self.file_obj) 389 return list(set(attrs)) 390 391 def __del__(self): 392 if not self.closed: 393 self.close() 394 395 def __iter__(self): 396 yield from self.file_obj 397 398 def __next__(self): 399 return next(self.file_obj) 400 401 def __repr__(self): 402 if self.name is None: 403 return f"StreamWrapper<{self.file_obj!r}>" 404 else: 405 return f"StreamWrapper<{self.name},{self.file_obj!r}>" 406 407 def __getstate__(self): 408 return self.file_obj 409 410 def __setstate__(self, obj): 411 self.file_obj = obj 412