xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/utils/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import fnmatch
3import functools
4import inspect
5import os
6import warnings
7from io import IOBase
8from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
9
10from torch.utils._import_utils import dill_available
11
12
13__all__ = [
14    "validate_input_col",
15    "StreamWrapper",
16    "get_file_binaries_from_pathnames",
17    "get_file_pathnames_from_root",
18    "match_masks",
19    "validate_pathname_binary_tuple",
20]
21
22
23# BC for torchdata
24DILL_AVAILABLE = dill_available()
25
26
27def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]):
28    """
29    Check that function used in a callable datapipe works with the input column.
30
31    This simply ensures that the number of positional arguments matches the size
32    of the input column. The function must not contain any non-default
33    keyword-only arguments.
34
35    Examples:
36        >>> # xdoctest: +SKIP("Failing on some CI machines")
37        >>> def f(a, b, *, c=1):
38        >>>     return a + b + c
39        >>> def f_def(a, b=1, *, c=1):
40        >>>     return a + b + c
41        >>> assert validate_input_col(f, [1, 2])
42        >>> assert validate_input_col(f_def, 1)
43        >>> assert validate_input_col(f_def, [1, 2])
44
45    Notes:
46        If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments,
47        for example, f(a, *args), the validator will accept any size of input column
48        greater than or equal to the number of positional arguments.
49        (in this case, 1).
50
51    Args:
52        fn: The function to check.
53        input_col: The input column to check.
54
55    Raises:
56        ValueError: If the function is not compatible with the input column.
57    """
58    try:
59        sig = inspect.signature(fn)
60    except (
61        ValueError
62    ):  # Signature cannot be inspected, likely it is a built-in fn or written in C
63        return
64    if isinstance(input_col, (list, tuple)):
65        input_col_size = len(input_col)
66    else:
67        input_col_size = 1
68
69    pos = []
70    var_positional = False
71    non_default_kw_only = []
72
73    for p in sig.parameters.values():
74        if p.kind in (
75            inspect.Parameter.POSITIONAL_ONLY,
76            inspect.Parameter.POSITIONAL_OR_KEYWORD,
77        ):
78            pos.append(p)
79        elif p.kind is inspect.Parameter.VAR_POSITIONAL:
80            var_positional = True
81        elif p.kind is inspect.Parameter.KEYWORD_ONLY:
82            if p.default is p.empty:
83                non_default_kw_only.append(p)
84        else:
85            continue
86
87    if isinstance(fn, functools.partial):
88        fn_name = getattr(fn.func, "__name__", repr(fn.func))
89    else:
90        fn_name = getattr(fn, "__name__", repr(fn))
91
92    if len(non_default_kw_only) > 0:
93        raise ValueError(
94            f"The function {fn_name} takes {len(non_default_kw_only)} "
95            f"non-default keyword-only parameters, which is not allowed."
96        )
97
98    if len(sig.parameters) < input_col_size:
99        if not var_positional:
100            raise ValueError(
101                f"The function {fn_name} takes {len(sig.parameters)} "
102                f"parameters, but {input_col_size} are required."
103            )
104    else:
105        if len(pos) > input_col_size:
106            if any(p.default is p.empty for p in pos[input_col_size:]):
107                raise ValueError(
108                    f"The function {fn_name} takes {len(pos)} "
109                    f"positional parameters, but {input_col_size} are required."
110                )
111        elif len(pos) < input_col_size:
112            if not var_positional:
113                raise ValueError(
114                    f"The function {fn_name} takes {len(pos)} "
115                    f"positional parameters, but {input_col_size} are required."
116                )
117
118
119def _is_local_fn(fn):
120    # Functions or Methods
121    if hasattr(fn, "__code__"):
122        return fn.__code__.co_flags & inspect.CO_NESTED
123    # Callable Objects
124    else:
125        if hasattr(fn, "__qualname__"):
126            return "<locals>" in fn.__qualname__
127        fn_type = type(fn)
128        if hasattr(fn_type, "__qualname__"):
129            return "<locals>" in fn_type.__qualname__
130    return False
131
132
133def _check_unpickable_fn(fn: Callable):
134    """
135    Check function is pickable or not.
136
137    If it is a lambda or local function, a UserWarning will be raised. If it's not a callable function, a TypeError will be raised.
138    """
139    if not callable(fn):
140        raise TypeError(f"A callable function is expected, but {type(fn)} is provided.")
141
142    # Extract function from partial object
143    # Nested partial function is automatically expanded as a single partial object
144    if isinstance(fn, functools.partial):
145        fn = fn.func
146
147    # Local function
148    if _is_local_fn(fn) and not dill_available():
149        warnings.warn(
150            "Local function is not supported by pickle, please use "
151            "regular python function or functools.partial instead."
152        )
153        return
154
155    # Lambda function
156    if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not dill_available():
157        warnings.warn(
158            "Lambda function is not supported by pickle, please use "
159            "regular python function or functools.partial instead."
160        )
161        return
162
163
164def match_masks(name: str, masks: Union[str, List[str]]) -> bool:
165    # empty mask matches any input name
166    if not masks:
167        return True
168
169    if isinstance(masks, str):
170        return fnmatch.fnmatch(name, masks)
171
172    for mask in masks:
173        if fnmatch.fnmatch(name, mask):
174            return True
175    return False
176
177
178def get_file_pathnames_from_root(
179    root: str,
180    masks: Union[str, List[str]],
181    recursive: bool = False,
182    abspath: bool = False,
183    non_deterministic: bool = False,
184) -> Iterable[str]:
185    # print out an error message and raise the error out
186    def onerror(err: OSError):
187        warnings.warn(err.filename + " : " + err.strerror)
188        raise err
189
190    if os.path.isfile(root):
191        path = root
192        if abspath:
193            path = os.path.abspath(path)
194        fname = os.path.basename(path)
195        if match_masks(fname, masks):
196            yield path
197    else:
198        for path, dirs, files in os.walk(root, onerror=onerror):
199            if abspath:
200                path = os.path.abspath(path)
201            if not non_deterministic:
202                files.sort()
203            for f in files:
204                if match_masks(f, masks):
205                    yield os.path.join(path, f)
206            if not recursive:
207                break
208            if not non_deterministic:
209                # Note that this is in-place modifying the internal list from `os.walk`
210                # This only works because `os.walk` doesn't shallow copy before turn
211                # https://github.com/python/cpython/blob/f4c03484da59049eb62a9bf7777b963e2267d187/Lib/os.py#L407
212                dirs.sort()
213
214
215def get_file_binaries_from_pathnames(
216    pathnames: Iterable, mode: str, encoding: Optional[str] = None
217):
218    if not isinstance(pathnames, Iterable):
219        pathnames = [
220            pathnames,
221        ]
222
223    if mode in ("b", "t"):
224        mode = "r" + mode
225
226    for pathname in pathnames:
227        if not isinstance(pathname, str):
228            raise TypeError(
229                f"Expected string type for pathname, but got {type(pathname)}"
230            )
231        yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding))
232
233
234def validate_pathname_binary_tuple(data: Tuple[str, IOBase]):
235    if not isinstance(data, tuple):
236        raise TypeError(
237            f"pathname binary data should be tuple type, but it is type {type(data)}"
238        )
239    if len(data) != 2:
240        raise TypeError(
241            f"pathname binary stream tuple length should be 2, but got {len(data)}"
242        )
243    if not isinstance(data[0], str):
244        raise TypeError(
245            f"pathname within the tuple should have string type pathname, but it is type {type(data[0])}"
246        )
247    if not isinstance(data[1], IOBase) and not isinstance(data[1], StreamWrapper):
248        raise TypeError(
249            f"binary stream within the tuple should have IOBase or"
250            f"its subclasses as type, but it is type {type(data[1])}"
251        )
252
253
254# Deprecated function names and its corresponding DataPipe type and kwargs for the `_deprecation_warning` function
255_iter_deprecated_functional_names: Dict[str, Dict] = {}
256_map_deprecated_functional_names: Dict[str, Dict] = {}
257
258
259def _deprecation_warning(
260    old_class_name: str,
261    *,
262    deprecation_version: str,
263    removal_version: str,
264    old_functional_name: str = "",
265    old_argument_name: str = "",
266    new_class_name: str = "",
267    new_functional_name: str = "",
268    new_argument_name: str = "",
269    deprecate_functional_name_only: bool = False,
270) -> None:
271    if new_functional_name and not old_functional_name:
272        raise ValueError(
273            "Old functional API needs to be specified for the deprecation warning."
274        )
275    if new_argument_name and not old_argument_name:
276        raise ValueError(
277            "Old argument name needs to be specified for the deprecation warning."
278        )
279
280    if old_functional_name and old_argument_name:
281        raise ValueError(
282            "Deprecating warning for functional API and argument should be separated."
283        )
284
285    msg = f"`{old_class_name}()`"
286    if deprecate_functional_name_only and old_functional_name:
287        msg = f"{msg}'s functional API `.{old_functional_name}()` is"
288    elif old_functional_name:
289        msg = f"{msg} and its functional API `.{old_functional_name}()` are"
290    elif old_argument_name:
291        msg = f"The argument `{old_argument_name}` of {msg} is"
292    else:
293        msg = f"{msg} is"
294    msg = (
295        f"{msg} deprecated since {deprecation_version} and will be removed in {removal_version}."
296        f"\nSee https://github.com/pytorch/data/issues/163 for details."
297    )
298
299    if new_class_name or new_functional_name:
300        msg = f"{msg}\nPlease use"
301        if new_class_name:
302            msg = f"{msg} `{new_class_name}()`"
303        if new_class_name and new_functional_name:
304            msg = f"{msg} or"
305        if new_functional_name:
306            msg = f"{msg} `.{new_functional_name}()`"
307        msg = f"{msg} instead."
308
309    if new_argument_name:
310        msg = f"{msg}\nPlease use `{old_class_name}({new_argument_name}=)` instead."
311
312    warnings.warn(msg, FutureWarning)
313
314
315class StreamWrapper:
316    """
317    StreamWrapper is introduced to wrap file handler generated by DataPipe operation like `FileOpener`.
318
319    StreamWrapper would guarantee the wrapped file handler is closed when it's out of scope.
320    """
321
322    session_streams: Dict[Any, int] = {}
323    debug_unclosed_streams: bool = False
324
325    def __init__(self, file_obj, parent_stream=None, name=None):
326        self.file_obj = file_obj
327        self.child_counter = 0
328        self.parent_stream = parent_stream
329        self.close_on_last_child = False
330        self.name = name
331        self.closed = False
332        if parent_stream is not None:
333            if not isinstance(parent_stream, StreamWrapper):
334                raise RuntimeError(
335                    f"Parent stream should be StreamWrapper, {type(parent_stream)} was given"
336                )
337            parent_stream.child_counter += 1
338            self.parent_stream = parent_stream
339        if StreamWrapper.debug_unclosed_streams:
340            StreamWrapper.session_streams[self] = 1
341
342    @classmethod
343    def close_streams(cls, v, depth=0):
344        """Traverse structure and attempts to close all found StreamWrappers on best effort basis."""
345        if depth > 10:
346            return
347        if isinstance(v, StreamWrapper):
348            v.close()
349        else:
350            # Traverse only simple structures
351            if isinstance(v, dict):
352                for vv in v.values():
353                    cls.close_streams(vv, depth=depth + 1)
354            elif isinstance(v, (list, tuple)):
355                for vv in v:
356                    cls.close_streams(vv, depth=depth + 1)
357
358    def __getattr__(self, name):
359        file_obj = self.__dict__["file_obj"]
360        return getattr(file_obj, name)
361
362    def close(self, *args, **kwargs):
363        if self.closed:
364            return
365        if StreamWrapper.debug_unclosed_streams:
366            del StreamWrapper.session_streams[self]
367        if hasattr(self, "parent_stream") and self.parent_stream is not None:
368            self.parent_stream.child_counter -= 1
369            if (
370                not self.parent_stream.child_counter
371                and self.parent_stream.close_on_last_child
372            ):
373                self.parent_stream.close()
374        try:
375            self.file_obj.close(*args, **kwargs)
376        except AttributeError:
377            pass
378        self.closed = True
379
380    def autoclose(self):
381        """Automatically close stream when all child streams are closed or if there are none."""
382        self.close_on_last_child = True
383        if self.child_counter == 0:
384            self.close()
385
386    def __dir__(self):
387        attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys())
388        attrs += dir(self.file_obj)
389        return list(set(attrs))
390
391    def __del__(self):
392        if not self.closed:
393            self.close()
394
395    def __iter__(self):
396        yield from self.file_obj
397
398    def __next__(self):
399        return next(self.file_obj)
400
401    def __repr__(self):
402        if self.name is None:
403            return f"StreamWrapper<{self.file_obj!r}>"
404        else:
405            return f"StreamWrapper<{self.name},{self.file_obj!r}>"
406
407    def __getstate__(self):
408        return self.file_obj
409
410    def __setstate__(self, obj):
411        self.file_obj = obj
412