xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/serialize/input-archive.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/serialize/input-archive.h>
2 
3 #include <torch/types.h>
4 #include <torch/utils.h>
5 
6 #include <c10/util/Exception.h>
7 #include <caffe2/serialize/read_adapter_interface.h>
8 #include <torch/csrc/jit/api/module.h>
9 #include <torch/csrc/jit/serialization/import.h>
10 
11 #include <istream>
12 #include <memory>
13 #include <string>
14 #include <utility>
15 
16 namespace torch {
17 namespace serialize {
18 
InputArchive()19 InputArchive::InputArchive()
20     : module_("Module", std::make_shared<jit::CompilationUnit>()) {}
21 
read(const std::string & key,c10::IValue & ivalue)22 void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
23   ivalue = module_.attr(key);
24 }
25 
try_read(const std::string & key,c10::IValue & ivalue)26 bool InputArchive::try_read(const std::string& key, c10::IValue& ivalue) {
27   if (!module_.hasattr(key)) {
28     return false;
29   }
30   ivalue = module_.attr(key);
31   return true;
32 }
33 
try_read(const std::string & key,Tensor & tensor,bool is_buffer)34 bool InputArchive::try_read(
35     const std::string& key,
36     Tensor& tensor,
37     bool is_buffer) {
38   if (!module_.hasattr(key)) {
39     return false;
40   }
41   auto iv = module_.attr(key);
42   if (!iv.isTensor()) {
43     return false;
44   }
45   auto read_tensor = iv.toTensor();
46   // clang-format on
47   if (tensor.defined()) {
48     torch::NoGradGuard guard;
49     if (tensor.device() != read_tensor.device()) {
50       tensor.set_data(read_tensor);
51     } else {
52       tensor.set_(read_tensor);
53     }
54   } else {
55     tensor = std::move(read_tensor);
56   }
57   return true;
58 }
59 
read(const std::string & key,Tensor & tensor,bool is_buffer)60 void InputArchive::read(
61     const std::string& key,
62     Tensor& tensor,
63     bool is_buffer) {
64   TORCH_CHECK(
65       try_read(key, tensor, is_buffer),
66       "No such serialized tensor '",
67       hierarchy_prefix_,
68       key,
69       "'");
70 }
71 
try_read(const std::string & key,InputArchive & archive)72 bool InputArchive::try_read(const std::string& key, InputArchive& archive) {
73   if (!module_.hasattr(key)) {
74     return false;
75   }
76   auto iv = module_.attr(key);
77   if (!iv.isModule()) {
78     return false;
79   }
80   archive.module_ = iv.toModule();
81   archive.hierarchy_prefix_ = hierarchy_prefix_ + key + ".";
82   return true;
83 }
84 
read(const std::string & key,InputArchive & archive)85 void InputArchive::read(const std::string& key, InputArchive& archive) {
86   TORCH_CHECK(
87       try_read(key, archive),
88       "No such serialized submodule: '",
89       hierarchy_prefix_,
90       key,
91       "'");
92 }
93 
load_from(const std::string & filename,std::optional<torch::Device> device)94 void InputArchive::load_from(
95     const std::string& filename,
96     std::optional<torch::Device> device /*= std::nullopt*/) {
97   module_ = torch::jit::load(filename, std::move(device));
98 }
99 
load_from(std::istream & stream,std::optional<torch::Device> device)100 void InputArchive::load_from(
101     std::istream& stream,
102     std::optional<torch::Device> device /*= std::nullopt*/) {
103   module_ = torch::jit::load(stream, std::move(device));
104 }
105 
load_from(const char * data,size_t size,std::optional<torch::Device> device)106 void InputArchive::load_from(
107     const char* data,
108     size_t size,
109     std::optional<torch::Device> device /*= std::nullopt*/) {
110   using caffe2::serialize::ReadAdapterInterface;
111   class OurAdapter : public ReadAdapterInterface {
112    public:
113     OurAdapter(const char* data, size_t size) : data_(data), size_(size) {}
114     size_t size() const override {
115       return size_;
116     }
117     size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
118         const override {
119       (void)what;
120       if (pos >= size_) {
121         return 0;
122       }
123       size_t nread = std::min(static_cast<size_t>(pos) + n, size_) - pos;
124       memcpy(buf, data_ + pos, nread);
125       return nread;
126     }
127 
128    private:
129     const char* data_;
130     size_t size_;
131   };
132   module_ = torch::jit::load(
133       std::make_unique<OurAdapter>(data, size), std::move(device));
134 }
135 
load_from(const std::function<size_t (uint64_t,void *,size_t)> & read_func,const std::function<size_t (void)> & size_func,std::optional<torch::Device> device)136 void InputArchive::load_from(
137     const std::function<size_t(uint64_t, void*, size_t)>& read_func,
138     const std::function<size_t(void)>& size_func,
139     std::optional<torch::Device> device /*= std::nullopt*/) {
140   using caffe2::serialize::ReadAdapterInterface;
141   class OurAdapter : public ReadAdapterInterface {
142    public:
143     OurAdapter(
144         const std::function<size_t(uint64_t, void*, size_t)>& read_func,
145         const std::function<size_t(void)>& size_func)
146         : read_func_(read_func), size_func_(size_func) {}
147     size_t size() const override {
148       return size_func_();
149     }
150     size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
151         const override {
152       (void)what;
153       return read_func_(pos, buf, n);
154     }
155 
156    private:
157     const std::function<size_t(uint64_t, void*, size_t)>& read_func_;
158     const std::function<size_t(void)>& size_func_;
159   };
160   module_ = torch::jit::load(
161       std::make_unique<OurAdapter>(read_func, size_func), std::move(device));
162 }
163 
keys()164 std::vector<std::string> InputArchive::keys() {
165   std::vector<std::string> all_keys;
166   all_keys.reserve(module_.named_attributes(/*recurse=*/false).size());
167 
168   for (const torch::jit::NameValue& s :
169        module_.named_attributes(/*recurse=*/false)) {
170     all_keys.push_back(s.name);
171   }
172 
173   return all_keys;
174 }
175 
176 } // namespace serialize
177 } // namespace torch
178