1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <c10/util/Exception.h> 5 #include <c10/util/flat_hash_map.h> 6 7 #include <torch/arg.h> 8 #include <torch/csrc/Export.h> 9 10 #include <algorithm> 11 #include <functional> 12 #include <iterator> 13 #include <memory> 14 #include <string> 15 #include <vector> 16 17 // Forward declarations confuse Doxygen 18 #ifndef DOXYGEN_SHOULD_SKIP_THIS 19 namespace at { 20 class Tensor; 21 } // namespace at 22 23 namespace torch { 24 using at::Tensor; 25 namespace serialize { 26 class OutputArchive; 27 class InputArchive; 28 } // namespace serialize 29 } // namespace torch 30 #endif // DOXYGEN_SHOULD_SKIP_THIS 31 32 namespace torch { 33 namespace optim { 34 35 class TORCH_API OptimizerParamState { 36 public: 37 OptimizerParamState() = default; 38 OptimizerParamState(const OptimizerParamState&) = default; 39 OptimizerParamState& operator=(const OptimizerParamState&) = default; 40 OptimizerParamState(OptimizerParamState&&) noexcept = default; 41 OptimizerParamState& operator=(OptimizerParamState&&) noexcept = default; 42 virtual std::unique_ptr<OptimizerParamState> clone() const; 43 virtual void serialize(torch::serialize::InputArchive& archive); 44 virtual void serialize(torch::serialize::OutputArchive& archive) const; 45 virtual ~OptimizerParamState() = default; 46 }; 47 48 template <typename Derived> 49 class OptimizerCloneableParamState : public OptimizerParamState { clone()50 std::unique_ptr<OptimizerParamState> clone() const override { 51 return std::make_unique<Derived>(static_cast<const Derived&>(*this)); 52 } 53 }; 54 55 class TORCH_API OptimizerOptions { 56 public: 57 OptimizerOptions() = default; 58 OptimizerOptions(const OptimizerOptions&) = default; 59 OptimizerOptions& operator=(const OptimizerOptions&) = default; 60 OptimizerOptions(OptimizerOptions&&) noexcept = default; 61 OptimizerOptions& operator=(OptimizerOptions&&) noexcept = default; 62 virtual std::unique_ptr<OptimizerOptions> clone() const; 63 virtual void serialize(torch::serialize::InputArchive& archive); 64 virtual void serialize(torch::serialize::OutputArchive& archive) const; 65 virtual ~OptimizerOptions() = default; 66 virtual double get_lr() const; 67 virtual void set_lr(const double lr); 68 }; 69 70 template <typename Derived> 71 class OptimizerCloneableOptions : public OptimizerOptions { 72 private: clone()73 std::unique_ptr<OptimizerOptions> clone() const override { 74 return std::make_unique<Derived>(static_cast<const Derived&>(*this)); 75 } 76 }; 77 78 /// Stores parameters in the param_group and stores a pointer to the 79 /// OptimizerOptions 80 class TORCH_API OptimizerParamGroup { 81 public: 82 // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to 83 // be copy-constructible. OptimizerParamGroup(const OptimizerParamGroup & param_group)84 OptimizerParamGroup(const OptimizerParamGroup& param_group) 85 : params_(param_group.params()), 86 options_( 87 param_group.has_options() ? param_group.options().clone() 88 : nullptr) {} OptimizerParamGroup(std::vector<Tensor> params)89 OptimizerParamGroup(std::vector<Tensor> params) 90 : params_(std::move(params)) {} OptimizerParamGroup(std::vector<Tensor> params,std::unique_ptr<OptimizerOptions> options)91 OptimizerParamGroup( 92 std::vector<Tensor> params, 93 std::unique_ptr<OptimizerOptions> options) 94 : params_(std::move(params)), options_(std::move(options)) {} 95 96 OptimizerParamGroup& operator=(const OptimizerParamGroup& param_group) = 97 delete; 98 bool has_options() const; 99 OptimizerOptions& options(); 100 const OptimizerOptions& options() const; 101 void set_options(std::unique_ptr<OptimizerOptions> options); 102 std::vector<Tensor>& params(); 103 const std::vector<Tensor>& params() const; 104 105 protected: 106 std::vector<Tensor> params_; 107 std::unique_ptr<OptimizerOptions> options_; 108 }; 109 110 class TORCH_API Optimizer { 111 public: 112 // The copy constructor is deleted, because the user should use the 113 // `state_dict` / `load_state_dict` API to copy an optimizer instead. 114 Optimizer(const Optimizer& optimizer) = delete; 115 Optimizer(Optimizer&& optimizer) = default; 116 Optimizer(std::vector<OptimizerParamGroup> param_groups,std::unique_ptr<OptimizerOptions> defaults)117 explicit Optimizer( 118 std::vector<OptimizerParamGroup> param_groups, 119 std::unique_ptr<OptimizerOptions> defaults) 120 : defaults_(std::move(defaults)) { 121 for (const auto& param_group : param_groups) { 122 add_param_group(param_group); 123 } 124 } 125 126 /// Constructs the `Optimizer` from a vector of parameters. Optimizer(std::vector<Tensor> parameters,std::unique_ptr<OptimizerOptions> defaults)127 explicit Optimizer( 128 std::vector<Tensor> parameters, 129 std::unique_ptr<OptimizerOptions> defaults) 130 : Optimizer( 131 {OptimizerParamGroup(std::move(parameters))}, 132 std::move(defaults)){}; 133 134 /// Adds the given param_group to the optimizer's param_group list. 135 void add_param_group(const OptimizerParamGroup& param_group); 136 137 virtual ~Optimizer() = default; 138 139 using LossClosure = std::function<Tensor()>; 140 /// A loss function closure, which is expected to return the loss value. 141 virtual Tensor step(LossClosure closure = nullptr) = 0; 142 143 /// Adds the given vector of parameters to the optimizer's parameter list. 144 void add_parameters(const std::vector<Tensor>& parameters); 145 146 /// Zeros out the gradients of all parameters. 147 void zero_grad(bool set_to_none = true); 148 149 /// Provides a const reference to the parameters in the first param_group this 150 /// optimizer holds. 151 const std::vector<Tensor>& parameters() const noexcept; 152 153 /// Provides a reference to the parameters in the first param_group this 154 /// optimizer holds. 155 std::vector<Tensor>& parameters() noexcept; 156 157 /// Returns the number of parameters referenced by the optimizer. 158 size_t size() const noexcept; 159 160 OptimizerOptions& defaults() noexcept; 161 162 const OptimizerOptions& defaults() const noexcept; 163 164 /// Provides a reference to the param_groups this optimizer holds. 165 std::vector<OptimizerParamGroup>& param_groups() noexcept; 166 167 /// Provides a const reference to the param_groups this optimizer holds. 168 const std::vector<OptimizerParamGroup>& param_groups() const noexcept; 169 170 /// Provides a reference to the state this optimizer holds 171 ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& 172 state() noexcept; 173 174 /// Provides a const reference to the state this optimizer holds 175 const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state() 176 const noexcept; 177 178 /// Serializes the optimizer state into the given `archive`. 179 virtual void save(serialize::OutputArchive& archive) const; 180 181 /// Deserializes the optimizer state from the given `archive`. 182 virtual void load(serialize::InputArchive& archive); 183 184 protected: 185 std::vector<OptimizerParamGroup> param_groups_; 186 ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_; 187 std::unique_ptr<OptimizerOptions> defaults_; 188 }; 189 190 /* How do we decide whether to serialize undefined tensors or 191 std::nullopt values into the output archive? 192 Answer: we strictly follow the behavior of Python API. To be more specific: 193 194 For optimizer options: 195 a) For undefined tensor: currently no tensor is used as an options argument in 196 Python API, so we don't need to worry about it now. b) For std::nullopt value: 197 we serialize std::nullopt values into the output archive, to follow the exact 198 same behavior as Python API. 199 200 For optimizer param state: 201 a) For undefined tensor: in param state, undefined tensor in C++ impl is 202 equivalent to missing key in Python impl. Since we don't serialize missing keys 203 in Python API, we skip undefined tensors when serializing the param state. b) 204 For std::nullopt value: in param state, std::nullopt value in C++ impl is 205 equivalent to missing key in Python impl. Since we don't serialize missing keys 206 in Python API, we skip std::nullopt values when serializing the param state. */ 207 208 /// Serializes an `Optimizer` into an `OutputArchive`. 209 TORCH_API serialize::OutputArchive& operator<<( 210 serialize::OutputArchive& archive, 211 const Optimizer& optimizer); 212 213 /// Deserializes a `Tensor` from an `InputArchive`. 214 TORCH_API serialize::InputArchive& operator>>( 215 serialize::InputArchive& archive, 216 Optimizer& optimizer); 217 218 } // namespace optim 219 } // namespace torch 220