xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/adam.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/adam.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/nn/module.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7 
8 #include <ATen/ATen.h>
9 #include <c10/util/irange.h>
10 
11 #include <cmath>
12 #include <functional>
13 
14 namespace torch {
15 namespace optim {
16 
AdamOptions(double lr)17 AdamOptions::AdamOptions(double lr) : lr_(lr) {}
18 
operator ==(const AdamOptions & lhs,const AdamOptions & rhs)19 bool operator==(const AdamOptions& lhs, const AdamOptions& rhs) {
20   return (lhs.lr() == rhs.lr()) &&
21       (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) &&
22       (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) &&
23       (lhs.eps() == rhs.eps()) &&
24       (lhs.weight_decay() == rhs.weight_decay() &&
25        (lhs.amsgrad() == rhs.amsgrad()));
26 }
27 
serialize(torch::serialize::OutputArchive & archive) const28 void AdamOptions::serialize(torch::serialize::OutputArchive& archive) const {
29   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
30   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
31   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
32   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
33   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(amsgrad);
34 }
35 
serialize(torch::serialize::InputArchive & archive)36 void AdamOptions::serialize(torch::serialize::InputArchive& archive) {
37   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
38   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
39   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
40   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
41   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, amsgrad);
42 }
43 
get_lr() const44 double AdamOptions::get_lr() const {
45   return lr();
46 }
47 
set_lr(const double lr)48 void AdamOptions::set_lr(const double lr) {
49   this->lr(lr);
50 }
51 
operator ==(const AdamParamState & lhs,const AdamParamState & rhs)52 bool operator==(const AdamParamState& lhs, const AdamParamState& rhs) {
53   return (lhs.step() == rhs.step()) &&
54       torch::equal(lhs.exp_avg(), rhs.exp_avg()) &&
55       torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) &&
56       torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq());
57 }
58 
serialize(torch::serialize::OutputArchive & archive) const59 void AdamParamState::serialize(torch::serialize::OutputArchive& archive) const {
60   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
61   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
62   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
63   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_exp_avg_sq);
64 }
65 
serialize(torch::serialize::InputArchive & archive)66 void AdamParamState::serialize(torch::serialize::InputArchive& archive) {
67   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
68   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg);
69   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, exp_avg_sq);
70   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq);
71 }
72 
step(LossClosure closure)73 Tensor Adam::step(LossClosure closure) {
74   NoGradGuard no_grad;
75   Tensor loss = {};
76   if (closure != nullptr) {
77     at::AutoGradMode enable_grad(true);
78     loss = closure();
79   }
80   for (auto& group : param_groups_) {
81     for (auto& p : group.params()) {
82       if (!p.grad().defined()) {
83         continue;
84       }
85       auto grad = p.grad();
86       TORCH_CHECK(!grad.is_sparse(), "Adam does not support sparse gradients" /*, please consider SparseAdam instead*/);
87       auto param_state = state_.find(p.unsafeGetTensorImpl());
88       auto& options = static_cast<AdamOptions&>(group.options());
89 
90       // State initialization
91       if (param_state == state_.end()) {
92         auto state = std::make_unique<AdamParamState>();
93         state->step(0);
94         // Exponential moving average of gradient values
95         state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve));
96         // Exponential moving average of squared gradient values
97         state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
98         if (options.amsgrad()) {
99           // Maintains max of all exp. moving avg. of sq. grad. values
100           state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
101         }
102         state_[p.unsafeGetTensorImpl()] = std::move(state);
103       }
104 
105       auto& state =
106           static_cast<AdamParamState&>(*state_[p.unsafeGetTensorImpl()]);
107       auto& exp_avg = state.exp_avg();
108       auto& exp_avg_sq = state.exp_avg_sq();
109       auto& max_exp_avg_sq = state.max_exp_avg_sq();
110 
111       state.step(state.step() + 1);
112       auto beta1 = std::get<0>(options.betas());
113       auto beta2 = std::get<1>(options.betas());
114 
115       auto bias_correction1 = 1 - std::pow(beta1, state.step());
116       auto bias_correction2 = 1 - std::pow(beta2, state.step());
117 
118       if (options.weight_decay() != 0) {
119         grad = grad.add(p, options.weight_decay());
120       }
121 
122       // Decay the first and second moment running average coefficient
123       exp_avg.mul_(beta1).add_(grad, 1 - beta1);
124       exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2);
125 
126       Tensor denom;
127       if (options.amsgrad()) {
128         // Maintains the maximum of all 2nd moment running avg. till now
129         torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq);
130         // Use the max. for normalizing running avg. of gradient
131         denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2))
132                     .add_(options.eps());
133       } else {
134         denom =
135             (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps());
136       }
137 
138       auto step_size = options.lr() / bias_correction1;
139       p.addcdiv_(exp_avg, denom, -step_size);
140     }
141   }
142   return loss;
143 }
144 
save(serialize::OutputArchive & archive) const145 void Adam::save(serialize::OutputArchive& archive) const {
146   serialize(*this, archive);
147 }
148 
load(serialize::InputArchive & archive)149 void Adam::load(serialize::InputArchive& archive) {
150   IValue pytorch_version;
151   if (archive.try_read("pytorch_version", pytorch_version)) {
152     serialize(*this, archive);
153   } else { // deserializing archives saved in old format (prior to
154            // version 1.5.0)
155     TORCH_WARN(
156         "Your serialized Adam optimizer is still using the old serialization format. "
157         "You should re-save your Adam optimizer to use the new serialization format.");
158     std::vector<int64_t> step_buffers;
159     std::vector<at::Tensor> exp_average_buffers;
160     std::vector<at::Tensor> exp_average_sq_buffers;
161     std::vector<at::Tensor> max_exp_average_sq_buffers;
162     torch::optim::serialize(archive, "step_buffers", step_buffers);
163     torch::optim::serialize(
164         archive, "exp_average_buffers", exp_average_buffers);
165     torch::optim::serialize(
166         archive, "exp_average_sq_buffers", exp_average_sq_buffers);
167     torch::optim::serialize(
168         archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
169     // since there were no param_groups prior to version 1.5.0, assuming all
170     // tensors are now in one param_group
171     std::vector<Tensor> params = param_groups_.at(0).params();
172     for (const auto idx : c10::irange(step_buffers.size())) {
173       auto state = std::make_unique<AdamParamState>();
174       state->step(step_buffers.at(idx));
175       state->exp_avg(exp_average_buffers.at(idx));
176       state->exp_avg_sq(exp_average_sq_buffers.at(idx));
177       if (idx < max_exp_average_sq_buffers.size()) {
178         state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
179       }
180       state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
181     }
182   }
183 }
184 } // namespace optim
185 } // namespace torch
186