xref: /aosp_15_r20/external/executorch/extension/training/optimizer/sgd.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /**
10  * SGD (stochastic gradient descent) optimizer to perform on-device training.
11  * This uses the gradients calculated in the backwards pass of the loss function
12  * and updates the parameters such that it minimizes the loss.
13  *
14  * This is similar to the Lite Interpreter implementation of the SGD optimizer,
15  * but without the dependency on ATen Tensors and autograd.
16  */
17 #pragma once
18 
19 #include <executorch/runtime/core/error.h>
20 #include <executorch/runtime/core/exec_aten/exec_aten.h>
21 #include <map>
22 #include <memory>
23 #include <unordered_map>
24 #include <vector>
25 
26 namespace executorch {
27 namespace extension {
28 namespace training {
29 namespace optimizer {
30 
31 /**
32  * SGD optimizer state. This keeps track of the state of a given parameter to
33  * be used in later epochs.
34  */
35 class ET_EXPERIMENTAL SGDParamState {
36  public:
37   /**
38    * Constructs a new SGD param state.
39    *
40    * @param[in] momentum_buffer A tensor that stores the momentum at the last
41    * epoch.
42    */
SGDParamState(executorch::aten::Tensor & momentum_buffer)43   explicit SGDParamState(executorch::aten::Tensor& momentum_buffer)
44       : momentum_buffer_(momentum_buffer) {}
45 
momentum_buffer()46   executorch::aten::Tensor& momentum_buffer() {
47     return momentum_buffer_;
48   }
49 
50  private:
51   executorch::aten::Tensor momentum_buffer_;
52 };
53 
54 /**
55  * SGD optimizer options. This contains options for performing training on a
56  * param group, such as the learning rate.
57  */
58 class ET_EXPERIMENTAL SGDOptions {
59  public:
60   /**
61    * Constructs a new SGD optimizer options.
62    *
63    * This is used for customizing the SGD optimizer for a given group of
64    * parameters.
65    *
66    * @param[in] lr The learning rate. This is the factor applied to the gradient
67    *   calculated from the loss function and used to update the parameters. A
68    *   lower learning rate will result in a smaller step towards the minimum of
69    * a loss function, and a higher learning rate will result in a larger step.
70    * @param[in] momentum The momentum value. This is a used to accelerate the
71    *   update step by using the gradients from previous epochs.
72    * @param[in] dampening The dampening value. This is used in combination with
73    *   momentum, and aims t o prevent the optimizer from taking steps that are
74    *   too large when using the momentum.
75    * @param[in] weight_decay The weight decay value. This is used as a
76    *   regularization technique and is used to subtract a small fraction of the
77    *   weight's value from itself at each step.
78    * @param[in] nesterov Whether to use Nesterov momentum. If true, the
79    *   optimizer uses the momentum of the current step and applies it to the
80    *   training update. When false, the optimizer uses the momentum of the
81    *   previous step and applies it to the training update.
82    */
83   explicit SGDOptions(
84       double lr,
85       double momentum = 0,
86       double dampening = 0,
87       double weight_decay = 0,
88       bool nesterov = false)
lr_(lr)89       : lr_(lr),
90         momentum_(momentum),
91         dampening_(dampening),
92         weight_decay_(weight_decay),
93         nesterov_(nesterov) {}
94 
clone()95   std::unique_ptr<SGDOptions> clone() const {
96     return std::make_unique<SGDOptions>(static_cast<const SGDOptions&>(*this));
97   }
98 
lr()99   double lr() const {
100     return lr_;
101   }
102 
momentum()103   double momentum() const {
104     return momentum_;
105   }
106 
dampening()107   double dampening() const {
108     return dampening_;
109   }
110 
weight_decay()111   double weight_decay() const {
112     return weight_decay_;
113   }
114 
nesterov()115   bool nesterov() const {
116     return nesterov_;
117   }
118 
119  private:
120   double lr_;
121   double momentum_;
122   double dampening_;
123   double weight_decay_;
124   bool nesterov_;
125 };
126 
127 /**
128  * SGD optimizer param group. This contains the parameters and
129  * the SGDOptions associated to it.
130  */
131 class ET_EXPERIMENTAL SGDParamGroup {
132  public:
133   // NOTE: In order to store `SGDParamGroup` in a `std::vector`, it has
134   // to be copy-constructible.
SGDParamGroup(const SGDParamGroup & param_group)135   SGDParamGroup(const SGDParamGroup& param_group)
136       : named_parameters_(param_group.named_parameters()),
137         options_(
138             param_group.has_options() ? param_group.options().clone()
139                                       : nullptr) {}
140   SGDParamGroup& operator=(const SGDParamGroup& param_group) {
141     this->named_parameters_ = param_group.named_parameters_;
142     this->options_ =
143         param_group.has_options() ? param_group.options().clone() : nullptr;
144     return *this;
145   }
146 
147   /**
148    * Constructs a SGD param group.
149    *
150    * @param[in] named_parameters The parameters to be optimized and their fully
151    * qualified names.
152    */
SGDParamGroup(const std::map<executorch::aten::string_view,executorch::aten::Tensor> & named_parameters)153   /* implicit */ SGDParamGroup(
154       const std::map<executorch::aten::string_view, executorch::aten::Tensor>&
155           named_parameters)
156       : named_parameters_(named_parameters) {}
SGDParamGroup(const std::map<executorch::aten::string_view,executorch::aten::Tensor> & named_parameters,std::unique_ptr<SGDOptions> options)157   SGDParamGroup(
158       const std::map<executorch::aten::string_view, executorch::aten::Tensor>&
159           named_parameters,
160       std::unique_ptr<SGDOptions> options)
161       : named_parameters_(named_parameters), options_(std::move(options)) {}
162 
163   bool has_options() const;
164   SGDOptions& options();
165   const SGDOptions& options() const;
166   void set_options(std::unique_ptr<SGDOptions> options);
167   const std::map<executorch::aten::string_view, executorch::aten::Tensor>&
168   named_parameters() const;
169 
170  private:
171   std::map<executorch::aten::string_view, executorch::aten::Tensor>
172       named_parameters_;
173   std::unique_ptr<SGDOptions> options_;
174 };
175 
176 /**
177  * SGD optimizer class. This is responsible for performing the optimization
178  * step.
179  */
180 class ET_EXPERIMENTAL SGD {
181  public:
SGD(const std::vector<SGDParamGroup> & param_groups,SGDOptions defaults)182   explicit SGD(
183       const std::vector<SGDParamGroup>& param_groups,
184       SGDOptions defaults)
185       : defaults_(std::make_unique<SGDOptions>(defaults)) {
186     for (const auto& param_group : param_groups) {
187       add_param_group(param_group);
188     }
189   }
190 
SGD(const std::map<executorch::aten::string_view,executorch::aten::Tensor> & named_parameters,SGDOptions defaults)191   explicit SGD(
192       const std::map<executorch::aten::string_view, executorch::aten::Tensor>&
193           named_parameters,
194       SGDOptions defaults)
195       : SGD({SGDParamGroup(named_parameters)}, defaults) {}
196 
197   // Adds the given param_group to the optimizer's param_group list.
198   void add_param_group(const SGDParamGroup& param_group);
199 
200   ~SGD();
201 
202   /**
203    * Performs the optimization step.
204    *
205    * @param[in] named_gradients The gradients of the tensors specified by the
206    * fully qualified name.
207    */
208   ::executorch::runtime::Error step(
209       const std::map<executorch::aten::string_view, executorch::aten::Tensor>&
210           named_gradients);
211 
212  private:
213   std::vector<SGDParamGroup> param_groups_;
214   std::unordered_map<void*, std::unique_ptr<SGDParamState>> state_;
215   std::unique_ptr<SGDOptions> defaults_;
216 };
217 
218 } // namespace optimizer
219 } // namespace training
220 } // namespace extension
221 } // namespace executorch
222