1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 /** 10 * SGD (stochastic gradient descent) optimizer to perform on-device training. 11 * This uses the gradients calculated in the backwards pass of the loss function 12 * and updates the parameters such that it minimizes the loss. 13 * 14 * This is similar to the Lite Interpreter implementation of the SGD optimizer, 15 * but without the dependency on ATen Tensors and autograd. 16 */ 17 #pragma once 18 19 #include <executorch/runtime/core/error.h> 20 #include <executorch/runtime/core/exec_aten/exec_aten.h> 21 #include <map> 22 #include <memory> 23 #include <unordered_map> 24 #include <vector> 25 26 namespace executorch { 27 namespace extension { 28 namespace training { 29 namespace optimizer { 30 31 /** 32 * SGD optimizer state. This keeps track of the state of a given parameter to 33 * be used in later epochs. 34 */ 35 class ET_EXPERIMENTAL SGDParamState { 36 public: 37 /** 38 * Constructs a new SGD param state. 39 * 40 * @param[in] momentum_buffer A tensor that stores the momentum at the last 41 * epoch. 42 */ SGDParamState(executorch::aten::Tensor & momentum_buffer)43 explicit SGDParamState(executorch::aten::Tensor& momentum_buffer) 44 : momentum_buffer_(momentum_buffer) {} 45 momentum_buffer()46 executorch::aten::Tensor& momentum_buffer() { 47 return momentum_buffer_; 48 } 49 50 private: 51 executorch::aten::Tensor momentum_buffer_; 52 }; 53 54 /** 55 * SGD optimizer options. This contains options for performing training on a 56 * param group, such as the learning rate. 57 */ 58 class ET_EXPERIMENTAL SGDOptions { 59 public: 60 /** 61 * Constructs a new SGD optimizer options. 62 * 63 * This is used for customizing the SGD optimizer for a given group of 64 * parameters. 65 * 66 * @param[in] lr The learning rate. This is the factor applied to the gradient 67 * calculated from the loss function and used to update the parameters. A 68 * lower learning rate will result in a smaller step towards the minimum of 69 * a loss function, and a higher learning rate will result in a larger step. 70 * @param[in] momentum The momentum value. This is a used to accelerate the 71 * update step by using the gradients from previous epochs. 72 * @param[in] dampening The dampening value. This is used in combination with 73 * momentum, and aims t o prevent the optimizer from taking steps that are 74 * too large when using the momentum. 75 * @param[in] weight_decay The weight decay value. This is used as a 76 * regularization technique and is used to subtract a small fraction of the 77 * weight's value from itself at each step. 78 * @param[in] nesterov Whether to use Nesterov momentum. If true, the 79 * optimizer uses the momentum of the current step and applies it to the 80 * training update. When false, the optimizer uses the momentum of the 81 * previous step and applies it to the training update. 82 */ 83 explicit SGDOptions( 84 double lr, 85 double momentum = 0, 86 double dampening = 0, 87 double weight_decay = 0, 88 bool nesterov = false) lr_(lr)89 : lr_(lr), 90 momentum_(momentum), 91 dampening_(dampening), 92 weight_decay_(weight_decay), 93 nesterov_(nesterov) {} 94 clone()95 std::unique_ptr<SGDOptions> clone() const { 96 return std::make_unique<SGDOptions>(static_cast<const SGDOptions&>(*this)); 97 } 98 lr()99 double lr() const { 100 return lr_; 101 } 102 momentum()103 double momentum() const { 104 return momentum_; 105 } 106 dampening()107 double dampening() const { 108 return dampening_; 109 } 110 weight_decay()111 double weight_decay() const { 112 return weight_decay_; 113 } 114 nesterov()115 bool nesterov() const { 116 return nesterov_; 117 } 118 119 private: 120 double lr_; 121 double momentum_; 122 double dampening_; 123 double weight_decay_; 124 bool nesterov_; 125 }; 126 127 /** 128 * SGD optimizer param group. This contains the parameters and 129 * the SGDOptions associated to it. 130 */ 131 class ET_EXPERIMENTAL SGDParamGroup { 132 public: 133 // NOTE: In order to store `SGDParamGroup` in a `std::vector`, it has 134 // to be copy-constructible. SGDParamGroup(const SGDParamGroup & param_group)135 SGDParamGroup(const SGDParamGroup& param_group) 136 : named_parameters_(param_group.named_parameters()), 137 options_( 138 param_group.has_options() ? param_group.options().clone() 139 : nullptr) {} 140 SGDParamGroup& operator=(const SGDParamGroup& param_group) { 141 this->named_parameters_ = param_group.named_parameters_; 142 this->options_ = 143 param_group.has_options() ? param_group.options().clone() : nullptr; 144 return *this; 145 } 146 147 /** 148 * Constructs a SGD param group. 149 * 150 * @param[in] named_parameters The parameters to be optimized and their fully 151 * qualified names. 152 */ SGDParamGroup(const std::map<executorch::aten::string_view,executorch::aten::Tensor> & named_parameters)153 /* implicit */ SGDParamGroup( 154 const std::map<executorch::aten::string_view, executorch::aten::Tensor>& 155 named_parameters) 156 : named_parameters_(named_parameters) {} SGDParamGroup(const std::map<executorch::aten::string_view,executorch::aten::Tensor> & named_parameters,std::unique_ptr<SGDOptions> options)157 SGDParamGroup( 158 const std::map<executorch::aten::string_view, executorch::aten::Tensor>& 159 named_parameters, 160 std::unique_ptr<SGDOptions> options) 161 : named_parameters_(named_parameters), options_(std::move(options)) {} 162 163 bool has_options() const; 164 SGDOptions& options(); 165 const SGDOptions& options() const; 166 void set_options(std::unique_ptr<SGDOptions> options); 167 const std::map<executorch::aten::string_view, executorch::aten::Tensor>& 168 named_parameters() const; 169 170 private: 171 std::map<executorch::aten::string_view, executorch::aten::Tensor> 172 named_parameters_; 173 std::unique_ptr<SGDOptions> options_; 174 }; 175 176 /** 177 * SGD optimizer class. This is responsible for performing the optimization 178 * step. 179 */ 180 class ET_EXPERIMENTAL SGD { 181 public: SGD(const std::vector<SGDParamGroup> & param_groups,SGDOptions defaults)182 explicit SGD( 183 const std::vector<SGDParamGroup>& param_groups, 184 SGDOptions defaults) 185 : defaults_(std::make_unique<SGDOptions>(defaults)) { 186 for (const auto& param_group : param_groups) { 187 add_param_group(param_group); 188 } 189 } 190 SGD(const std::map<executorch::aten::string_view,executorch::aten::Tensor> & named_parameters,SGDOptions defaults)191 explicit SGD( 192 const std::map<executorch::aten::string_view, executorch::aten::Tensor>& 193 named_parameters, 194 SGDOptions defaults) 195 : SGD({SGDParamGroup(named_parameters)}, defaults) {} 196 197 // Adds the given param_group to the optimizer's param_group list. 198 void add_param_group(const SGDParamGroup& param_group); 199 200 ~SGD(); 201 202 /** 203 * Performs the optimization step. 204 * 205 * @param[in] named_gradients The gradients of the tensors specified by the 206 * fully qualified name. 207 */ 208 ::executorch::runtime::Error step( 209 const std::map<executorch::aten::string_view, executorch::aten::Tensor>& 210 named_gradients); 211 212 private: 213 std::vector<SGDParamGroup> param_groups_; 214 std::unordered_map<void*, std::unique_ptr<SGDParamState>> state_; 215 std::unique_ptr<SGDOptions> defaults_; 216 }; 217 218 } // namespace optimizer 219 } // namespace training 220 } // namespace extension 221 } // namespace executorch 222