xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/sgd.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/sgd.h>
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/optim/serialize.h>
6 #include <torch/types.h>
7 #include <torch/utils.h>
8 
9 #include <ATen/ATen.h>
10 #include <c10/util/irange.h>
11 
12 #include <functional>
13 
14 namespace torch {
15 namespace optim {
16 
SGDOptions(double lr)17 SGDOptions::SGDOptions(double lr) : lr_(lr) {}
18 
operator ==(const SGDOptions & lhs,const SGDOptions & rhs)19 bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) {
20   return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) &&
21       (lhs.dampening() == rhs.dampening()) &&
22       (lhs.weight_decay() == rhs.weight_decay()) &&
23       (lhs.nesterov() == rhs.nesterov());
24 }
25 
serialize(torch::serialize::OutputArchive & archive) const26 void SGDOptions::serialize(torch::serialize::OutputArchive& archive) const {
27   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
28   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum);
29   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(dampening);
30   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
31   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(nesterov);
32 }
33 
serialize(torch::serialize::InputArchive & archive)34 void SGDOptions::serialize(torch::serialize::InputArchive& archive) {
35   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
36   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, momentum);
37   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, dampening);
38   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
39   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, nesterov);
40 }
41 
get_lr() const42 double SGDOptions::get_lr() const {
43   return lr();
44 }
45 
set_lr(const double lr)46 void SGDOptions::set_lr(const double lr) {
47   this->lr(lr);
48 }
49 
operator ==(const SGDParamState & lhs,const SGDParamState & rhs)50 bool operator==(const SGDParamState& lhs, const SGDParamState& rhs) {
51   return torch::equal(lhs.momentum_buffer(), rhs.momentum_buffer());
52 }
53 
serialize(torch::serialize::OutputArchive & archive) const54 void SGDParamState::serialize(torch::serialize::OutputArchive& archive) const {
55   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum_buffer);
56 }
57 
serialize(torch::serialize::InputArchive & archive)58 void SGDParamState::serialize(torch::serialize::InputArchive& archive) {
59   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, momentum_buffer);
60 }
61 
step(LossClosure closure)62 Tensor SGD::step(LossClosure closure) {
63   NoGradGuard no_grad;
64   Tensor loss = {};
65   if (closure != nullptr) {
66     at::AutoGradMode enable_grad(true);
67     loss = closure();
68   }
69   for (auto& group : param_groups_) {
70     auto& options = static_cast<SGDOptions&>(group.options());
71     auto weight_decay = options.weight_decay();
72     auto momentum = options.momentum();
73     auto dampening = options.dampening();
74     auto nesterov = options.nesterov();
75 
76     for (auto& p : group.params()) {
77       if (!p.grad().defined()) {
78         continue;
79       }
80       auto d_p = p.grad().data();
81       if (weight_decay != 0) {
82         d_p = d_p.add(p.data(), weight_decay);
83       }
84       if (momentum != 0) {
85         Tensor buf;
86         auto param_state = state_.find(p.unsafeGetTensorImpl());
87         if (param_state == state_.end()) {
88           buf = torch::clone(d_p).detach();
89           auto state = std::make_unique<SGDParamState>();
90           state->momentum_buffer(buf);
91           state_[p.unsafeGetTensorImpl()] = std::move(state);
92         } else {
93           buf = static_cast<SGDParamState&>(*param_state->second)
94                     .momentum_buffer();
95           buf.mul_(momentum).add_(d_p, 1 - dampening);
96         }
97         if (nesterov) {
98           d_p = d_p.add(buf, momentum);
99         } else {
100           d_p = buf;
101         }
102       }
103       p.data().add_(d_p, -1 * options.lr());
104     }
105   }
106   return loss;
107 }
108 
save(serialize::OutputArchive & archive) const109 void SGD::save(serialize::OutputArchive& archive) const {
110   serialize(*this, archive);
111 }
112 
load(serialize::InputArchive & archive)113 void SGD::load(serialize::InputArchive& archive) {
114   IValue pytorch_version;
115   if (archive.try_read("pytorch_version", pytorch_version)) {
116     serialize(*this, archive);
117   } else { // deserializing archives saved in old format (prior to
118            // version 1.5.0)
119     TORCH_WARN(
120         "Your serialized SGD optimizer is still using the old serialization format. "
121         "You should re-save your SGD optimizer to use the new serialization format.");
122     std::vector<Tensor> momentum_buffers;
123     torch::optim::serialize(archive, "momentum_buffers", momentum_buffers);
124     // since there were no param_groups prior to version 1.5.0, assuming all
125     // tensors are now in one param_group
126     std::vector<Tensor> params = param_groups_.at(0).params();
127     for (const auto idx : c10::irange(momentum_buffers.size())) {
128       auto state = std::make_unique<SGDParamState>();
129       state->momentum_buffer(momentum_buffers[idx]);
130       state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
131     }
132   }
133 }
134 } // namespace optim
135 } // namespace torch
136