xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/routeddecoder.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from io import BufferedIOBase
2from typing import Any, Callable, Iterable, Iterator, Sized, Tuple
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.datapipe import IterDataPipe
6from torch.utils.data.datapipes.utils.common import _deprecation_warning
7from torch.utils.data.datapipes.utils.decoder import (
8    basichandlers as decoder_basichandlers,
9    Decoder,
10    extension_extract_fn,
11    imagehandler as decoder_imagehandler,
12)
13
14
15__all__ = ["RoutedDecoderIterDataPipe"]
16
17
18@functional_datapipe("routed_decode")
19class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]):
20    r"""
21    Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
22
23    (functional name: ``routed_decode``)
24
25    Args:
26        datapipe: Iterable datapipe that provides pathname and binary stream in tuples
27        handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
28            handlers will be set as default. If multiple handles are provided, the priority
29            order follows the order of handlers (the first handler has the top priority)
30        key_fn: Function for decoder to extract key from pathname to dispatch handlers.
31            Default is set to extract file extension from pathname
32
33    Note:
34        When ``key_fn`` is specified returning anything other than extension, the default
35        handler will not work and users need to specify custom handler. Custom handler
36        could use regex to determine the eligibility to handle data.
37    """
38
39    def __init__(
40        self,
41        datapipe: Iterable[Tuple[str, BufferedIOBase]],
42        *handlers: Callable,
43        key_fn: Callable = extension_extract_fn,
44    ) -> None:
45        super().__init__()
46        self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
47        if not handlers:
48            handlers = (decoder_basichandlers, decoder_imagehandler("torch"))
49        self.decoder = Decoder(*handlers, key_fn=key_fn)
50        _deprecation_warning(
51            type(self).__name__,
52            deprecation_version="1.12",
53            removal_version="1.13",
54            old_functional_name="routed_decode",
55        )
56
57    def add_handler(self, *handler: Callable) -> None:
58        self.decoder.add_handler(*handler)
59
60    def __iter__(self) -> Iterator[Tuple[str, Any]]:
61        for data in self.datapipe:
62            pathname = data[0]
63            result = self.decoder(data)
64            yield (pathname, result[pathname])
65
66    def __len__(self) -> int:
67        if isinstance(self.datapipe, Sized):
68            return len(self.datapipe)
69        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
70