1 #pragma once 2 3 #include <torch/nn/module.h> 4 #include <torch/optim/optimizer.h> 5 #include <torch/optim/serialize.h> 6 7 #include <utility> 8 #include <vector> 9 10 namespace torch { 11 namespace serialize { 12 class OutputArchive; 13 class InputArchive; 14 } // namespace serialize 15 } // namespace torch 16 17 namespace torch { 18 namespace optim { 19 20 struct TORCH_API AdamWOptions : public OptimizerCloneableOptions<AdamWOptions> { 21 AdamWOptions(double lr = 1e-3); 22 TORCH_ARG(double, lr) = 1e-3; 23 typedef std::tuple<double, double> betas_t; 24 TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999); 25 TORCH_ARG(double, eps) = 1e-8; 26 TORCH_ARG(double, weight_decay) = 1e-2; 27 TORCH_ARG(bool, amsgrad) = false; 28 29 public: 30 void serialize(torch::serialize::InputArchive& archive) override; 31 void serialize(torch::serialize::OutputArchive& archive) const override; 32 TORCH_API friend bool operator==( 33 const AdamWOptions& lhs, 34 const AdamWOptions& rhs); 35 double get_lr() const override; 36 void set_lr(const double lr) override; 37 }; 38 39 struct TORCH_API AdamWParamState 40 : public OptimizerCloneableParamState<AdamWParamState> { 41 TORCH_ARG(int64_t, step) = 0; 42 TORCH_ARG(torch::Tensor, exp_avg); 43 TORCH_ARG(torch::Tensor, exp_avg_sq); 44 TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {}; 45 46 public: 47 void serialize(torch::serialize::InputArchive& archive) override; 48 void serialize(torch::serialize::OutputArchive& archive) const override; 49 TORCH_API friend bool operator==( 50 const AdamWParamState& lhs, 51 const AdamWParamState& rhs); 52 }; 53 54 class TORCH_API AdamW : public Optimizer { 55 public: 56 explicit AdamW( 57 std::vector<OptimizerParamGroup> param_groups, 58 AdamWOptions defaults = {}) Optimizer(std::move (param_groups),std::make_unique<AdamWOptions> (defaults))59 : Optimizer( 60 std::move(param_groups), 61 std::make_unique<AdamWOptions>(defaults)) { 62 TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); 63 TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); 64 auto betas = defaults.betas(); 65 TORCH_CHECK( 66 0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, 67 "Invalid beta parameter at index 0: ", 68 std::get<0>(betas)); 69 TORCH_CHECK( 70 0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, 71 "Invalid beta parameter at index 1: ", 72 std::get<1>(betas)); 73 TORCH_CHECK( 74 defaults.weight_decay() >= 0, 75 "Invalid weight_decay value: ", 76 defaults.weight_decay()); 77 } 78 explicit AdamW(std::vector<Tensor> params, AdamWOptions defaults = {}) 79 : AdamW({OptimizerParamGroup(std::move(params))}, defaults) {} 80 81 torch::Tensor step(LossClosure closure = nullptr) override; 82 void save(serialize::OutputArchive& archive) const override; 83 void load(serialize::InputArchive& archive) override; 84 85 private: 86 template <typename Self, typename Archive> serialize(Self & self,Archive & archive)87 static void serialize(Self& self, Archive& archive) { 88 _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(AdamW); 89 } 90 }; 91 } // namespace optim 92 } // namespace torch 93