1 #include <torch/optim/adam.h>
2
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/nn/module.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7
8 #include <ATen/ATen.h>
9 #include <c10/util/irange.h>
10
11 #include <cmath>
12 #include <functional>
13
14 namespace torch {
15 namespace optim {
16
AdamOptions(double lr)17 AdamOptions::AdamOptions(double lr) : lr_(lr) {}
18
operator ==(const AdamOptions & lhs,const AdamOptions & rhs)19 bool operator==(const AdamOptions& lhs, const AdamOptions& rhs) {
20 return (lhs.lr() == rhs.lr()) &&
21 (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) &&
22 (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) &&
23 (lhs.eps() == rhs.eps()) &&
24 (lhs.weight_decay() == rhs.weight_decay() &&
25 (lhs.amsgrad() == rhs.amsgrad()));
26 }
27
serialize(torch::serialize::OutputArchive & archive) const28 void AdamOptions::serialize(torch::serialize::OutputArchive& archive) const {
29 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
30 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
31 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
32 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
33 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(amsgrad);
34 }
35
serialize(torch::serialize::InputArchive & archive)36 void AdamOptions::serialize(torch::serialize::InputArchive& archive) {
37 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
38 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
39 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
40 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
41 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, amsgrad);
42 }
43
get_lr() const44 double AdamOptions::get_lr() const {
45 return lr();
46 }
47
set_lr(const double lr)48 void AdamOptions::set_lr(const double lr) {
49 this->lr(lr);
50 }
51
operator ==(const AdamParamState & lhs,const AdamParamState & rhs)52 bool operator==(const AdamParamState& lhs, const AdamParamState& rhs) {
53 return (lhs.step() == rhs.step()) &&
54 torch::equal(lhs.exp_avg(), rhs.exp_avg()) &&
55 torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) &&
56 torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq());
57 }
58
serialize(torch::serialize::OutputArchive & archive) const59 void AdamParamState::serialize(torch::serialize::OutputArchive& archive) const {
60 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
61 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
62 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
63 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_exp_avg_sq);
64 }
65
serialize(torch::serialize::InputArchive & archive)66 void AdamParamState::serialize(torch::serialize::InputArchive& archive) {
67 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
68 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg);
69 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg_sq);
70 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq);
71 }
72
step(LossClosure closure)73 Tensor Adam::step(LossClosure closure) {
74 NoGradGuard no_grad;
75 Tensor loss = {};
76 if (closure != nullptr) {
77 at::AutoGradMode enable_grad(true);
78 loss = closure();
79 }
80 for (auto& group : param_groups_) {
81 for (auto& p : group.params()) {
82 if (!p.grad().defined()) {
83 continue;
84 }
85 auto grad = p.grad();
86 TORCH_CHECK(!grad.is_sparse(), "Adam does not support sparse gradients" /*, please consider SparseAdam instead*/);
87 auto param_state = state_.find(p.unsafeGetTensorImpl());
88 auto& options = static_cast<AdamOptions&>(group.options());
89
90 // State initialization
91 if (param_state == state_.end()) {
92 auto state = std::make_unique<AdamParamState>();
93 state->step(0);
94 // Exponential moving average of gradient values
95 state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve));
96 // Exponential moving average of squared gradient values
97 state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
98 if (options.amsgrad()) {
99 // Maintains max of all exp. moving avg. of sq. grad. values
100 state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
101 }
102 state_[p.unsafeGetTensorImpl()] = std::move(state);
103 }
104
105 auto& state =
106 static_cast<AdamParamState&>(*state_[p.unsafeGetTensorImpl()]);
107 auto& exp_avg = state.exp_avg();
108 auto& exp_avg_sq = state.exp_avg_sq();
109 auto& max_exp_avg_sq = state.max_exp_avg_sq();
110
111 state.step(state.step() + 1);
112 auto beta1 = std::get<0>(options.betas());
113 auto beta2 = std::get<1>(options.betas());
114
115 auto bias_correction1 = 1 - std::pow(beta1, state.step());
116 auto bias_correction2 = 1 - std::pow(beta2, state.step());
117
118 if (options.weight_decay() != 0) {
119 grad = grad.add(p, options.weight_decay());
120 }
121
122 // Decay the first and second moment running average coefficient
123 exp_avg.mul_(beta1).add_(grad, 1 - beta1);
124 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2);
125
126 Tensor denom;
127 if (options.amsgrad()) {
128 // Maintains the maximum of all 2nd moment running avg. till now
129 torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq);
130 // Use the max. for normalizing running avg. of gradient
131 denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2))
132 .add_(options.eps());
133 } else {
134 denom =
135 (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
136 }
137
138 auto step_size = options.lr() / bias_correction1;
139 p.addcdiv_(exp_avg, denom, -step_size);
140 }
141 }
142 return loss;
143 }
144
save(serialize::OutputArchive & archive) const145 void Adam::save(serialize::OutputArchive& archive) const {
146 serialize(*this, archive);
147 }
148
load(serialize::InputArchive & archive)149 void Adam::load(serialize::InputArchive& archive) {
150 IValue pytorch_version;
151 if (archive.try_read("pytorch_version", pytorch_version)) {
152 serialize(*this, archive);
153 } else { // deserializing archives saved in old format (prior to
154 // version 1.5.0)
155 TORCH_WARN(
156 "Your serialized Adam optimizer is still using the old serialization format. "
157 "You should re-save your Adam optimizer to use the new serialization format.");
158 std::vector<int64_t> step_buffers;
159 std::vector<at::Tensor> exp_average_buffers;
160 std::vector<at::Tensor> exp_average_sq_buffers;
161 std::vector<at::Tensor> max_exp_average_sq_buffers;
162 torch::optim::serialize(archive, "step_buffers", step_buffers);
163 torch::optim::serialize(
164 archive, "exp_average_buffers", exp_average_buffers);
165 torch::optim::serialize(
166 archive, "exp_average_sq_buffers", exp_average_sq_buffers);
167 torch::optim::serialize(
168 archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
169 // since there were no param_groups prior to version 1.5.0, assuming all
170 // tensors are now in one param_group
171 std::vector<Tensor> params = param_groups_.at(0).params();
172 for (const auto idx : c10::irange(step_buffers.size())) {
173 auto state = std::make_unique<AdamParamState>();
174 state->step(step_buffers.at(idx));
175 state->exp_avg(exp_average_buffers.at(idx));
176 state->exp_avg_sq(exp_average_sq_buffers.at(idx));
177 if (idx < max_exp_average_sq_buffers.size()) {
178 state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
179 }
180 state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
181 }
182 }
183 }
184 } // namespace optim
185 } // namespace torch
186