xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/serialize.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/types.h>
7 #include <cstddef>
8 #include <cstdint>
9 #include <deque>
10 #include <string>
11 #include <vector>
12 
13 namespace torch {
14 namespace optim {
15 namespace detail {
16 // Utility function to save state
17 template <typename DerivedOptimizerParamState>
serialize(serialize::OutputArchive & archive,const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & state)18 void serialize(
19     serialize::OutputArchive& archive,
20     const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
21         state) {
22   for (const auto& item : state) {
23     serialize::OutputArchive param_state_archive(archive.compilation_unit());
24     std::string tensorimpl_key =
25         std::to_string(reinterpret_cast<size_t>(item.first));
26     const DerivedOptimizerParamState& curr_state =
27         static_cast<const DerivedOptimizerParamState&>(*(item.second.get()));
28     curr_state.serialize(param_state_archive);
29     archive.write(tensorimpl_key, param_state_archive);
30   }
31 }
32 
33 // Utility function to load state
34 template <typename DerivedOptimizerParamState>
serialize(serialize::InputArchive & archive,ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & state)35 void serialize(
36     serialize::InputArchive& archive,
37     ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state) {
38   std::vector<std::string> tensorimpl_keys = archive.keys();
39   for (const std::string& tensorimpl_key : tensorimpl_keys) {
40     serialize::InputArchive param_state_archive;
41     archive.read(tensorimpl_key, param_state_archive);
42     DerivedOptimizerParamState param_state;
43     param_state.serialize(param_state_archive);
44     state[reinterpret_cast<void*>(std::stoull(tensorimpl_key))] =
45         std::make_unique<DerivedOptimizerParamState>(param_state);
46   }
47 }
48 
49 // Utility function to save param_groups
50 template <typename DerivedOptimizerParamOptions>
serialize(serialize::OutputArchive & archive,const std::vector<OptimizerParamGroup> & param_groups)51 void serialize(
52     serialize::OutputArchive& archive,
53     const std::vector<OptimizerParamGroup>& param_groups) {
54   archive.write(
55       "param_groups/size",
56       torch::tensor(static_cast<int64_t>(param_groups.size())));
57   for (const auto i : c10::irange(param_groups.size())) {
58     serialize::OutputArchive param_group_archive(archive.compilation_unit());
59     std::vector<Tensor> params = param_groups[i].params();
60     param_group_archive.write(
61         "params/size", torch::tensor(static_cast<int64_t>(params.size())));
62     for (const auto index : c10::irange(params.size())) {
63       param_group_archive.write(
64           "params/" + std::to_string(index),
65           IValue(std::to_string(
66               reinterpret_cast<size_t>(params[index].unsafeGetTensorImpl()))));
67     }
68     const DerivedOptimizerParamOptions& param_group_options =
69         static_cast<const DerivedOptimizerParamOptions&>(
70             param_groups[i].options());
71     serialize::OutputArchive param_group_options_archive(
72         param_group_archive.compilation_unit());
73     param_group_options.serialize(param_group_options_archive);
74     param_group_archive.write("options", param_group_options_archive);
75     archive.write("param_groups/" + std::to_string(i), param_group_archive);
76   }
77 }
78 
79 // Utility function to load param_groups
80 // We take as input vector of pair of string and unique_ptr to optimizer options
81 // so that we can retain the state for each param by using the old tensor impl
82 // keys (saved during serialization) and map the new tensor impl keys to the
83 // correct state for each param
84 template <typename DerivedOptimizerParamOptions>
serialize(serialize::InputArchive & archive,std::vector<std::pair<std::vector<std::string>,std::unique_ptr<OptimizerOptions>>> & param_groups)85 void serialize(
86     serialize::InputArchive& archive,
87     std::vector<
88         std::pair<std::vector<std::string>, std::unique_ptr<OptimizerOptions>>>&
89         param_groups) {
90   torch::Tensor param_groups_size_tensor;
91   archive.read("param_groups/size", param_groups_size_tensor);
92   const int64_t param_groups_size = param_groups_size_tensor.item<int64_t>();
93   for (const auto i : c10::irange(param_groups_size)) {
94     serialize::InputArchive param_group_archive;
95     archive.read("param_groups/" + std::to_string(i), param_group_archive);
96     torch::Tensor size_tensor;
97     param_group_archive.read("params/size", size_tensor);
98     const int64_t size = size_tensor.item<int64_t>();
99     std::vector<std::string> params;
100     for (const auto index : c10::irange(size)) {
101       IValue ivalue;
102       param_group_archive.read("params/" + std::to_string(index), ivalue);
103       std::string element = ivalue.toStringRef();
104       params.emplace_back(element);
105     }
106     serialize::InputArchive param_group_options_archive;
107     param_group_archive.read("options", param_group_options_archive);
108     DerivedOptimizerParamOptions param_group_options(0);
109     param_group_options.serialize(param_group_options_archive);
110     param_groups.emplace_back(std::make_pair(
111         params,
112         std::make_unique<DerivedOptimizerParamOptions>(param_group_options)));
113   }
114 }
115 } // namespace detail
116 
117 // Note: These functions are all called `serialize()` so they can be called
118 // inside a template where the archive type is a template type and can thus be
119 // passed such that the appropriate overload is selected.
120 
121 /// Utility function to save a value of `int64_t` type.
122 void serialize(
123     serialize::OutputArchive& archive,
124     const std::string& key,
125     const int64_t& value);
126 
127 /// Utility function to load a value of `int64_t` type.
128 void serialize(
129     serialize::InputArchive& archive,
130     const std::string& key,
131     int64_t& value);
132 
133 /// Utility function to save a vector of step buffers.
134 void serialize(
135     serialize::OutputArchive& archive,
136     const std::string& key,
137     const std::vector<int64_t>& steps);
138 
139 /// Utility function to load a vector of step buffers.
140 void serialize(
141     serialize::InputArchive& archive,
142     const std::string& key,
143     std::vector<int64_t>& steps);
144 
145 // Utility function to save state and param_groups
146 template <
147     typename DerivedOptimizerParamState,
148     typename DerivedOptimizerParamOptions>
serialize(serialize::OutputArchive & archive,const Optimizer & optimizer)149 void serialize(serialize::OutputArchive& archive, const Optimizer& optimizer) {
150   archive.write("pytorch_version", IValue("1.5.0"));
151   serialize::OutputArchive state_archive(archive.compilation_unit());
152   detail::serialize<DerivedOptimizerParamState>(
153       state_archive, optimizer.state());
154   archive.write("state", state_archive);
155 
156   serialize::OutputArchive param_groups_archive(archive.compilation_unit());
157   detail::serialize<DerivedOptimizerParamOptions>(
158       param_groups_archive, optimizer.param_groups());
159   archive.write("param_groups", param_groups_archive);
160 }
161 
162 // Utility function to load state and param_groups and update state
163 template <
164     typename DerivedOptimizerParamState,
165     typename DerivedOptimizerParamOptions>
serialize(serialize::InputArchive & archive,Optimizer & optimizer)166 void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
167   IValue pytorch_version;
168   archive.read("pytorch_version", pytorch_version);
169   TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0");
170   serialize::InputArchive state_archive;
171   archive.read("state", state_archive);
172   ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> saved_state;
173   detail::serialize<DerivedOptimizerParamState>(state_archive, saved_state);
174 
175   serialize::InputArchive param_groups_archive;
176   archive.read("param_groups", param_groups_archive);
177   std::vector<
178       std::pair<std::vector<std::string>, std::unique_ptr<OptimizerOptions>>>
179       saved_param_groups;
180   detail::serialize<DerivedOptimizerParamOptions>(
181       param_groups_archive, saved_param_groups);
182 
183   // update state and optimizer options
184   TORCH_CHECK(
185       saved_param_groups.size() == optimizer.param_groups().size(),
186       "loaded state dict has a different number of parameter groups");
187   for (const auto i : c10::irange(saved_param_groups.size())) {
188     std::vector<std::string> param_group_old_keys = saved_param_groups[i].first;
189     std::vector<Tensor> params = optimizer.param_groups()[i].params();
190     TORCH_CHECK(
191         param_group_old_keys.size() == params.size(),
192         "loaded state dict contains a parameter group that has a different size than the optimizer's parameter group");
193 
194     for (const auto idx : c10::irange(params.size())) {
195       auto param_group_old_key =
196           reinterpret_cast<void*>(std::stoull(param_group_old_keys[idx]));
197       if (saved_state.find(param_group_old_key) != saved_state.end()) {
198         optimizer.state()[params[idx].unsafeGetTensorImpl()] =
199             std::move(saved_state[param_group_old_key]);
200       }
201     }
202 
203     auto& saved_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
204         *saved_param_groups[i].second);
205     auto& current_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
206         optimizer.param_groups()[i].options());
207     current_options = saved_options;
208   }
209 }
210 
211 /// Utility function to save a vector of buffers.
212 template <typename BufferContainer>
serialize(serialize::OutputArchive & archive,const std::string & key,const BufferContainer & buffers)213 void serialize(
214     serialize::OutputArchive& archive,
215     const std::string& key,
216     const BufferContainer& buffers) {
217   archive.write(
218       key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
219   for (const auto index : c10::irange(buffers.size())) {
220     archive.write(
221         key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
222   }
223 }
224 
225 /// Utility function to load a vector of buffers.
226 template <typename BufferContainer>
serialize(serialize::InputArchive & archive,const std::string & key,BufferContainer & buffers)227 void serialize(
228     serialize::InputArchive& archive,
229     const std::string& key,
230     BufferContainer& buffers) {
231   buffers.clear();
232   torch::Tensor size_tensor;
233   archive.read(key + "/size", size_tensor);
234   const size_t size = size_tensor.item<int64_t>();
235   for (const auto index : c10::irange(size)) {
236     buffers.emplace_back();
237     archive.read(
238         key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
239   }
240 }
241 
242 template <typename T>
deque_to_list(const std::deque<T> & dq)243 c10::List<T> deque_to_list(const std::deque<T>& dq) {
244   c10::List<T> list;
245   list.reserve(dq.size());
246   for (const auto& e : dq) {
247     list.emplace_back(e);
248   }
249   return list;
250 }
251 
252 template <typename T>
list_to_deque(const c10::List<T> & list)253 std::deque<T> list_to_deque(const c10::List<T>& list) {
254   std::deque<T> dq;
255   for (const auto& e : list) {
256     dq.emplace_back(e);
257   }
258   return dq;
259 }
260 
261 #define _TORCH_OPTIM_SERIALIZE(name) \
262   torch::optim::serialize(archive, #name, self.name)
263 
264 #define _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(OptimizerName)               \
265   torch::optim::serialize<OptimizerName##ParamState, OptimizerName##Options>( \
266       archive, self)
267 
268 #define _TORCH_OPTIM_SERIALIZE_TORCH_ARG(name)           \
269   {                                                      \
270     auto ivalue = torch::IValue(name());                 \
271     /* do not serialize if name is an undefined tensor*/ \
272     if (!(ivalue.isTensor() &&                           \
273           ivalue.unsafeToTensorImpl() ==                 \
274               at::UndefinedTensorImpl::singleton())) {   \
275       archive.write(#name, ivalue);                      \
276     }                                                    \
277   }
278 
279 #define _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(name)           \
280   {                                                            \
281     c10::IValue ivalue = torch::IValue(deque_to_list(name())); \
282     archive.write(#name, ivalue);                              \
283   }
284 
285 #define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name)                   \
286   {                                                                   \
287     c10::IValue ivalue;                                               \
288     bool exists = archive.try_read(#name, ivalue);                    \
289     if (exists) {                                                     \
290       name(ivalue.to<T>());                                           \
291     } else {                                                          \
292       bool is_tensor_type = std::is_base_of<torch::Tensor, T>::value; \
293       TORCH_INTERNAL_ASSERT(is_tensor_type);                          \
294     }                                                                 \
295   }
296 
297 #define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) \
298   {                                                          \
299     c10::IValue ivalue;                                      \
300     bool exists = archive.try_read(#name, ivalue);           \
301     if (exists) {                                            \
302       name(ivalue.toOptional<T>());                          \
303     }                                                        \
304   }
305 
306 #define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(T, name) \
307   {                                                       \
308     c10::IValue ivalue;                                   \
309     archive.read(#name, ivalue);                          \
310     auto list = ivalue.to<c10::List<T::value_type>>();    \
311     name(list_to_deque(list));                            \
312   }
313 
314 } // namespace optim
315 } // namespace torch
316