xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/optimizer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/optimizer.h>
2 
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/types.h>
5 
6 #include <string>
7 #include <utility>
8 #include <vector>
9 
10 namespace torch {
11 namespace optim {
12 
has_options() const13 bool OptimizerParamGroup::has_options() const {
14   return options_ != nullptr;
15 }
16 
options()17 OptimizerOptions& OptimizerParamGroup::options() {
18   TORCH_CHECK(has_options());
19   return *options_.get();
20 }
21 
options() const22 const OptimizerOptions& OptimizerParamGroup::options() const {
23   TORCH_CHECK(has_options());
24   return *options_.get();
25 }
26 
set_options(std::unique_ptr<OptimizerOptions> options)27 void OptimizerParamGroup::set_options(
28     std::unique_ptr<OptimizerOptions> options) {
29   options_ = std::move(options);
30 }
31 
params()32 std::vector<Tensor>& OptimizerParamGroup::params() {
33   return params_;
34 }
35 
params() const36 const std::vector<Tensor>& OptimizerParamGroup::params() const {
37   return params_;
38 }
39 
clone() const40 std::unique_ptr<OptimizerParamState> OptimizerParamState::clone() const {
41   TORCH_CHECK(
42       false,
43       "clone() has not been implemented for torch::optim::OptimizerParamState. ",
44       "Subclass torch::optim::OptimizerCloneableParamState<YourOptimizerParamState> ",
45       "instead of torch::optim::OptimizerParamState to inherit the ability to clone.");
46 }
47 
serialize(torch::serialize::InputArchive & archive)48 void OptimizerParamState::serialize(torch::serialize::InputArchive& archive) {
49   TORCH_CHECK(
50       false,
51       "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ",
52       "You must override it in your subclass of torch::optim::OptimizerCloneableParamState<YourOptimizerParamState>.");
53 }
54 
serialize(torch::serialize::OutputArchive & archive) const55 void OptimizerParamState::serialize(
56     torch::serialize::OutputArchive& archive) const {
57   TORCH_CHECK(
58       false,
59       "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ",
60       "You must override it in your subclass of torch::optim::OptimizerCloneableParamState<YourOptimizerParamState>.");
61 }
62 
get_lr() const63 double OptimizerOptions::get_lr() const {
64   TORCH_CHECK(
65       false,
66       "double get_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass.");
67 }
68 
set_lr(const double lr)69 void OptimizerOptions::set_lr(const double lr) {
70   TORCH_CHECK(
71       false,
72       "double set_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass.");
73 }
74 
clone() const75 std::unique_ptr<OptimizerOptions> OptimizerOptions::clone() const {
76   TORCH_CHECK(
77       false,
78       "clone() has not been implemented for torch::optim::OptimizerOptions. ",
79       "Subclass torch::optim::OptimizerCloneableOptions<YourOptimizerOptions> ",
80       "instead of torch::optim::OptimizerOptions to inherit the ability to clone.");
81 }
82 
serialize(torch::serialize::InputArchive & archive)83 void OptimizerOptions::serialize(torch::serialize::InputArchive& archive) {
84   TORCH_CHECK(
85       false,
86       "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ",
87       "You must override it in your subclass of torch::optim::OptimizerCloneableOptions<YourOptimizerOptions>.");
88 }
89 
serialize(torch::serialize::OutputArchive & archive) const90 void OptimizerOptions::serialize(
91     torch::serialize::OutputArchive& archive) const {
92   TORCH_CHECK(
93       false,
94       "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ",
95       "You must override it in your subclass of torch::optim::OptimizerCloneableOptions<YourOptimizerOptions>.");
96 }
97 
add_param_group(const OptimizerParamGroup & param_group)98 void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
99   for (const auto& param : param_group.params()) {
100     TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
101   }
102   TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
103   OptimizerParamGroup param_group_(param_group.params());
104   if (!param_group.has_options()) {
105     param_group_.set_options(defaults_->clone());
106   } else {
107     param_group_.set_options(param_group.options().clone());
108   }
109   for (const auto& p : param_group_.params()) {
110     TORCH_CHECK(
111         state_.count(p.unsafeGetTensorImpl()) == 0,
112         "some parameters appear in more than one parameter group");
113   }
114   param_groups_.emplace_back(std::move(param_group_));
115 }
116 
add_parameters(const std::vector<Tensor> & parameters)117 void Optimizer::add_parameters(const std::vector<Tensor>& parameters) {
118   TORCH_WARN("Optimizer::add_parameters() will be removed in PyTorch 1.6");
119   auto& parameters_ = param_groups_[0].params();
120   parameters_.insert(parameters_.end(), parameters.begin(), parameters.end());
121 }
122 
zero_grad(bool set_to_none)123 void Optimizer::zero_grad(bool set_to_none) {
124   for (auto& group : param_groups_) {
125     for (auto& p : group.params()) {
126       if (p.mutable_grad().defined()) {
127         p.mutable_grad().detach_();
128         if (set_to_none)
129           p.mutable_grad().reset();
130         else
131           p.mutable_grad().zero_();
132       }
133     }
134   }
135 }
136 
parameters() const137 const std::vector<Tensor>& Optimizer::parameters() const noexcept {
138   TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6");
139   return param_groups_.at(0).params();
140 }
141 
parameters()142 std::vector<Tensor>& Optimizer::parameters() noexcept {
143   TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6");
144   return param_groups_.at(0).params();
145 }
146 
size() const147 size_t Optimizer::size() const noexcept {
148   TORCH_WARN("Optimizer::size() will be removed in PyTorch 1.6");
149   size_t count = 0;
150   for (const auto& group : param_groups_) {
151     count += group.params().size();
152   }
153   return count;
154 }
155 
defaults()156 OptimizerOptions& Optimizer::defaults() noexcept {
157   return *defaults_.get();
158 }
159 
defaults() const160 const OptimizerOptions& Optimizer::defaults() const noexcept {
161   return *defaults_.get();
162 }
163 
param_groups()164 std::vector<OptimizerParamGroup>& Optimizer::param_groups() noexcept {
165   return param_groups_;
166 }
167 
param_groups() const168 const std::vector<OptimizerParamGroup>& Optimizer::param_groups()
169     const noexcept {
170   return param_groups_;
171 }
172 
173 ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& Optimizer::
state()174     state() noexcept {
175   return state_;
176 }
177 
178 const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
state() const179 Optimizer::state() const noexcept {
180   return state_;
181 }
182 
save(serialize::OutputArchive & archive) const183 void Optimizer::save(serialize::OutputArchive& archive) const {}
load(serialize::InputArchive & archive)184 void Optimizer::load(serialize::InputArchive& archive) {}
185 
186 /// Serializes an `Optimizer` into an `OutputArchive`.
operator <<(serialize::OutputArchive & archive,const Optimizer & optimizer)187 serialize::OutputArchive& operator<<(
188     serialize::OutputArchive& archive,
189     const Optimizer& optimizer) {
190   optimizer.save(archive);
191   return archive;
192 }
193 
194 /// Deserializes a `Tensor` from an `InputArchive`.
operator >>(serialize::InputArchive & archive,Optimizer & optimizer)195 serialize::InputArchive& operator>>(
196     serialize::InputArchive& archive,
197     Optimizer& optimizer) {
198   optimizer.load(archive);
199   return archive;
200 }
201 
202 } // namespace optim
203 } // namespace torch
204