1 #include <torch/optim/adamw.h>
2
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/nn/module.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7
8 #include <ATen/ATen.h>
9 #include <c10/util/irange.h>
10
11 #include <cmath>
12 #include <functional>
13
14 namespace torch {
15 namespace optim {
16
AdamWOptions(double lr)17 AdamWOptions::AdamWOptions(double lr) : lr_(lr) {}
18
operator ==(const AdamWOptions & lhs,const AdamWOptions & rhs)19 bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs) {
20 return (lhs.lr() == rhs.lr()) &&
21 (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) &&
22 (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) &&
23 (lhs.eps() == rhs.eps()) && (lhs.weight_decay() == rhs.weight_decay()) &&
24 (lhs.amsgrad() == rhs.amsgrad());
25 }
26
serialize(torch::serialize::OutputArchive & archive) const27 void AdamWOptions::serialize(torch::serialize::OutputArchive& archive) const {
28 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
29 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
30 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
31 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
32 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(amsgrad);
33 }
34
serialize(torch::serialize::InputArchive & archive)35 void AdamWOptions::serialize(torch::serialize::InputArchive& archive) {
36 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
37 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
38 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
39 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
40 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, amsgrad);
41 }
42
get_lr() const43 double AdamWOptions::get_lr() const {
44 return lr();
45 }
46
set_lr(const double lr)47 void AdamWOptions::set_lr(const double lr) {
48 this->lr(lr);
49 }
50
operator ==(const AdamWParamState & lhs,const AdamWParamState & rhs)51 bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs) {
52 return (lhs.step() == rhs.step()) &&
53 torch::equal(lhs.exp_avg(), rhs.exp_avg()) &&
54 torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) &&
55 torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq());
56 }
57
serialize(torch::serialize::OutputArchive & archive) const58 void AdamWParamState::serialize(
59 torch::serialize::OutputArchive& archive) const {
60 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
61 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
62 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
63 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_exp_avg_sq);
64 }
65
serialize(torch::serialize::InputArchive & archive)66 void AdamWParamState::serialize(torch::serialize::InputArchive& archive) {
67 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
68 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg);
69 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg_sq);
70 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq);
71 }
72
step(LossClosure closure)73 Tensor AdamW::step(LossClosure closure) {
74 NoGradGuard no_grad;
75 Tensor loss = {};
76 if (closure != nullptr) {
77 at::AutoGradMode enable_grad(true);
78 loss = closure();
79 }
80 for (auto& group : param_groups_) {
81 for (auto& p : group.params()) {
82 if (!p.grad().defined()) {
83 continue;
84 }
85 const auto& grad = p.grad();
86 TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients" /*, please consider SparseAdamW instead*/);
87 auto param_state = state_.find(p.unsafeGetTensorImpl());
88 auto& options = static_cast<AdamWOptions&>(group.options());
89
90 // Perform stepweight decay
91 if (options.weight_decay() != 0) {
92 p.mul_(1 - options.lr() * options.weight_decay());
93 }
94
95 // State initialization
96 if (param_state == state_.end()) {
97 auto state = std::make_unique<AdamWParamState>();
98 state->step(0);
99 // Exponential moving average of gradient values
100 state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve));
101 // Exponential moving average of squared gradient values
102 state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
103 if (options.amsgrad()) {
104 // Maintains max of all exp. moving avg. of sq. grad. values
105 state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
106 }
107 state_[p.unsafeGetTensorImpl()] = std::move(state);
108 }
109
110 auto& state =
111 static_cast<AdamWParamState&>(*state_[p.unsafeGetTensorImpl()]);
112 auto& exp_avg = state.exp_avg();
113 auto& exp_avg_sq = state.exp_avg_sq();
114 auto& max_exp_avg_sq = state.max_exp_avg_sq();
115
116 state.step(state.step() + 1);
117 auto beta1 = std::get<0>(options.betas());
118 auto beta2 = std::get<1>(options.betas());
119
120 auto bias_correction1 = 1 - std::pow(beta1, state.step());
121 auto bias_correction2 = 1 - std::pow(beta2, state.step());
122
123 // Decay the first and second moment running average coefficient
124 exp_avg.mul_(beta1).add_(grad, 1 - beta1);
125 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2);
126
127 Tensor denom;
128 if (options.amsgrad()) {
129 // Maintains the maximum of all 2nd moment running avg. till now
130 torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq);
131 // Use the max. for normalizing running avg. of gradient
132 denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2))
133 .add_(options.eps());
134 } else {
135 denom =
136 (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
137 }
138
139 auto step_size = options.lr() / bias_correction1;
140 p.addcdiv_(exp_avg, denom, -step_size);
141 }
142 }
143 return loss;
144 }
145
save(serialize::OutputArchive & archive) const146 void AdamW::save(serialize::OutputArchive& archive) const {
147 serialize(*this, archive);
148 }
149
load(serialize::InputArchive & archive)150 void AdamW::load(serialize::InputArchive& archive) {
151 IValue pytorch_version;
152 if (archive.try_read("pytorch_version", pytorch_version)) {
153 serialize(*this, archive);
154 } else { // deserializing archives saved in old format (prior to
155 // version 1.5.0)
156 TORCH_WARN(
157 "Your serialized AdamW optimizer is still using the old serialization format. "
158 "You should re-save your AdamW optimizer to use the new serialization format.");
159 std::vector<int64_t> step_buffers;
160 std::vector<at::Tensor> exp_average_buffers;
161 std::vector<at::Tensor> exp_average_sq_buffers;
162 std::vector<at::Tensor> max_exp_average_sq_buffers;
163 torch::optim::serialize(archive, "step_buffers", step_buffers);
164 torch::optim::serialize(
165 archive, "exp_average_buffers", exp_average_buffers);
166 torch::optim::serialize(
167 archive, "exp_average_sq_buffers", exp_average_sq_buffers);
168 torch::optim::serialize(
169 archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
170 // since there were no param_groups prior to version 1.5.0, assuming all
171 // tensors are now in one param_group
172 std::vector<Tensor> params = param_groups_.at(0).params();
173 for (const auto idx : c10::irange(step_buffers.size())) {
174 auto state = std::make_unique<AdamWParamState>();
175 state->step(step_buffers.at(idx));
176 state->exp_avg(exp_average_buffers.at(idx));
177 state->exp_avg_sq(exp_average_sq_buffers.at(idx));
178 if (idx < max_exp_average_sq_buffers.size()) {
179 state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
180 }
181 state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
182 }
183 }
184 }
185 } // namespace optim
186 } // namespace torch
187