xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/utils/decoder.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# This file takes partial of the implementation from NVIDIA's webdataset at here:
3# https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py
4
5import io
6import json
7import os.path
8import pickle
9import tempfile
10
11import torch
12from torch.utils.data.datapipes.utils.common import StreamWrapper
13
14
15__all__ = [
16    "Decoder",
17    "ImageHandler",
18    "MatHandler",
19    "audiohandler",
20    "basichandlers",
21    "extension_extract_fn",
22    "handle_extension",
23    "imagehandler",
24    "mathandler",
25    "videohandler",
26]
27
28
29################################################################
30# handle basic datatypes
31################################################################
32def basichandlers(extension: str, data):
33    """Transforms raw data (byte stream) into python objects.
34
35    Looks at the extension and loads the data into a python object supporting
36    the corresponding extension.
37
38    Args:
39        extension (str): The file extension
40        data (byte stream): Data to load into a python object.
41
42    Returns:
43        object: The data loaded into a corresponding python object
44            supporting the extension.
45
46    Example:
47        >>> import pickle
48        >>> data = pickle.dumps('some data')
49        >>> new_data = basichandlers('pickle', data)
50        >>> new_data
51        some data
52
53    The transformation of data for extensions are:
54        - txt, text, transcript: utf-8 decoded data of str format
55        - cls, cls2, class, count, index, inx, id: int
56        - json, jsn: json loaded data
57        - pickle, pyd: pickle loaded data
58        - pt: torch loaded data
59    """
60
61    if extension in "txt text transcript":
62        return data.decode("utf-8")
63
64    if extension in "cls cls2 class count index inx id".split():
65        try:
66            return int(data)
67        except ValueError:
68            return None
69
70    if extension in "json jsn":
71        return json.loads(data)
72
73    if extension in "pyd pickle".split():
74        return pickle.loads(data)
75
76    if extension in "pt".split():
77        stream = io.BytesIO(data)
78        return torch.load(stream)
79
80    # if extension in "ten tb".split():
81    #     from . import tenbin
82    #     return tenbin.decode_buffer(data)
83
84    # if extension in "mp msgpack msg".split():
85    #     import msgpack
86    #     return msgpack.unpackb(data)
87
88    return None
89
90
91################################################################
92# handle images
93################################################################
94imagespecs = {
95    "l8": ("numpy", "uint8", "l"),
96    "rgb8": ("numpy", "uint8", "rgb"),
97    "rgba8": ("numpy", "uint8", "rgba"),
98    "l": ("numpy", "float", "l"),
99    "rgb": ("numpy", "float", "rgb"),
100    "rgba": ("numpy", "float", "rgba"),
101    "torchl8": ("torch", "uint8", "l"),
102    "torchrgb8": ("torch", "uint8", "rgb"),
103    "torchrgba8": ("torch", "uint8", "rgba"),
104    "torchl": ("torch", "float", "l"),
105    "torchrgb": ("torch", "float", "rgb"),
106    "torch": ("torch", "float", "rgb"),
107    "torchrgba": ("torch", "float", "rgba"),
108    "pill": ("pil", None, "l"),
109    "pil": ("pil", None, "rgb"),
110    "pilrgb": ("pil", None, "rgb"),
111    "pilrgba": ("pil", None, "rgba"),
112}
113
114
115def handle_extension(extensions, f):
116    """
117    Return a decoder handler function for the list of extensions.
118
119    Extensions can be a space separated list of extensions.
120    Extensions can contain dots, in which case the corresponding number
121    of extension components must be present in the key given to f.
122    Comparisons are case insensitive.
123    Examples:
124    handle_extension("jpg jpeg", my_decode_jpg)  # invoked for any file.jpg
125    handle_extension("seg.jpg", special_case_jpg)  # invoked only for file.seg.jpg
126    """
127    extensions = extensions.lower().split()
128
129    def g(key, data):
130        extension = key.lower().split(".")
131
132        for target in extensions:
133            target = target.split(".")
134            if len(target) > len(extension):
135                continue
136
137            if extension[-len(target) :] == target:
138                return f(data)
139            return None
140
141    return g
142
143
144class ImageHandler:
145    """
146    Decode image data using the given `imagespec`.
147
148    The `imagespec` specifies whether the image is decoded
149    to numpy/torch/pi, decoded to uint8/float, and decoded
150    to l/rgb/rgba:
151
152    - l8: numpy uint8 l
153    - rgb8: numpy uint8 rgb
154    - rgba8: numpy uint8 rgba
155    - l: numpy float l
156    - rgb: numpy float rgb
157    - rgba: numpy float rgba
158    - torchl8: torch uint8 l
159    - torchrgb8: torch uint8 rgb
160    - torchrgba8: torch uint8 rgba
161    - torchl: torch float l
162    - torchrgb: torch float rgb
163    - torch: torch float rgb
164    - torchrgba: torch float rgba
165    - pill: pil None l
166    - pil: pil None rgb
167    - pilrgb: pil None rgb
168    - pilrgba: pil None rgba
169    """
170
171    def __init__(self, imagespec):
172        assert imagespec in list(
173            imagespecs.keys()
174        ), f"unknown image specification: {imagespec}"
175        self.imagespec = imagespec.lower()
176
177    def __call__(self, extension, data):
178        if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split():
179            return None
180
181        try:
182            import numpy as np
183        except ModuleNotFoundError as e:
184            raise ModuleNotFoundError(
185                "Package `numpy` is required to be installed for default image decoder."
186                "Please use `pip install numpy` to install the package"
187            ) from e
188
189        try:
190            import PIL.Image
191        except ModuleNotFoundError as e:
192            raise ModuleNotFoundError(
193                "Package `PIL` is required to be installed for default image decoder."
194                "Please use `pip install Pillow` to install the package"
195            ) from e
196
197        imagespec = self.imagespec
198        atype, etype, mode = imagespecs[imagespec]
199
200        with io.BytesIO(data) as stream:
201            img = PIL.Image.open(stream)
202            img.load()
203            img = img.convert(mode.upper())
204            if atype == "pil":
205                return img
206            elif atype == "numpy":
207                result = np.asarray(img)
208                assert (
209                    result.dtype == np.uint8
210                ), f"numpy image array should be type uint8, but got {result.dtype}"
211                if etype == "uint8":
212                    return result
213                else:
214                    return result.astype("f") / 255.0
215            elif atype == "torch":
216                result = np.asarray(img)
217                assert (
218                    result.dtype == np.uint8
219                ), f"numpy image array should be type uint8, but got {result.dtype}"
220
221                if etype == "uint8":
222                    result = np.array(result.transpose(2, 0, 1))
223                    return torch.tensor(result)
224                else:
225                    result = np.array(result.transpose(2, 0, 1))
226                    return torch.tensor(result) / 255.0
227            return None
228
229
230def imagehandler(imagespec):
231    return ImageHandler(imagespec)
232
233
234################################################################
235# torch video
236################################################################
237def videohandler(extension, data):
238    if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
239        return None
240
241    try:
242        import torchvision.io
243    except ImportError as e:
244        raise ModuleNotFoundError(
245            "Package `torchvision` is required to be installed for default video file loader."
246            "Please use `pip install torchvision` or `conda install torchvision -c pytorch`"
247            "to install the package"
248        ) from e
249
250    with tempfile.TemporaryDirectory() as dirname:
251        fname = os.path.join(dirname, f"file.{extension}")
252        with open(fname, "wb") as stream:
253            stream.write(data)
254            return torchvision.io.read_video(fname)
255
256
257################################################################
258# torchaudio
259################################################################
260def audiohandler(extension, data):
261    if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
262        return None
263
264    try:
265        import torchaudio  # type: ignore[import]
266    except ImportError as e:
267        raise ModuleNotFoundError(
268            "Package `torchaudio` is required to be installed for default audio file loader."
269            "Please use `pip install torchaudio` or `conda install torchaudio -c pytorch`"
270            "to install the package"
271        ) from e
272
273    with tempfile.TemporaryDirectory() as dirname:
274        fname = os.path.join(dirname, f"file.{extension}")
275        with open(fname, "wb") as stream:
276            stream.write(data)
277            return torchaudio.load(fname)
278
279
280################################################################
281# mat
282################################################################
283class MatHandler:
284    def __init__(self, **loadmat_kwargs) -> None:
285        try:
286            import scipy.io as sio
287        except ImportError as e:
288            raise ModuleNotFoundError(
289                "Package `scipy` is required to be installed for mat file."
290                "Please use `pip install scipy` or `conda install scipy`"
291                "to install the package"
292            ) from e
293        self.sio = sio
294        self.loadmat_kwargs = loadmat_kwargs
295
296    def __call__(self, extension, data):
297        if extension != "mat":
298            return None
299        with io.BytesIO(data) as stream:
300            return self.sio.loadmat(stream, **self.loadmat_kwargs)
301
302
303def mathandler(**loadmat_kwargs):
304    return MatHandler(**loadmat_kwargs)
305
306
307################################################################
308# a sample decoder
309################################################################
310# Extract extension from pathname
311def extension_extract_fn(pathname):
312    ext = os.path.splitext(pathname)[1]
313    # Remove dot
314    if ext:
315        ext = ext[1:]
316    return ext
317
318
319class Decoder:
320    """
321    Decode key/data sets using a list of handlers.
322
323    For each key/data item, this iterates through the list of
324    handlers until some handler returns something other than None.
325    """
326
327    def __init__(self, *handler, key_fn=extension_extract_fn):
328        self.handlers = list(handler) if handler else []
329        self.key_fn = key_fn
330
331    # Insert new handler from the beginning of handlers list to make sure the new
332    # handler having the highest priority
333    def add_handler(self, *handler):
334        if not handler:
335            return
336        self.handlers = list(handler) + self.handlers
337
338    @staticmethod
339    def _is_stream_handle(data):
340        obj_to_check = data.file_obj if isinstance(data, StreamWrapper) else data
341        return isinstance(obj_to_check, (io.BufferedIOBase, io.RawIOBase))
342
343    def decode1(self, key, data):
344        if not data:
345            return data
346
347        # if data is a stream handle, we need to read all the content before decoding
348        if Decoder._is_stream_handle(data):
349            ds = data
350            # The behavior of .read can differ between streams (e.g. HTTPResponse), hence this is used instead
351            data = b"".join(data)
352            ds.close()
353
354        for f in self.handlers:
355            result = f(key, data)
356            if result is not None:
357                return result
358        return data
359
360    def decode(self, data):
361        result = {}
362        # single data tuple(pathname, data stream)
363        if isinstance(data, tuple):
364            data = [data]
365
366        if data is not None:
367            for k, v in data:
368                # TODO: xinyu, figure out why Nvidia do this?
369                if k[0] == "_":
370                    if isinstance(v, bytes):
371                        v = v.decode("utf-8")
372                        result[k] = v
373                        continue
374                result[k] = self.decode1(self.key_fn(k), v)
375        return result
376
377    def __call__(self, data):
378        return self.decode(data)
379