1 #include <torch/optim/optimizer.h>
2
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/types.h>
5
6 #include <string>
7 #include <utility>
8 #include <vector>
9
10 namespace torch {
11 namespace optim {
12
has_options() const13 bool OptimizerParamGroup::has_options() const {
14 return options_ != nullptr;
15 }
16
options()17 OptimizerOptions& OptimizerParamGroup::options() {
18 TORCH_CHECK(has_options());
19 return *options_.get();
20 }
21
options() const22 const OptimizerOptions& OptimizerParamGroup::options() const {
23 TORCH_CHECK(has_options());
24 return *options_.get();
25 }
26
set_options(std::unique_ptr<OptimizerOptions> options)27 void OptimizerParamGroup::set_options(
28 std::unique_ptr<OptimizerOptions> options) {
29 options_ = std::move(options);
30 }
31
params()32 std::vector<Tensor>& OptimizerParamGroup::params() {
33 return params_;
34 }
35
params() const36 const std::vector<Tensor>& OptimizerParamGroup::params() const {
37 return params_;
38 }
39
clone() const40 std::unique_ptr<OptimizerParamState> OptimizerParamState::clone() const {
41 TORCH_CHECK(
42 false,
43 "clone() has not been implemented for torch::optim::OptimizerParamState. ",
44 "Subclass torch::optim::OptimizerCloneableParamState<YourOptimizerParamState> ",
45 "instead of torch::optim::OptimizerParamState to inherit the ability to clone.");
46 }
47
serialize(torch::serialize::InputArchive & archive)48 void OptimizerParamState::serialize(torch::serialize::InputArchive& archive) {
49 TORCH_CHECK(
50 false,
51 "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ",
52 "You must override it in your subclass of torch::optim::OptimizerCloneableParamState<YourOptimizerParamState>.");
53 }
54
serialize(torch::serialize::OutputArchive & archive) const55 void OptimizerParamState::serialize(
56 torch::serialize::OutputArchive& archive) const {
57 TORCH_CHECK(
58 false,
59 "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ",
60 "You must override it in your subclass of torch::optim::OptimizerCloneableParamState<YourOptimizerParamState>.");
61 }
62
get_lr() const63 double OptimizerOptions::get_lr() const {
64 TORCH_CHECK(
65 false,
66 "double get_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass.");
67 }
68
set_lr(const double lr)69 void OptimizerOptions::set_lr(const double lr) {
70 TORCH_CHECK(
71 false,
72 "double set_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass.");
73 }
74
clone() const75 std::unique_ptr<OptimizerOptions> OptimizerOptions::clone() const {
76 TORCH_CHECK(
77 false,
78 "clone() has not been implemented for torch::optim::OptimizerOptions. ",
79 "Subclass torch::optim::OptimizerCloneableOptions<YourOptimizerOptions> ",
80 "instead of torch::optim::OptimizerOptions to inherit the ability to clone.");
81 }
82
serialize(torch::serialize::InputArchive & archive)83 void OptimizerOptions::serialize(torch::serialize::InputArchive& archive) {
84 TORCH_CHECK(
85 false,
86 "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ",
87 "You must override it in your subclass of torch::optim::OptimizerCloneableOptions<YourOptimizerOptions>.");
88 }
89
serialize(torch::serialize::OutputArchive & archive) const90 void OptimizerOptions::serialize(
91 torch::serialize::OutputArchive& archive) const {
92 TORCH_CHECK(
93 false,
94 "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ",
95 "You must override it in your subclass of torch::optim::OptimizerCloneableOptions<YourOptimizerOptions>.");
96 }
97
add_param_group(const OptimizerParamGroup & param_group)98 void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
99 for (const auto& param : param_group.params()) {
100 TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
101 }
102 TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
103 OptimizerParamGroup param_group_(param_group.params());
104 if (!param_group.has_options()) {
105 param_group_.set_options(defaults_->clone());
106 } else {
107 param_group_.set_options(param_group.options().clone());
108 }
109 for (const auto& p : param_group_.params()) {
110 TORCH_CHECK(
111 state_.count(p.unsafeGetTensorImpl()) == 0,
112 "some parameters appear in more than one parameter group");
113 }
114 param_groups_.emplace_back(std::move(param_group_));
115 }
116
add_parameters(const std::vector<Tensor> & parameters)117 void Optimizer::add_parameters(const std::vector<Tensor>& parameters) {
118 TORCH_WARN("Optimizer::add_parameters() will be removed in PyTorch 1.6");
119 auto& parameters_ = param_groups_[0].params();
120 parameters_.insert(parameters_.end(), parameters.begin(), parameters.end());
121 }
122
zero_grad(bool set_to_none)123 void Optimizer::zero_grad(bool set_to_none) {
124 for (auto& group : param_groups_) {
125 for (auto& p : group.params()) {
126 if (p.mutable_grad().defined()) {
127 p.mutable_grad().detach_();
128 if (set_to_none)
129 p.mutable_grad().reset();
130 else
131 p.mutable_grad().zero_();
132 }
133 }
134 }
135 }
136
parameters() const137 const std::vector<Tensor>& Optimizer::parameters() const noexcept {
138 TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6");
139 return param_groups_.at(0).params();
140 }
141
parameters()142 std::vector<Tensor>& Optimizer::parameters() noexcept {
143 TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6");
144 return param_groups_.at(0).params();
145 }
146
size() const147 size_t Optimizer::size() const noexcept {
148 TORCH_WARN("Optimizer::size() will be removed in PyTorch 1.6");
149 size_t count = 0;
150 for (const auto& group : param_groups_) {
151 count += group.params().size();
152 }
153 return count;
154 }
155
defaults()156 OptimizerOptions& Optimizer::defaults() noexcept {
157 return *defaults_.get();
158 }
159
defaults() const160 const OptimizerOptions& Optimizer::defaults() const noexcept {
161 return *defaults_.get();
162 }
163
param_groups()164 std::vector<OptimizerParamGroup>& Optimizer::param_groups() noexcept {
165 return param_groups_;
166 }
167
param_groups() const168 const std::vector<OptimizerParamGroup>& Optimizer::param_groups()
169 const noexcept {
170 return param_groups_;
171 }
172
173 ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& Optimizer::
state()174 state() noexcept {
175 return state_;
176 }
177
178 const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
state() const179 Optimizer::state() const noexcept {
180 return state_;
181 }
182
save(serialize::OutputArchive & archive) const183 void Optimizer::save(serialize::OutputArchive& archive) const {}
load(serialize::InputArchive & archive)184 void Optimizer::load(serialize::InputArchive& archive) {}
185
186 /// Serializes an `Optimizer` into an `OutputArchive`.
operator <<(serialize::OutputArchive & archive,const Optimizer & optimizer)187 serialize::OutputArchive& operator<<(
188 serialize::OutputArchive& archive,
189 const Optimizer& optimizer) {
190 optimizer.save(archive);
191 return archive;
192 }
193
194 /// Deserializes a `Tensor` from an `InputArchive`.
operator >>(serialize::InputArchive & archive,Optimizer & optimizer)195 serialize::InputArchive& operator>>(
196 serialize::InputArchive& archive,
197 Optimizer& optimizer) {
198 optimizer.load(archive);
199 return archive;
200 }
201
202 } // namespace optim
203 } // namespace torch
204