xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/adamw.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/adamw.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/nn/module.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7 
8 #include <ATen/ATen.h>
9 #include <c10/util/irange.h>
10 
11 #include <cmath>
12 #include <functional>
13 
14 namespace torch {
15 namespace optim {
16 
AdamWOptions(double lr)17 AdamWOptions::AdamWOptions(double lr) : lr_(lr) {}
18 
operator ==(const AdamWOptions & lhs,const AdamWOptions & rhs)19 bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs) {
20   return (lhs.lr() == rhs.lr()) &&
21       (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) &&
22       (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) &&
23       (lhs.eps() == rhs.eps()) && (lhs.weight_decay() == rhs.weight_decay()) &&
24       (lhs.amsgrad() == rhs.amsgrad());
25 }
26 
serialize(torch::serialize::OutputArchive & archive) const27 void AdamWOptions::serialize(torch::serialize::OutputArchive& archive) const {
28   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
29   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
30   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
31   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
32   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(amsgrad);
33 }
34 
serialize(torch::serialize::InputArchive & archive)35 void AdamWOptions::serialize(torch::serialize::InputArchive& archive) {
36   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
37   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
38   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
39   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
40   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, amsgrad);
41 }
42 
get_lr() const43 double AdamWOptions::get_lr() const {
44   return lr();
45 }
46 
set_lr(const double lr)47 void AdamWOptions::set_lr(const double lr) {
48   this->lr(lr);
49 }
50 
operator ==(const AdamWParamState & lhs,const AdamWParamState & rhs)51 bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs) {
52   return (lhs.step() == rhs.step()) &&
53       torch::equal(lhs.exp_avg(), rhs.exp_avg()) &&
54       torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) &&
55       torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq());
56 }
57 
serialize(torch::serialize::OutputArchive & archive) const58 void AdamWParamState::serialize(
59     torch::serialize::OutputArchive& archive) const {
60   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
61   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
62   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
63   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_exp_avg_sq);
64 }
65 
serialize(torch::serialize::InputArchive & archive)66 void AdamWParamState::serialize(torch::serialize::InputArchive& archive) {
67   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
68   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg);
69   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg_sq);
70   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq);
71 }
72 
step(LossClosure closure)73 Tensor AdamW::step(LossClosure closure) {
74   NoGradGuard no_grad;
75   Tensor loss = {};
76   if (closure != nullptr) {
77     at::AutoGradMode enable_grad(true);
78     loss = closure();
79   }
80   for (auto& group : param_groups_) {
81     for (auto& p : group.params()) {
82       if (!p.grad().defined()) {
83         continue;
84       }
85       const auto& grad = p.grad();
86       TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients" /*, please consider SparseAdamW instead*/);
87       auto param_state = state_.find(p.unsafeGetTensorImpl());
88       auto& options = static_cast<AdamWOptions&>(group.options());
89 
90       // Perform stepweight decay
91       if (options.weight_decay() != 0) {
92         p.mul_(1 - options.lr() * options.weight_decay());
93       }
94 
95       // State initialization
96       if (param_state == state_.end()) {
97         auto state = std::make_unique<AdamWParamState>();
98         state->step(0);
99         // Exponential moving average of gradient values
100         state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve));
101         // Exponential moving average of squared gradient values
102         state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
103         if (options.amsgrad()) {
104           // Maintains max of all exp. moving avg. of sq. grad. values
105           state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
106         }
107         state_[p.unsafeGetTensorImpl()] = std::move(state);
108       }
109 
110       auto& state =
111           static_cast<AdamWParamState&>(*state_[p.unsafeGetTensorImpl()]);
112       auto& exp_avg = state.exp_avg();
113       auto& exp_avg_sq = state.exp_avg_sq();
114       auto& max_exp_avg_sq = state.max_exp_avg_sq();
115 
116       state.step(state.step() + 1);
117       auto beta1 = std::get<0>(options.betas());
118       auto beta2 = std::get<1>(options.betas());
119 
120       auto bias_correction1 = 1 - std::pow(beta1, state.step());
121       auto bias_correction2 = 1 - std::pow(beta2, state.step());
122 
123       // Decay the first and second moment running average coefficient
124       exp_avg.mul_(beta1).add_(grad, 1 - beta1);
125       exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2);
126 
127       Tensor denom;
128       if (options.amsgrad()) {
129         // Maintains the maximum of all 2nd moment running avg. till now
130         torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq);
131         // Use the max. for normalizing running avg. of gradient
132         denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2))
133                     .add_(options.eps());
134       } else {
135         denom =
136             (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
137       }
138 
139       auto step_size = options.lr() / bias_correction1;
140       p.addcdiv_(exp_avg, denom, -step_size);
141     }
142   }
143   return loss;
144 }
145 
save(serialize::OutputArchive & archive) const146 void AdamW::save(serialize::OutputArchive& archive) const {
147   serialize(*this, archive);
148 }
149 
load(serialize::InputArchive & archive)150 void AdamW::load(serialize::InputArchive& archive) {
151   IValue pytorch_version;
152   if (archive.try_read("pytorch_version", pytorch_version)) {
153     serialize(*this, archive);
154   } else { // deserializing archives saved in old format (prior to
155            // version 1.5.0)
156     TORCH_WARN(
157         "Your serialized AdamW optimizer is still using the old serialization format. "
158         "You should re-save your AdamW optimizer to use the new serialization format.");
159     std::vector<int64_t> step_buffers;
160     std::vector<at::Tensor> exp_average_buffers;
161     std::vector<at::Tensor> exp_average_sq_buffers;
162     std::vector<at::Tensor> max_exp_average_sq_buffers;
163     torch::optim::serialize(archive, "step_buffers", step_buffers);
164     torch::optim::serialize(
165         archive, "exp_average_buffers", exp_average_buffers);
166     torch::optim::serialize(
167         archive, "exp_average_sq_buffers", exp_average_sq_buffers);
168     torch::optim::serialize(
169         archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
170     // since there were no param_groups prior to version 1.5.0, assuming all
171     // tensors are now in one param_group
172     std::vector<Tensor> params = param_groups_.at(0).params();
173     for (const auto idx : c10::irange(step_buffers.size())) {
174       auto state = std::make_unique<AdamWParamState>();
175       state->step(step_buffers.at(idx));
176       state->exp_avg(exp_average_buffers.at(idx));
177       state->exp_avg_sq(exp_average_sq_buffers.at(idx));
178       if (idx < max_exp_average_sq_buffers.size()) {
179         state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
180       }
181       state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
182     }
183   }
184 }
185 } // namespace optim
186 } // namespace torch
187