xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/lbfgs.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/module.h>
4 #include <torch/optim/optimizer.h>
5 #include <torch/optim/serialize.h>
6 #include <torch/serialize/archive.h>
7 
8 #include <deque>
9 #include <functional>
10 #include <memory>
11 #include <vector>
12 
13 namespace torch {
14 namespace optim {
15 
16 struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions<LBFGSOptions> {
17   LBFGSOptions(double lr = 1);
18   TORCH_ARG(double, lr) = 1;
19   TORCH_ARG(int64_t, max_iter) = 20;
20   TORCH_ARG(std::optional<int64_t>, max_eval) = std::nullopt;
21   TORCH_ARG(double, tolerance_grad) = 1e-7;
22   TORCH_ARG(double, tolerance_change) = 1e-9;
23   TORCH_ARG(int64_t, history_size) = 100;
24   TORCH_ARG(std::optional<std::string>, line_search_fn) = std::nullopt;
25 
26  public:
27   void serialize(torch::serialize::InputArchive& archive) override;
28   void serialize(torch::serialize::OutputArchive& archive) const override;
29   TORCH_API friend bool operator==(
30       const LBFGSOptions& lhs,
31       const LBFGSOptions& rhs);
32   double get_lr() const override;
33   void set_lr(const double lr) override;
34 };
35 
36 struct TORCH_API LBFGSParamState
37     : public OptimizerCloneableParamState<LBFGSParamState> {
38   TORCH_ARG(int64_t, func_evals) = 0;
39   TORCH_ARG(int64_t, n_iter) = 0;
40   TORCH_ARG(double, t) = 0;
41   TORCH_ARG(double, prev_loss) = 0;
42   TORCH_ARG(Tensor, d) = {};
43   TORCH_ARG(Tensor, H_diag) = {};
44   TORCH_ARG(Tensor, prev_flat_grad) = {};
45   TORCH_ARG(std::deque<Tensor>, old_dirs);
46   TORCH_ARG(std::deque<Tensor>, old_stps);
47   TORCH_ARG(std::deque<Tensor>, ro);
48   TORCH_ARG(std::optional<std::vector<Tensor>>, al) = std::nullopt;
49 
50  public:
51   void serialize(torch::serialize::InputArchive& archive) override;
52   void serialize(torch::serialize::OutputArchive& archive) const override;
53   TORCH_API friend bool operator==(
54       const LBFGSParamState& lhs,
55       const LBFGSParamState& rhs);
56 };
57 
58 class TORCH_API LBFGS : public Optimizer {
59  public:
60   explicit LBFGS(
61       std::vector<OptimizerParamGroup> param_groups,
62       LBFGSOptions defaults = {})
Optimizer(std::move (param_groups),std::make_unique<LBFGSOptions> (defaults))63       : Optimizer(
64             std::move(param_groups),
65             std::make_unique<LBFGSOptions>(defaults)) {
66     TORCH_CHECK(
67         param_groups_.size() == 1,
68         "LBFGS doesn't support per-parameter options (parameter groups)");
69     if (defaults.max_eval() == std::nullopt) {
70       auto max_eval_val = (defaults.max_iter() * 5) / 4;
71       static_cast<LBFGSOptions&>(param_groups_[0].options())
72           .max_eval(max_eval_val);
73       static_cast<LBFGSOptions&>(*defaults_.get()).max_eval(max_eval_val);
74     }
75     _numel_cache = std::nullopt;
76   }
77   explicit LBFGS(std::vector<Tensor> params, LBFGSOptions defaults = {})
78       : LBFGS({OptimizerParamGroup(std::move(params))}, defaults) {}
79 
80   Tensor step(LossClosure closure) override;
81   void save(serialize::OutputArchive& archive) const override;
82   void load(serialize::InputArchive& archive) override;
83 
84  private:
85   std::optional<int64_t> _numel_cache;
86   int64_t _numel();
87   Tensor _gather_flat_grad();
88   void _add_grad(const double step_size, const Tensor& update);
89   std::tuple<double, Tensor> _directional_evaluate(
90       const LossClosure& closure,
91       const std::vector<Tensor>& x,
92       double t,
93       const Tensor& d);
94   void _set_param(const std::vector<Tensor>& params_data);
95   std::vector<Tensor> _clone_param();
96 
97   template <typename Self, typename Archive>
serialize(Self & self,Archive & archive)98   static void serialize(Self& self, Archive& archive) {
99     _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(LBFGS);
100   }
101 };
102 } // namespace optim
103 } // namespace torch
104