xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/storage_context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 
5 namespace torch::jit {
6 
7 // Used in torch.package and TorchScript serialization to coordinate
8 // sharing of storages between models. Also used to create deterministic
9 // naming for storages.
10 class TORCH_API SerializationStorageContext {
11  public:
12   explicit SerializationStorageContext() = default;
13   SerializationStorageContext operator=(const SerializationStorageContext&) =
14       delete;
15   SerializationStorageContext(const SerializationStorageContext&) = delete;
16 
getOrAddStorage(const c10::Storage & storage)17   uint64_t getOrAddStorage(const c10::Storage& storage) {
18     if (!hasStorage(storage)) {
19       uint64_t size = storage_id_map_.size();
20       storage_id_map_[storage] = size;
21     }
22     return storage_id_map_[storage];
23   }
24 
hasStorage(const c10::Storage & storage)25   bool hasStorage(const c10::Storage& storage) {
26     return storage_id_map_.find(storage) != storage_id_map_.end();
27   }
28 
29   ~SerializationStorageContext() = default;
30 
31  private:
32   class StorageSerializationHash {
33    public:
operator()34     size_t operator()(const c10::Storage& storage) const {
35       return std::hash<void*>()(
36           reinterpret_cast<void*>(storage.unsafeGetStorageImpl()));
37     }
38   };
39 
40   class StorageSerializationEqual {
41    public:
operator()42     bool operator()(const c10::Storage& lhs, const c10::Storage& rhs) const {
43       return lhs.unsafeGetStorageImpl() == rhs.unsafeGetStorageImpl();
44     }
45   };
46 
47   std::unordered_map<
48       c10::Storage,
49       uint64_t,
50       StorageSerializationHash,
51       StorageSerializationEqual>
52       storage_id_map_;
53 };
54 
55 // Used in torch.package and TorchScript deserialization to coordinate
56 // sharing of storages between models.
57 class TORCH_API DeserializationStorageContext {
58  public:
59   explicit DeserializationStorageContext() = default;
60   DeserializationStorageContext operator=(
61       const DeserializationStorageContext&) = delete;
62   DeserializationStorageContext(const DeserializationStorageContext&) = delete;
63 
addStorage(std::string name,c10::Storage storage)64   void addStorage(std::string name, c10::Storage storage) {
65     TORCH_INTERNAL_ASSERT(!hasStorage(name));
66     name_storage_map_.emplace(std::move(name), std::move(storage));
67   }
68 
hasStorage(const std::string & name)69   bool hasStorage(const std::string& name) {
70     return name_storage_map_.find(name) != name_storage_map_.end();
71   }
72 
getStorage(const std::string & name)73   c10::Storage getStorage(const std::string& name) {
74     TORCH_INTERNAL_ASSERT(hasStorage(name));
75     return name_storage_map_.find(name)->second;
76   }
77   ~DeserializationStorageContext() = default;
78 
79  private:
80   std::unordered_map<std::string, c10::Storage> name_storage_map_;
81 };
82 
83 } // namespace torch::jit
84