xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/adamw.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/module.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/optim/serialize.h>
6 
7 #include <utility>
8 #include <vector>
9 
10 namespace torch {
11 namespace serialize {
12 class OutputArchive;
13 class InputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace optim {
19 
20 struct TORCH_API AdamWOptions : public OptimizerCloneableOptions<AdamWOptions> {
21   AdamWOptions(double lr = 1e-3);
22   TORCH_ARG(double, lr) = 1e-3;
23   typedef std::tuple<double, double> betas_t;
24   TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999);
25   TORCH_ARG(double, eps) = 1e-8;
26   TORCH_ARG(double, weight_decay) = 1e-2;
27   TORCH_ARG(bool, amsgrad) = false;
28 
29  public:
30   void serialize(torch::serialize::InputArchive& archive) override;
31   void serialize(torch::serialize::OutputArchive& archive) const override;
32   TORCH_API friend bool operator==(
33       const AdamWOptions& lhs,
34       const AdamWOptions& rhs);
35   double get_lr() const override;
36   void set_lr(const double lr) override;
37 };
38 
39 struct TORCH_API AdamWParamState
40     : public OptimizerCloneableParamState<AdamWParamState> {
41   TORCH_ARG(int64_t, step) = 0;
42   TORCH_ARG(torch::Tensor, exp_avg);
43   TORCH_ARG(torch::Tensor, exp_avg_sq);
44   TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {};
45 
46  public:
47   void serialize(torch::serialize::InputArchive& archive) override;
48   void serialize(torch::serialize::OutputArchive& archive) const override;
49   TORCH_API friend bool operator==(
50       const AdamWParamState& lhs,
51       const AdamWParamState& rhs);
52 };
53 
54 class TORCH_API AdamW : public Optimizer {
55  public:
56   explicit AdamW(
57       std::vector<OptimizerParamGroup> param_groups,
58       AdamWOptions defaults = {})
Optimizer(std::move (param_groups),std::make_unique<AdamWOptions> (defaults))59       : Optimizer(
60             std::move(param_groups),
61             std::make_unique<AdamWOptions>(defaults)) {
62     TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr());
63     TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps());
64     auto betas = defaults.betas();
65     TORCH_CHECK(
66         0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0,
67         "Invalid beta parameter at index 0: ",
68         std::get<0>(betas));
69     TORCH_CHECK(
70         0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0,
71         "Invalid beta parameter at index 1: ",
72         std::get<1>(betas));
73     TORCH_CHECK(
74         defaults.weight_decay() >= 0,
75         "Invalid weight_decay value: ",
76         defaults.weight_decay());
77   }
78   explicit AdamW(std::vector<Tensor> params, AdamWOptions defaults = {})
79       : AdamW({OptimizerParamGroup(std::move(params))}, defaults) {}
80 
81   torch::Tensor step(LossClosure closure = nullptr) override;
82   void save(serialize::OutputArchive& archive) const override;
83   void load(serialize::InputArchive& archive) override;
84 
85  private:
86   template <typename Self, typename Archive>
serialize(Self & self,Archive & archive)87   static void serialize(Self& self, Archive& archive) {
88     _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(AdamW);
89   }
90 };
91 } // namespace optim
92 } // namespace torch
93