1# mypy: allow-untyped-defs 2import os.path 3from glob import glob 4from typing import cast 5 6import torch 7from torch.types import Storage 8 9 10__serialization_id_record_name__ = ".data/serialization_id" 11 12 13# because get_storage_from_record returns a tensor!? 14class _HasStorage: 15 def __init__(self, storage): 16 self._storage = storage 17 18 def storage(self): 19 return self._storage 20 21 22class DirectoryReader: 23 """ 24 Class to allow PackageImporter to operate on unzipped packages. Methods 25 copy the behavior of the internal PyTorchFileReader class (which is used for 26 accessing packages in all other cases). 27 28 N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader 29 class due to ScriptObjects requiring an actual PyTorchFileReader instance. 30 """ 31 32 def __init__(self, directory): 33 self.directory = directory 34 35 def get_record(self, name): 36 filename = f"{self.directory}/{name}" 37 with open(filename, "rb") as f: 38 return f.read() 39 40 def get_storage_from_record(self, name, numel, dtype): 41 filename = f"{self.directory}/{name}" 42 nbytes = torch._utils._element_size(dtype) * numel 43 storage = cast(Storage, torch.UntypedStorage) 44 return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes)) 45 46 def has_record(self, path): 47 full_path = os.path.join(self.directory, path) 48 return os.path.isfile(full_path) 49 50 def get_all_records( 51 self, 52 ): 53 files = [] 54 for filename in glob(f"{self.directory}/**", recursive=True): 55 if not os.path.isdir(filename): 56 files.append(filename[len(self.directory) + 1 :]) 57 return files 58 59 def serialization_id( 60 self, 61 ): 62 if self.has_record(__serialization_id_record_name__): 63 return self.get_record(__serialization_id_record_name__) 64 else: 65 return "" 66