xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/adam.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/module.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/optim/serialize.h>
6 
7 #include <utility>
8 #include <vector>
9 
10 namespace torch {
11 namespace serialize {
12 class OutputArchive;
13 class InputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace optim {
19 
20 struct TORCH_API AdamOptions : public OptimizerCloneableOptions<AdamOptions> {
21   AdamOptions(double lr = 1e-3);
22   TORCH_ARG(double, lr) = 1e-3;
23   typedef std::tuple<double, double> betas_t;
24   TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999);
25   TORCH_ARG(double, eps) = 1e-8;
26   TORCH_ARG(double, weight_decay) = 0;
27   TORCH_ARG(bool, amsgrad) = false;
28 
29  public:
30   void serialize(torch::serialize::InputArchive& archive) override;
31   void serialize(torch::serialize::OutputArchive& archive) const override;
32   TORCH_API friend bool operator==(
33       const AdamOptions& lhs,
34       const AdamOptions& rhs);
35   double get_lr() const override;
36   void set_lr(const double lr) override;
37 };
38 
39 struct TORCH_API AdamParamState
40     : public OptimizerCloneableParamState<AdamParamState> {
41   TORCH_ARG(int64_t, step) = 0;
42   TORCH_ARG(torch::Tensor, exp_avg);
43   TORCH_ARG(torch::Tensor, exp_avg_sq);
44   TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {};
45 
46  public:
47   void serialize(torch::serialize::InputArchive& archive) override;
48   void serialize(torch::serialize::OutputArchive& archive) const override;
49   TORCH_API friend bool operator==(
50       const AdamParamState& lhs,
51       const AdamParamState& rhs);
52 };
53 
54 class TORCH_API Adam : public Optimizer {
55  public:
56   explicit Adam(
57       std::vector<OptimizerParamGroup> param_groups,
58       AdamOptions defaults = {})
Optimizer(std::move (param_groups),std::make_unique<AdamOptions> (defaults))59       : Optimizer(
60             std::move(param_groups),
61             std::make_unique<AdamOptions>(defaults)) {
62     TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr());
63     TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps());
64     auto betas = defaults.betas();
65     TORCH_CHECK(
66         0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0,
67         "Invalid beta parameter at index 0: ",
68         std::get<0>(betas));
69     TORCH_CHECK(
70         0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0,
71         "Invalid beta parameter at index 1: ",
72         std::get<1>(betas));
73     TORCH_CHECK(
74         defaults.weight_decay() >= 0,
75         "Invalid weight_decay value: ",
76         defaults.weight_decay());
77   }
78   explicit Adam(std::vector<Tensor> params, AdamOptions defaults = {})
79       : Adam({OptimizerParamGroup(std::move(params))}, defaults) {}
80 
81   torch::Tensor step(LossClosure closure = nullptr) override;
82   void save(serialize::OutputArchive& archive) const override;
83   void load(serialize::InputArchive& archive) override;
84 
85  private:
86   template <typename Self, typename Archive>
serialize(Self & self,Archive & archive)87   static void serialize(Self& self, Archive& archive) {
88     _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Adam);
89   }
90 };
91 } // namespace optim
92 } // namespace torch
93