1 #pragma once
2
3 #include <c10/util/irange.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/types.h>
7 #include <cstddef>
8 #include <cstdint>
9 #include <deque>
10 #include <string>
11 #include <vector>
12
13 namespace torch {
14 namespace optim {
15 namespace detail {
16 // Utility function to save state
17 template <typename DerivedOptimizerParamState>
serialize(serialize::OutputArchive & archive,const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & state)18 void serialize(
19 serialize::OutputArchive& archive,
20 const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
21 state) {
22 for (const auto& item : state) {
23 serialize::OutputArchive param_state_archive(archive.compilation_unit());
24 std::string tensorimpl_key =
25 std::to_string(reinterpret_cast<size_t>(item.first));
26 const DerivedOptimizerParamState& curr_state =
27 static_cast<const DerivedOptimizerParamState&>(*(item.second.get()));
28 curr_state.serialize(param_state_archive);
29 archive.write(tensorimpl_key, param_state_archive);
30 }
31 }
32
33 // Utility function to load state
34 template <typename DerivedOptimizerParamState>
serialize(serialize::InputArchive & archive,ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & state)35 void serialize(
36 serialize::InputArchive& archive,
37 ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state) {
38 std::vector<std::string> tensorimpl_keys = archive.keys();
39 for (const std::string& tensorimpl_key : tensorimpl_keys) {
40 serialize::InputArchive param_state_archive;
41 archive.read(tensorimpl_key, param_state_archive);
42 DerivedOptimizerParamState param_state;
43 param_state.serialize(param_state_archive);
44 state[reinterpret_cast<void*>(std::stoull(tensorimpl_key))] =
45 std::make_unique<DerivedOptimizerParamState>(param_state);
46 }
47 }
48
49 // Utility function to save param_groups
50 template <typename DerivedOptimizerParamOptions>
serialize(serialize::OutputArchive & archive,const std::vector<OptimizerParamGroup> & param_groups)51 void serialize(
52 serialize::OutputArchive& archive,
53 const std::vector<OptimizerParamGroup>& param_groups) {
54 archive.write(
55 "param_groups/size",
56 torch::tensor(static_cast<int64_t>(param_groups.size())));
57 for (const auto i : c10::irange(param_groups.size())) {
58 serialize::OutputArchive param_group_archive(archive.compilation_unit());
59 std::vector<Tensor> params = param_groups[i].params();
60 param_group_archive.write(
61 "params/size", torch::tensor(static_cast<int64_t>(params.size())));
62 for (const auto index : c10::irange(params.size())) {
63 param_group_archive.write(
64 "params/" + std::to_string(index),
65 IValue(std::to_string(
66 reinterpret_cast<size_t>(params[index].unsafeGetTensorImpl()))));
67 }
68 const DerivedOptimizerParamOptions& param_group_options =
69 static_cast<const DerivedOptimizerParamOptions&>(
70 param_groups[i].options());
71 serialize::OutputArchive param_group_options_archive(
72 param_group_archive.compilation_unit());
73 param_group_options.serialize(param_group_options_archive);
74 param_group_archive.write("options", param_group_options_archive);
75 archive.write("param_groups/" + std::to_string(i), param_group_archive);
76 }
77 }
78
79 // Utility function to load param_groups
80 // We take as input vector of pair of string and unique_ptr to optimizer options
81 // so that we can retain the state for each param by using the old tensor impl
82 // keys (saved during serialization) and map the new tensor impl keys to the
83 // correct state for each param
84 template <typename DerivedOptimizerParamOptions>
serialize(serialize::InputArchive & archive,std::vector<std::pair<std::vector<std::string>,std::unique_ptr<OptimizerOptions>>> & param_groups)85 void serialize(
86 serialize::InputArchive& archive,
87 std::vector<
88 std::pair<std::vector<std::string>, std::unique_ptr<OptimizerOptions>>>&
89 param_groups) {
90 torch::Tensor param_groups_size_tensor;
91 archive.read("param_groups/size", param_groups_size_tensor);
92 const int64_t param_groups_size = param_groups_size_tensor.item<int64_t>();
93 for (const auto i : c10::irange(param_groups_size)) {
94 serialize::InputArchive param_group_archive;
95 archive.read("param_groups/" + std::to_string(i), param_group_archive);
96 torch::Tensor size_tensor;
97 param_group_archive.read("params/size", size_tensor);
98 const int64_t size = size_tensor.item<int64_t>();
99 std::vector<std::string> params;
100 for (const auto index : c10::irange(size)) {
101 IValue ivalue;
102 param_group_archive.read("params/" + std::to_string(index), ivalue);
103 std::string element = ivalue.toStringRef();
104 params.emplace_back(element);
105 }
106 serialize::InputArchive param_group_options_archive;
107 param_group_archive.read("options", param_group_options_archive);
108 DerivedOptimizerParamOptions param_group_options(0);
109 param_group_options.serialize(param_group_options_archive);
110 param_groups.emplace_back(std::make_pair(
111 params,
112 std::make_unique<DerivedOptimizerParamOptions>(param_group_options)));
113 }
114 }
115 } // namespace detail
116
117 // Note: These functions are all called `serialize()` so they can be called
118 // inside a template where the archive type is a template type and can thus be
119 // passed such that the appropriate overload is selected.
120
121 /// Utility function to save a value of `int64_t` type.
122 void serialize(
123 serialize::OutputArchive& archive,
124 const std::string& key,
125 const int64_t& value);
126
127 /// Utility function to load a value of `int64_t` type.
128 void serialize(
129 serialize::InputArchive& archive,
130 const std::string& key,
131 int64_t& value);
132
133 /// Utility function to save a vector of step buffers.
134 void serialize(
135 serialize::OutputArchive& archive,
136 const std::string& key,
137 const std::vector<int64_t>& steps);
138
139 /// Utility function to load a vector of step buffers.
140 void serialize(
141 serialize::InputArchive& archive,
142 const std::string& key,
143 std::vector<int64_t>& steps);
144
145 // Utility function to save state and param_groups
146 template <
147 typename DerivedOptimizerParamState,
148 typename DerivedOptimizerParamOptions>
serialize(serialize::OutputArchive & archive,const Optimizer & optimizer)149 void serialize(serialize::OutputArchive& archive, const Optimizer& optimizer) {
150 archive.write("pytorch_version", IValue("1.5.0"));
151 serialize::OutputArchive state_archive(archive.compilation_unit());
152 detail::serialize<DerivedOptimizerParamState>(
153 state_archive, optimizer.state());
154 archive.write("state", state_archive);
155
156 serialize::OutputArchive param_groups_archive(archive.compilation_unit());
157 detail::serialize<DerivedOptimizerParamOptions>(
158 param_groups_archive, optimizer.param_groups());
159 archive.write("param_groups", param_groups_archive);
160 }
161
162 // Utility function to load state and param_groups and update state
163 template <
164 typename DerivedOptimizerParamState,
165 typename DerivedOptimizerParamOptions>
serialize(serialize::InputArchive & archive,Optimizer & optimizer)166 void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
167 IValue pytorch_version;
168 archive.read("pytorch_version", pytorch_version);
169 TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0");
170 serialize::InputArchive state_archive;
171 archive.read("state", state_archive);
172 ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> saved_state;
173 detail::serialize<DerivedOptimizerParamState>(state_archive, saved_state);
174
175 serialize::InputArchive param_groups_archive;
176 archive.read("param_groups", param_groups_archive);
177 std::vector<
178 std::pair<std::vector<std::string>, std::unique_ptr<OptimizerOptions>>>
179 saved_param_groups;
180 detail::serialize<DerivedOptimizerParamOptions>(
181 param_groups_archive, saved_param_groups);
182
183 // update state and optimizer options
184 TORCH_CHECK(
185 saved_param_groups.size() == optimizer.param_groups().size(),
186 "loaded state dict has a different number of parameter groups");
187 for (const auto i : c10::irange(saved_param_groups.size())) {
188 std::vector<std::string> param_group_old_keys = saved_param_groups[i].first;
189 std::vector<Tensor> params = optimizer.param_groups()[i].params();
190 TORCH_CHECK(
191 param_group_old_keys.size() == params.size(),
192 "loaded state dict contains a parameter group that has a different size than the optimizer's parameter group");
193
194 for (const auto idx : c10::irange(params.size())) {
195 auto param_group_old_key =
196 reinterpret_cast<void*>(std::stoull(param_group_old_keys[idx]));
197 if (saved_state.find(param_group_old_key) != saved_state.end()) {
198 optimizer.state()[params[idx].unsafeGetTensorImpl()] =
199 std::move(saved_state[param_group_old_key]);
200 }
201 }
202
203 auto& saved_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
204 *saved_param_groups[i].second);
205 auto& current_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
206 optimizer.param_groups()[i].options());
207 current_options = saved_options;
208 }
209 }
210
211 /// Utility function to save a vector of buffers.
212 template <typename BufferContainer>
serialize(serialize::OutputArchive & archive,const std::string & key,const BufferContainer & buffers)213 void serialize(
214 serialize::OutputArchive& archive,
215 const std::string& key,
216 const BufferContainer& buffers) {
217 archive.write(
218 key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
219 for (const auto index : c10::irange(buffers.size())) {
220 archive.write(
221 key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
222 }
223 }
224
225 /// Utility function to load a vector of buffers.
226 template <typename BufferContainer>
serialize(serialize::InputArchive & archive,const std::string & key,BufferContainer & buffers)227 void serialize(
228 serialize::InputArchive& archive,
229 const std::string& key,
230 BufferContainer& buffers) {
231 buffers.clear();
232 torch::Tensor size_tensor;
233 archive.read(key + "/size", size_tensor);
234 const size_t size = size_tensor.item<int64_t>();
235 for (const auto index : c10::irange(size)) {
236 buffers.emplace_back();
237 archive.read(
238 key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
239 }
240 }
241
242 template <typename T>
deque_to_list(const std::deque<T> & dq)243 c10::List<T> deque_to_list(const std::deque<T>& dq) {
244 c10::List<T> list;
245 list.reserve(dq.size());
246 for (const auto& e : dq) {
247 list.emplace_back(e);
248 }
249 return list;
250 }
251
252 template <typename T>
list_to_deque(const c10::List<T> & list)253 std::deque<T> list_to_deque(const c10::List<T>& list) {
254 std::deque<T> dq;
255 for (const auto& e : list) {
256 dq.emplace_back(e);
257 }
258 return dq;
259 }
260
261 #define _TORCH_OPTIM_SERIALIZE(name) \
262 torch::optim::serialize(archive, #name, self.name)
263
264 #define _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(OptimizerName) \
265 torch::optim::serialize<OptimizerName##ParamState, OptimizerName##Options>( \
266 archive, self)
267
268 #define _TORCH_OPTIM_SERIALIZE_TORCH_ARG(name) \
269 { \
270 auto ivalue = torch::IValue(name()); \
271 /* do not serialize if name is an undefined tensor*/ \
272 if (!(ivalue.isTensor() && \
273 ivalue.unsafeToTensorImpl() == \
274 at::UndefinedTensorImpl::singleton())) { \
275 archive.write(#name, ivalue); \
276 } \
277 }
278
279 #define _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(name) \
280 { \
281 c10::IValue ivalue = torch::IValue(deque_to_list(name())); \
282 archive.write(#name, ivalue); \
283 }
284
285 #define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) \
286 { \
287 c10::IValue ivalue; \
288 bool exists = archive.try_read(#name, ivalue); \
289 if (exists) { \
290 name(ivalue.to<T>()); \
291 } else { \
292 bool is_tensor_type = std::is_base_of<torch::Tensor, T>::value; \
293 TORCH_INTERNAL_ASSERT(is_tensor_type); \
294 } \
295 }
296
297 #define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) \
298 { \
299 c10::IValue ivalue; \
300 bool exists = archive.try_read(#name, ivalue); \
301 if (exists) { \
302 name(ivalue.toOptional<T>()); \
303 } \
304 }
305
306 #define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(T, name) \
307 { \
308 c10::IValue ivalue; \
309 archive.read(#name, ivalue); \
310 auto list = ivalue.to<c10::List<T::value_type>>(); \
311 name(list_to_deque(list)); \
312 }
313
314 } // namespace optim
315 } // namespace torch
316