1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 5 namespace torch::jit { 6 7 // Used in torch.package and TorchScript serialization to coordinate 8 // sharing of storages between models. Also used to create deterministic 9 // naming for storages. 10 class TORCH_API SerializationStorageContext { 11 public: 12 explicit SerializationStorageContext() = default; 13 SerializationStorageContext operator=(const SerializationStorageContext&) = 14 delete; 15 SerializationStorageContext(const SerializationStorageContext&) = delete; 16 getOrAddStorage(const c10::Storage & storage)17 uint64_t getOrAddStorage(const c10::Storage& storage) { 18 if (!hasStorage(storage)) { 19 uint64_t size = storage_id_map_.size(); 20 storage_id_map_[storage] = size; 21 } 22 return storage_id_map_[storage]; 23 } 24 hasStorage(const c10::Storage & storage)25 bool hasStorage(const c10::Storage& storage) { 26 return storage_id_map_.find(storage) != storage_id_map_.end(); 27 } 28 29 ~SerializationStorageContext() = default; 30 31 private: 32 class StorageSerializationHash { 33 public: operator()34 size_t operator()(const c10::Storage& storage) const { 35 return std::hash<void*>()( 36 reinterpret_cast<void*>(storage.unsafeGetStorageImpl())); 37 } 38 }; 39 40 class StorageSerializationEqual { 41 public: operator()42 bool operator()(const c10::Storage& lhs, const c10::Storage& rhs) const { 43 return lhs.unsafeGetStorageImpl() == rhs.unsafeGetStorageImpl(); 44 } 45 }; 46 47 std::unordered_map< 48 c10::Storage, 49 uint64_t, 50 StorageSerializationHash, 51 StorageSerializationEqual> 52 storage_id_map_; 53 }; 54 55 // Used in torch.package and TorchScript deserialization to coordinate 56 // sharing of storages between models. 57 class TORCH_API DeserializationStorageContext { 58 public: 59 explicit DeserializationStorageContext() = default; 60 DeserializationStorageContext operator=( 61 const DeserializationStorageContext&) = delete; 62 DeserializationStorageContext(const DeserializationStorageContext&) = delete; 63 addStorage(std::string name,c10::Storage storage)64 void addStorage(std::string name, c10::Storage storage) { 65 TORCH_INTERNAL_ASSERT(!hasStorage(name)); 66 name_storage_map_.emplace(std::move(name), std::move(storage)); 67 } 68 hasStorage(const std::string & name)69 bool hasStorage(const std::string& name) { 70 return name_storage_map_.find(name) != name_storage_map_.end(); 71 } 72 getStorage(const std::string & name)73 c10::Storage getStorage(const std::string& name) { 74 TORCH_INTERNAL_ASSERT(hasStorage(name)); 75 return name_storage_map_.find(name)->second; 76 } 77 ~DeserializationStorageContext() = default; 78 79 private: 80 std::unordered_map<std::string, c10::Storage> name_storage_map_; 81 }; 82 83 } // namespace torch::jit 84