1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/types.h> 5 6 #include <utility> 7 #include <vector> 8 9 namespace torch::jit::mobile { 10 11 class SGDParamState { 12 TORCH_ARG(torch::Tensor, momentum_buffer); 13 14 public: clone()15 std::unique_ptr<SGDParamState> clone() const { 16 return std::make_unique<SGDParamState>( 17 static_cast<const SGDParamState&>(*this)); 18 } 19 friend bool operator==(const SGDParamState& lhs, const SGDParamState& rhs); 20 }; 21 22 struct TORCH_API SGDOptions { 23 /* implicit */ SGDOptions(double lr); 24 TORCH_ARG(double, lr); 25 TORCH_ARG(double, momentum) = 0; 26 TORCH_ARG(double, dampening) = 0; 27 TORCH_ARG(double, weight_decay) = 0; 28 TORCH_ARG(bool, nesterov) = false; 29 30 public: cloneSGDOptions31 std::unique_ptr<SGDOptions> clone() const { 32 return std::make_unique<SGDOptions>(static_cast<const SGDOptions&>(*this)); 33 } 34 TORCH_API friend bool operator==( 35 const SGDOptions& lhs, 36 const SGDOptions& rhs); 37 }; 38 39 /// Stores parameters in the param_group and stores a pointer to the SGDOptions 40 class TORCH_API SGDParamGroup { 41 public: 42 // NOTE: In order to store `SGDParamGroup` in a `std::vector`, it has to be 43 // copy-constructible. SGDParamGroup(const SGDParamGroup & param_group)44 SGDParamGroup(const SGDParamGroup& param_group) 45 : params_(param_group.params()), 46 options_( 47 param_group.has_options() ? param_group.options().clone() 48 : nullptr) {} 49 SGDParamGroup& operator=(const SGDParamGroup& param_group) { 50 this->params_ = param_group.params(); 51 this->options_ = 52 param_group.has_options() ? param_group.options().clone() : nullptr; 53 return *this; 54 } SGDParamGroup(std::vector<Tensor> params)55 /* implicit */ SGDParamGroup(std::vector<Tensor> params) 56 : params_(std::move(params)) {} SGDParamGroup(std::vector<Tensor> params,std::unique_ptr<SGDOptions> options)57 SGDParamGroup(std::vector<Tensor> params, std::unique_ptr<SGDOptions> options) 58 : params_(std::move(params)), options_(std::move(options)) {} 59 60 bool has_options() const; 61 SGDOptions& options(); 62 const SGDOptions& options() const; 63 void set_options(std::unique_ptr<SGDOptions> options); 64 std::vector<Tensor>& params(); 65 const std::vector<Tensor>& params() const; 66 67 protected: 68 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 69 std::vector<Tensor> params_; 70 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 71 std::unique_ptr<SGDOptions> options_; 72 }; 73 74 class TORCH_API SGD { 75 public: SGD(const std::vector<torch::jit::mobile::SGDParamGroup> & param_groups,SGDOptions defaults)76 explicit SGD( 77 const std::vector<torch::jit::mobile::SGDParamGroup>& param_groups, 78 SGDOptions defaults) 79 : defaults_(std::make_unique<SGDOptions>(defaults)) { 80 for (const auto& param_group : param_groups) { 81 add_param_group(param_group); 82 } 83 TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); 84 TORCH_CHECK( 85 defaults.momentum() >= 0, 86 "Invalid momentum value: ", 87 defaults.momentum()); 88 TORCH_CHECK( 89 defaults.weight_decay() >= 0, 90 "Invalid weight_decay value: ", 91 defaults.weight_decay()); 92 TORCH_CHECK( 93 !defaults.nesterov() || 94 (defaults.momentum() > 0 && defaults.dampening() == 0), 95 "Nesterov momentum requires a momentum and zero dampening"); 96 } 97 SGD(std::vector<Tensor> params,SGDOptions defaults)98 explicit SGD(std::vector<Tensor> params, SGDOptions defaults) 99 : SGD({SGDParamGroup(std::move(params))}, defaults) {} 100 101 /// Adds the given param_group to the optimizer's param_group list. 102 void add_param_group(const SGDParamGroup& param_group); 103 104 ~SGD() = default; 105 106 using LossClosure = std::function<Tensor()>; 107 /// A loss function closure, which is expected to return the loss value. 108 torch::Tensor step(const LossClosure& closure = nullptr); 109 110 /// Zeros out the gradients of all parameters. 111 void zero_grad(); 112 113 protected: 114 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 115 std::vector<SGDParamGroup> param_groups_; 116 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 117 ska::flat_hash_map<void*, std::unique_ptr<SGDParamState>> state_; 118 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 119 std::unique_ptr<SGDOptions> defaults_; 120 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 121 std::vector<Tensor> params_; 122 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 123 std::unique_ptr<SGDOptions> options_; 124 }; 125 } // namespace torch::jit::mobile 126