xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/optimizer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/flat_hash_map.h>
6 
7 #include <torch/arg.h>
8 #include <torch/csrc/Export.h>
9 
10 #include <algorithm>
11 #include <functional>
12 #include <iterator>
13 #include <memory>
14 #include <string>
15 #include <vector>
16 
17 // Forward declarations confuse Doxygen
18 #ifndef DOXYGEN_SHOULD_SKIP_THIS
19 namespace at {
20 class Tensor;
21 } // namespace at
22 
23 namespace torch {
24 using at::Tensor;
25 namespace serialize {
26 class OutputArchive;
27 class InputArchive;
28 } // namespace serialize
29 } // namespace torch
30 #endif // DOXYGEN_SHOULD_SKIP_THIS
31 
32 namespace torch {
33 namespace optim {
34 
35 class TORCH_API OptimizerParamState {
36  public:
37   OptimizerParamState() = default;
38   OptimizerParamState(const OptimizerParamState&) = default;
39   OptimizerParamState& operator=(const OptimizerParamState&) = default;
40   OptimizerParamState(OptimizerParamState&&) noexcept = default;
41   OptimizerParamState& operator=(OptimizerParamState&&) noexcept = default;
42   virtual std::unique_ptr<OptimizerParamState> clone() const;
43   virtual void serialize(torch::serialize::InputArchive& archive);
44   virtual void serialize(torch::serialize::OutputArchive& archive) const;
45   virtual ~OptimizerParamState() = default;
46 };
47 
48 template <typename Derived>
49 class OptimizerCloneableParamState : public OptimizerParamState {
clone()50   std::unique_ptr<OptimizerParamState> clone() const override {
51     return std::make_unique<Derived>(static_cast<const Derived&>(*this));
52   }
53 };
54 
55 class TORCH_API OptimizerOptions {
56  public:
57   OptimizerOptions() = default;
58   OptimizerOptions(const OptimizerOptions&) = default;
59   OptimizerOptions& operator=(const OptimizerOptions&) = default;
60   OptimizerOptions(OptimizerOptions&&) noexcept = default;
61   OptimizerOptions& operator=(OptimizerOptions&&) noexcept = default;
62   virtual std::unique_ptr<OptimizerOptions> clone() const;
63   virtual void serialize(torch::serialize::InputArchive& archive);
64   virtual void serialize(torch::serialize::OutputArchive& archive) const;
65   virtual ~OptimizerOptions() = default;
66   virtual double get_lr() const;
67   virtual void set_lr(const double lr);
68 };
69 
70 template <typename Derived>
71 class OptimizerCloneableOptions : public OptimizerOptions {
72  private:
clone()73   std::unique_ptr<OptimizerOptions> clone() const override {
74     return std::make_unique<Derived>(static_cast<const Derived&>(*this));
75   }
76 };
77 
78 /// Stores parameters in the param_group and stores a pointer to the
79 /// OptimizerOptions
80 class TORCH_API OptimizerParamGroup {
81  public:
82   // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to
83   // be copy-constructible.
OptimizerParamGroup(const OptimizerParamGroup & param_group)84   OptimizerParamGroup(const OptimizerParamGroup& param_group)
85       : params_(param_group.params()),
86         options_(
87             param_group.has_options() ? param_group.options().clone()
88                                       : nullptr) {}
OptimizerParamGroup(std::vector<Tensor> params)89   OptimizerParamGroup(std::vector<Tensor> params)
90       : params_(std::move(params)) {}
OptimizerParamGroup(std::vector<Tensor> params,std::unique_ptr<OptimizerOptions> options)91   OptimizerParamGroup(
92       std::vector<Tensor> params,
93       std::unique_ptr<OptimizerOptions> options)
94       : params_(std::move(params)), options_(std::move(options)) {}
95 
96   OptimizerParamGroup& operator=(const OptimizerParamGroup& param_group) =
97       delete;
98   bool has_options() const;
99   OptimizerOptions& options();
100   const OptimizerOptions& options() const;
101   void set_options(std::unique_ptr<OptimizerOptions> options);
102   std::vector<Tensor>& params();
103   const std::vector<Tensor>& params() const;
104 
105  protected:
106   std::vector<Tensor> params_;
107   std::unique_ptr<OptimizerOptions> options_;
108 };
109 
110 class TORCH_API Optimizer {
111  public:
112   // The copy constructor is deleted, because the user should use the
113   // `state_dict` / `load_state_dict` API to copy an optimizer instead.
114   Optimizer(const Optimizer& optimizer) = delete;
115   Optimizer(Optimizer&& optimizer) = default;
116 
Optimizer(std::vector<OptimizerParamGroup> param_groups,std::unique_ptr<OptimizerOptions> defaults)117   explicit Optimizer(
118       std::vector<OptimizerParamGroup> param_groups,
119       std::unique_ptr<OptimizerOptions> defaults)
120       : defaults_(std::move(defaults)) {
121     for (const auto& param_group : param_groups) {
122       add_param_group(param_group);
123     }
124   }
125 
126   /// Constructs the `Optimizer` from a vector of parameters.
Optimizer(std::vector<Tensor> parameters,std::unique_ptr<OptimizerOptions> defaults)127   explicit Optimizer(
128       std::vector<Tensor> parameters,
129       std::unique_ptr<OptimizerOptions> defaults)
130       : Optimizer(
131             {OptimizerParamGroup(std::move(parameters))},
132             std::move(defaults)){};
133 
134   /// Adds the given param_group to the optimizer's param_group list.
135   void add_param_group(const OptimizerParamGroup& param_group);
136 
137   virtual ~Optimizer() = default;
138 
139   using LossClosure = std::function<Tensor()>;
140   /// A loss function closure, which is expected to return the loss value.
141   virtual Tensor step(LossClosure closure = nullptr) = 0;
142 
143   /// Adds the given vector of parameters to the optimizer's parameter list.
144   void add_parameters(const std::vector<Tensor>& parameters);
145 
146   /// Zeros out the gradients of all parameters.
147   void zero_grad(bool set_to_none = true);
148 
149   /// Provides a const reference to the parameters in the first param_group this
150   /// optimizer holds.
151   const std::vector<Tensor>& parameters() const noexcept;
152 
153   /// Provides a reference to the parameters in the first param_group this
154   /// optimizer holds.
155   std::vector<Tensor>& parameters() noexcept;
156 
157   /// Returns the number of parameters referenced by the optimizer.
158   size_t size() const noexcept;
159 
160   OptimizerOptions& defaults() noexcept;
161 
162   const OptimizerOptions& defaults() const noexcept;
163 
164   /// Provides a reference to the param_groups this optimizer holds.
165   std::vector<OptimizerParamGroup>& param_groups() noexcept;
166 
167   /// Provides a const reference to the param_groups this optimizer holds.
168   const std::vector<OptimizerParamGroup>& param_groups() const noexcept;
169 
170   /// Provides a reference to the state this optimizer holds
171   ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
172   state() noexcept;
173 
174   /// Provides a const reference to the state this optimizer holds
175   const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state()
176       const noexcept;
177 
178   /// Serializes the optimizer state into the given `archive`.
179   virtual void save(serialize::OutputArchive& archive) const;
180 
181   /// Deserializes the optimizer state from the given `archive`.
182   virtual void load(serialize::InputArchive& archive);
183 
184  protected:
185   std::vector<OptimizerParamGroup> param_groups_;
186   ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;
187   std::unique_ptr<OptimizerOptions> defaults_;
188 };
189 
190 /* How do we decide whether to serialize undefined tensors or
191   std::nullopt values into the output archive?
192 Answer: we strictly follow the behavior of Python API. To be more specific:
193 
194 For optimizer options:
195 a) For undefined tensor: currently no tensor is used as an options argument in
196 Python API, so we don't need to worry about it now. b) For std::nullopt value:
197 we serialize std::nullopt values into the output archive, to follow the exact
198 same behavior as Python API.
199 
200 For optimizer param state:
201 a) For undefined tensor: in param state, undefined tensor in C++ impl is
202 equivalent to missing key in Python impl. Since we don't serialize missing keys
203 in Python API, we skip undefined tensors when serializing the param state. b)
204 For std::nullopt value: in param state, std::nullopt value in C++ impl is
205 equivalent to missing key in Python impl. Since we don't serialize missing keys
206 in Python API, we skip std::nullopt values when serializing the param state. */
207 
208 /// Serializes an `Optimizer` into an `OutputArchive`.
209 TORCH_API serialize::OutputArchive& operator<<(
210     serialize::OutputArchive& archive,
211     const Optimizer& optimizer);
212 
213 /// Deserializes a `Tensor` from an `InputArchive`.
214 TORCH_API serialize::InputArchive& operator>>(
215     serialize::InputArchive& archive,
216     Optimizer& optimizer);
217 
218 } // namespace optim
219 } // namespace torch
220