1 #pragma once 2 3 #include <torch/nn/module.h> 4 #include <torch/optim/optimizer.h> 5 #include <torch/optim/serialize.h> 6 #include <torch/serialize/archive.h> 7 8 #include <deque> 9 #include <functional> 10 #include <memory> 11 #include <vector> 12 13 namespace torch { 14 namespace optim { 15 16 struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions<LBFGSOptions> { 17 LBFGSOptions(double lr = 1); 18 TORCH_ARG(double, lr) = 1; 19 TORCH_ARG(int64_t, max_iter) = 20; 20 TORCH_ARG(std::optional<int64_t>, max_eval) = std::nullopt; 21 TORCH_ARG(double, tolerance_grad) = 1e-7; 22 TORCH_ARG(double, tolerance_change) = 1e-9; 23 TORCH_ARG(int64_t, history_size) = 100; 24 TORCH_ARG(std::optional<std::string>, line_search_fn) = std::nullopt; 25 26 public: 27 void serialize(torch::serialize::InputArchive& archive) override; 28 void serialize(torch::serialize::OutputArchive& archive) const override; 29 TORCH_API friend bool operator==( 30 const LBFGSOptions& lhs, 31 const LBFGSOptions& rhs); 32 double get_lr() const override; 33 void set_lr(const double lr) override; 34 }; 35 36 struct TORCH_API LBFGSParamState 37 : public OptimizerCloneableParamState<LBFGSParamState> { 38 TORCH_ARG(int64_t, func_evals) = 0; 39 TORCH_ARG(int64_t, n_iter) = 0; 40 TORCH_ARG(double, t) = 0; 41 TORCH_ARG(double, prev_loss) = 0; 42 TORCH_ARG(Tensor, d) = {}; 43 TORCH_ARG(Tensor, H_diag) = {}; 44 TORCH_ARG(Tensor, prev_flat_grad) = {}; 45 TORCH_ARG(std::deque<Tensor>, old_dirs); 46 TORCH_ARG(std::deque<Tensor>, old_stps); 47 TORCH_ARG(std::deque<Tensor>, ro); 48 TORCH_ARG(std::optional<std::vector<Tensor>>, al) = std::nullopt; 49 50 public: 51 void serialize(torch::serialize::InputArchive& archive) override; 52 void serialize(torch::serialize::OutputArchive& archive) const override; 53 TORCH_API friend bool operator==( 54 const LBFGSParamState& lhs, 55 const LBFGSParamState& rhs); 56 }; 57 58 class TORCH_API LBFGS : public Optimizer { 59 public: 60 explicit LBFGS( 61 std::vector<OptimizerParamGroup> param_groups, 62 LBFGSOptions defaults = {}) Optimizer(std::move (param_groups),std::make_unique<LBFGSOptions> (defaults))63 : Optimizer( 64 std::move(param_groups), 65 std::make_unique<LBFGSOptions>(defaults)) { 66 TORCH_CHECK( 67 param_groups_.size() == 1, 68 "LBFGS doesn't support per-parameter options (parameter groups)"); 69 if (defaults.max_eval() == std::nullopt) { 70 auto max_eval_val = (defaults.max_iter() * 5) / 4; 71 static_cast<LBFGSOptions&>(param_groups_[0].options()) 72 .max_eval(max_eval_val); 73 static_cast<LBFGSOptions&>(*defaults_.get()).max_eval(max_eval_val); 74 } 75 _numel_cache = std::nullopt; 76 } 77 explicit LBFGS(std::vector<Tensor> params, LBFGSOptions defaults = {}) 78 : LBFGS({OptimizerParamGroup(std::move(params))}, defaults) {} 79 80 Tensor step(LossClosure closure) override; 81 void save(serialize::OutputArchive& archive) const override; 82 void load(serialize::InputArchive& archive) override; 83 84 private: 85 std::optional<int64_t> _numel_cache; 86 int64_t _numel(); 87 Tensor _gather_flat_grad(); 88 void _add_grad(const double step_size, const Tensor& update); 89 std::tuple<double, Tensor> _directional_evaluate( 90 const LossClosure& closure, 91 const std::vector<Tensor>& x, 92 double t, 93 const Tensor& d); 94 void _set_param(const std::vector<Tensor>& params_data); 95 std::vector<Tensor> _clone_param(); 96 97 template <typename Self, typename Archive> serialize(Self & self,Archive & archive)98 static void serialize(Self& self, Archive& archive) { 99 _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(LBFGS); 100 } 101 }; 102 } // namespace optim 103 } // namespace torch 104