xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/optim/sgd.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/types.h>
5 
6 #include <utility>
7 #include <vector>
8 
9 namespace torch::jit::mobile {
10 
11 class SGDParamState {
12   TORCH_ARG(torch::Tensor, momentum_buffer);
13 
14  public:
clone()15   std::unique_ptr<SGDParamState> clone() const {
16     return std::make_unique<SGDParamState>(
17         static_cast<const SGDParamState&>(*this));
18   }
19   friend bool operator==(const SGDParamState& lhs, const SGDParamState& rhs);
20 };
21 
22 struct TORCH_API SGDOptions {
23   /* implicit */ SGDOptions(double lr);
24   TORCH_ARG(double, lr);
25   TORCH_ARG(double, momentum) = 0;
26   TORCH_ARG(double, dampening) = 0;
27   TORCH_ARG(double, weight_decay) = 0;
28   TORCH_ARG(bool, nesterov) = false;
29 
30  public:
cloneSGDOptions31   std::unique_ptr<SGDOptions> clone() const {
32     return std::make_unique<SGDOptions>(static_cast<const SGDOptions&>(*this));
33   }
34   TORCH_API friend bool operator==(
35       const SGDOptions& lhs,
36       const SGDOptions& rhs);
37 };
38 
39 /// Stores parameters in the param_group and stores a pointer to the SGDOptions
40 class TORCH_API SGDParamGroup {
41  public:
42   // NOTE: In order to store `SGDParamGroup` in a `std::vector`, it has to be
43   // copy-constructible.
SGDParamGroup(const SGDParamGroup & param_group)44   SGDParamGroup(const SGDParamGroup& param_group)
45       : params_(param_group.params()),
46         options_(
47             param_group.has_options() ? param_group.options().clone()
48                                       : nullptr) {}
49   SGDParamGroup& operator=(const SGDParamGroup& param_group) {
50     this->params_ = param_group.params();
51     this->options_ =
52         param_group.has_options() ? param_group.options().clone() : nullptr;
53     return *this;
54   }
SGDParamGroup(std::vector<Tensor> params)55   /* implicit */ SGDParamGroup(std::vector<Tensor> params)
56       : params_(std::move(params)) {}
SGDParamGroup(std::vector<Tensor> params,std::unique_ptr<SGDOptions> options)57   SGDParamGroup(std::vector<Tensor> params, std::unique_ptr<SGDOptions> options)
58       : params_(std::move(params)), options_(std::move(options)) {}
59 
60   bool has_options() const;
61   SGDOptions& options();
62   const SGDOptions& options() const;
63   void set_options(std::unique_ptr<SGDOptions> options);
64   std::vector<Tensor>& params();
65   const std::vector<Tensor>& params() const;
66 
67  protected:
68   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
69   std::vector<Tensor> params_;
70   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
71   std::unique_ptr<SGDOptions> options_;
72 };
73 
74 class TORCH_API SGD {
75  public:
SGD(const std::vector<torch::jit::mobile::SGDParamGroup> & param_groups,SGDOptions defaults)76   explicit SGD(
77       const std::vector<torch::jit::mobile::SGDParamGroup>& param_groups,
78       SGDOptions defaults)
79       : defaults_(std::make_unique<SGDOptions>(defaults)) {
80     for (const auto& param_group : param_groups) {
81       add_param_group(param_group);
82     }
83     TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr());
84     TORCH_CHECK(
85         defaults.momentum() >= 0,
86         "Invalid momentum value: ",
87         defaults.momentum());
88     TORCH_CHECK(
89         defaults.weight_decay() >= 0,
90         "Invalid weight_decay value: ",
91         defaults.weight_decay());
92     TORCH_CHECK(
93         !defaults.nesterov() ||
94             (defaults.momentum() > 0 && defaults.dampening() == 0),
95         "Nesterov momentum requires a momentum and zero dampening");
96   }
97 
SGD(std::vector<Tensor> params,SGDOptions defaults)98   explicit SGD(std::vector<Tensor> params, SGDOptions defaults)
99       : SGD({SGDParamGroup(std::move(params))}, defaults) {}
100 
101   /// Adds the given param_group to the optimizer's param_group list.
102   void add_param_group(const SGDParamGroup& param_group);
103 
104   ~SGD() = default;
105 
106   using LossClosure = std::function<Tensor()>;
107   /// A loss function closure, which is expected to return the loss value.
108   torch::Tensor step(const LossClosure& closure = nullptr);
109 
110   /// Zeros out the gradients of all parameters.
111   void zero_grad();
112 
113  protected:
114   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
115   std::vector<SGDParamGroup> param_groups_;
116   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
117   ska::flat_hash_map<void*, std::unique_ptr<SGDParamState>> state_;
118   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
119   std::unique_ptr<SGDOptions> defaults_;
120   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
121   std::vector<Tensor> params_;
122   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
123   std::unique_ptr<SGDOptions> options_;
124 };
125 } // namespace torch::jit::mobile
126