xref: /aosp_15_r20/external/pytorch/torch/package/_directory_reader.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os.path
3from glob import glob
4from typing import cast
5
6import torch
7from torch.types import Storage
8
9
10__serialization_id_record_name__ = ".data/serialization_id"
11
12
13# because get_storage_from_record returns a tensor!?
14class _HasStorage:
15    def __init__(self, storage):
16        self._storage = storage
17
18    def storage(self):
19        return self._storage
20
21
22class DirectoryReader:
23    """
24    Class to allow PackageImporter to operate on unzipped packages. Methods
25    copy the behavior of the internal PyTorchFileReader class (which is used for
26    accessing packages in all other cases).
27
28    N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
29    class due to ScriptObjects requiring an actual PyTorchFileReader instance.
30    """
31
32    def __init__(self, directory):
33        self.directory = directory
34
35    def get_record(self, name):
36        filename = f"{self.directory}/{name}"
37        with open(filename, "rb") as f:
38            return f.read()
39
40    def get_storage_from_record(self, name, numel, dtype):
41        filename = f"{self.directory}/{name}"
42        nbytes = torch._utils._element_size(dtype) * numel
43        storage = cast(Storage, torch.UntypedStorage)
44        return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
45
46    def has_record(self, path):
47        full_path = os.path.join(self.directory, path)
48        return os.path.isfile(full_path)
49
50    def get_all_records(
51        self,
52    ):
53        files = []
54        for filename in glob(f"{self.directory}/**", recursive=True):
55            if not os.path.isdir(filename):
56                files.append(filename[len(self.directory) + 1 :])
57        return files
58
59    def serialization_id(
60        self,
61    ):
62        if self.has_record(__serialization_id_record_name__):
63            return self.get_record(__serialization_id_record_name__)
64        else:
65            return ""
66