1# mypy: allow-untyped-defs 2# This file takes partial of the implementation from NVIDIA's webdataset at here: 3# https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py 4 5import io 6import json 7import os.path 8import pickle 9import tempfile 10 11import torch 12from torch.utils.data.datapipes.utils.common import StreamWrapper 13 14 15__all__ = [ 16 "Decoder", 17 "ImageHandler", 18 "MatHandler", 19 "audiohandler", 20 "basichandlers", 21 "extension_extract_fn", 22 "handle_extension", 23 "imagehandler", 24 "mathandler", 25 "videohandler", 26] 27 28 29################################################################ 30# handle basic datatypes 31################################################################ 32def basichandlers(extension: str, data): 33 """Transforms raw data (byte stream) into python objects. 34 35 Looks at the extension and loads the data into a python object supporting 36 the corresponding extension. 37 38 Args: 39 extension (str): The file extension 40 data (byte stream): Data to load into a python object. 41 42 Returns: 43 object: The data loaded into a corresponding python object 44 supporting the extension. 45 46 Example: 47 >>> import pickle 48 >>> data = pickle.dumps('some data') 49 >>> new_data = basichandlers('pickle', data) 50 >>> new_data 51 some data 52 53 The transformation of data for extensions are: 54 - txt, text, transcript: utf-8 decoded data of str format 55 - cls, cls2, class, count, index, inx, id: int 56 - json, jsn: json loaded data 57 - pickle, pyd: pickle loaded data 58 - pt: torch loaded data 59 """ 60 61 if extension in "txt text transcript": 62 return data.decode("utf-8") 63 64 if extension in "cls cls2 class count index inx id".split(): 65 try: 66 return int(data) 67 except ValueError: 68 return None 69 70 if extension in "json jsn": 71 return json.loads(data) 72 73 if extension in "pyd pickle".split(): 74 return pickle.loads(data) 75 76 if extension in "pt".split(): 77 stream = io.BytesIO(data) 78 return torch.load(stream) 79 80 # if extension in "ten tb".split(): 81 # from . import tenbin 82 # return tenbin.decode_buffer(data) 83 84 # if extension in "mp msgpack msg".split(): 85 # import msgpack 86 # return msgpack.unpackb(data) 87 88 return None 89 90 91################################################################ 92# handle images 93################################################################ 94imagespecs = { 95 "l8": ("numpy", "uint8", "l"), 96 "rgb8": ("numpy", "uint8", "rgb"), 97 "rgba8": ("numpy", "uint8", "rgba"), 98 "l": ("numpy", "float", "l"), 99 "rgb": ("numpy", "float", "rgb"), 100 "rgba": ("numpy", "float", "rgba"), 101 "torchl8": ("torch", "uint8", "l"), 102 "torchrgb8": ("torch", "uint8", "rgb"), 103 "torchrgba8": ("torch", "uint8", "rgba"), 104 "torchl": ("torch", "float", "l"), 105 "torchrgb": ("torch", "float", "rgb"), 106 "torch": ("torch", "float", "rgb"), 107 "torchrgba": ("torch", "float", "rgba"), 108 "pill": ("pil", None, "l"), 109 "pil": ("pil", None, "rgb"), 110 "pilrgb": ("pil", None, "rgb"), 111 "pilrgba": ("pil", None, "rgba"), 112} 113 114 115def handle_extension(extensions, f): 116 """ 117 Return a decoder handler function for the list of extensions. 118 119 Extensions can be a space separated list of extensions. 120 Extensions can contain dots, in which case the corresponding number 121 of extension components must be present in the key given to f. 122 Comparisons are case insensitive. 123 Examples: 124 handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg 125 handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg 126 """ 127 extensions = extensions.lower().split() 128 129 def g(key, data): 130 extension = key.lower().split(".") 131 132 for target in extensions: 133 target = target.split(".") 134 if len(target) > len(extension): 135 continue 136 137 if extension[-len(target) :] == target: 138 return f(data) 139 return None 140 141 return g 142 143 144class ImageHandler: 145 """ 146 Decode image data using the given `imagespec`. 147 148 The `imagespec` specifies whether the image is decoded 149 to numpy/torch/pi, decoded to uint8/float, and decoded 150 to l/rgb/rgba: 151 152 - l8: numpy uint8 l 153 - rgb8: numpy uint8 rgb 154 - rgba8: numpy uint8 rgba 155 - l: numpy float l 156 - rgb: numpy float rgb 157 - rgba: numpy float rgba 158 - torchl8: torch uint8 l 159 - torchrgb8: torch uint8 rgb 160 - torchrgba8: torch uint8 rgba 161 - torchl: torch float l 162 - torchrgb: torch float rgb 163 - torch: torch float rgb 164 - torchrgba: torch float rgba 165 - pill: pil None l 166 - pil: pil None rgb 167 - pilrgb: pil None rgb 168 - pilrgba: pil None rgba 169 """ 170 171 def __init__(self, imagespec): 172 assert imagespec in list( 173 imagespecs.keys() 174 ), f"unknown image specification: {imagespec}" 175 self.imagespec = imagespec.lower() 176 177 def __call__(self, extension, data): 178 if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split(): 179 return None 180 181 try: 182 import numpy as np 183 except ModuleNotFoundError as e: 184 raise ModuleNotFoundError( 185 "Package `numpy` is required to be installed for default image decoder." 186 "Please use `pip install numpy` to install the package" 187 ) from e 188 189 try: 190 import PIL.Image 191 except ModuleNotFoundError as e: 192 raise ModuleNotFoundError( 193 "Package `PIL` is required to be installed for default image decoder." 194 "Please use `pip install Pillow` to install the package" 195 ) from e 196 197 imagespec = self.imagespec 198 atype, etype, mode = imagespecs[imagespec] 199 200 with io.BytesIO(data) as stream: 201 img = PIL.Image.open(stream) 202 img.load() 203 img = img.convert(mode.upper()) 204 if atype == "pil": 205 return img 206 elif atype == "numpy": 207 result = np.asarray(img) 208 assert ( 209 result.dtype == np.uint8 210 ), f"numpy image array should be type uint8, but got {result.dtype}" 211 if etype == "uint8": 212 return result 213 else: 214 return result.astype("f") / 255.0 215 elif atype == "torch": 216 result = np.asarray(img) 217 assert ( 218 result.dtype == np.uint8 219 ), f"numpy image array should be type uint8, but got {result.dtype}" 220 221 if etype == "uint8": 222 result = np.array(result.transpose(2, 0, 1)) 223 return torch.tensor(result) 224 else: 225 result = np.array(result.transpose(2, 0, 1)) 226 return torch.tensor(result) / 255.0 227 return None 228 229 230def imagehandler(imagespec): 231 return ImageHandler(imagespec) 232 233 234################################################################ 235# torch video 236################################################################ 237def videohandler(extension, data): 238 if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): 239 return None 240 241 try: 242 import torchvision.io 243 except ImportError as e: 244 raise ModuleNotFoundError( 245 "Package `torchvision` is required to be installed for default video file loader." 246 "Please use `pip install torchvision` or `conda install torchvision -c pytorch`" 247 "to install the package" 248 ) from e 249 250 with tempfile.TemporaryDirectory() as dirname: 251 fname = os.path.join(dirname, f"file.{extension}") 252 with open(fname, "wb") as stream: 253 stream.write(data) 254 return torchvision.io.read_video(fname) 255 256 257################################################################ 258# torchaudio 259################################################################ 260def audiohandler(extension, data): 261 if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: 262 return None 263 264 try: 265 import torchaudio # type: ignore[import] 266 except ImportError as e: 267 raise ModuleNotFoundError( 268 "Package `torchaudio` is required to be installed for default audio file loader." 269 "Please use `pip install torchaudio` or `conda install torchaudio -c pytorch`" 270 "to install the package" 271 ) from e 272 273 with tempfile.TemporaryDirectory() as dirname: 274 fname = os.path.join(dirname, f"file.{extension}") 275 with open(fname, "wb") as stream: 276 stream.write(data) 277 return torchaudio.load(fname) 278 279 280################################################################ 281# mat 282################################################################ 283class MatHandler: 284 def __init__(self, **loadmat_kwargs) -> None: 285 try: 286 import scipy.io as sio 287 except ImportError as e: 288 raise ModuleNotFoundError( 289 "Package `scipy` is required to be installed for mat file." 290 "Please use `pip install scipy` or `conda install scipy`" 291 "to install the package" 292 ) from e 293 self.sio = sio 294 self.loadmat_kwargs = loadmat_kwargs 295 296 def __call__(self, extension, data): 297 if extension != "mat": 298 return None 299 with io.BytesIO(data) as stream: 300 return self.sio.loadmat(stream, **self.loadmat_kwargs) 301 302 303def mathandler(**loadmat_kwargs): 304 return MatHandler(**loadmat_kwargs) 305 306 307################################################################ 308# a sample decoder 309################################################################ 310# Extract extension from pathname 311def extension_extract_fn(pathname): 312 ext = os.path.splitext(pathname)[1] 313 # Remove dot 314 if ext: 315 ext = ext[1:] 316 return ext 317 318 319class Decoder: 320 """ 321 Decode key/data sets using a list of handlers. 322 323 For each key/data item, this iterates through the list of 324 handlers until some handler returns something other than None. 325 """ 326 327 def __init__(self, *handler, key_fn=extension_extract_fn): 328 self.handlers = list(handler) if handler else [] 329 self.key_fn = key_fn 330 331 # Insert new handler from the beginning of handlers list to make sure the new 332 # handler having the highest priority 333 def add_handler(self, *handler): 334 if not handler: 335 return 336 self.handlers = list(handler) + self.handlers 337 338 @staticmethod 339 def _is_stream_handle(data): 340 obj_to_check = data.file_obj if isinstance(data, StreamWrapper) else data 341 return isinstance(obj_to_check, (io.BufferedIOBase, io.RawIOBase)) 342 343 def decode1(self, key, data): 344 if not data: 345 return data 346 347 # if data is a stream handle, we need to read all the content before decoding 348 if Decoder._is_stream_handle(data): 349 ds = data 350 # The behavior of .read can differ between streams (e.g. HTTPResponse), hence this is used instead 351 data = b"".join(data) 352 ds.close() 353 354 for f in self.handlers: 355 result = f(key, data) 356 if result is not None: 357 return result 358 return data 359 360 def decode(self, data): 361 result = {} 362 # single data tuple(pathname, data stream) 363 if isinstance(data, tuple): 364 data = [data] 365 366 if data is not None: 367 for k, v in data: 368 # TODO: xinyu, figure out why Nvidia do this? 369 if k[0] == "_": 370 if isinstance(v, bytes): 371 v = v.decode("utf-8") 372 result[k] = v 373 continue 374 result[k] = self.decode1(self.key_fn(k), v) 375 return result 376 377 def __call__(self, data): 378 return self.decode(data) 379