xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/lbfgs.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/lbfgs.h>
2 
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7 
8 #include <ATen/ATen.h>
9 #include <c10/util/irange.h>
10 
11 #include <algorithm>
12 #include <cmath>
13 #include <functional>
14 #include <vector>
15 
16 namespace torch {
17 namespace optim {
18 
LBFGSOptions(double lr)19 LBFGSOptions::LBFGSOptions(double lr) : lr_(lr) {}
20 
operator ==(const LBFGSOptions & lhs,const LBFGSOptions & rhs)21 bool operator==(const LBFGSOptions& lhs, const LBFGSOptions& rhs) {
22   return (lhs.lr() == rhs.lr()) && (lhs.max_iter() == rhs.max_iter()) &&
23       (lhs.max_eval() == rhs.max_eval()) &&
24       (lhs.tolerance_grad() == rhs.tolerance_grad()) &&
25       (lhs.tolerance_change() == rhs.tolerance_change() &&
26        (lhs.history_size() == rhs.history_size())) &&
27       (lhs.line_search_fn() == rhs.line_search_fn());
28 }
29 
serialize(torch::serialize::OutputArchive & archive) const30 void LBFGSOptions::serialize(torch::serialize::OutputArchive& archive) const {
31   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
32   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_iter);
33   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_eval);
34   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_grad);
35   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_change);
36   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(history_size);
37   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(line_search_fn);
38 }
39 
serialize(torch::serialize::InputArchive & archive)40 void LBFGSOptions::serialize(torch::serialize::InputArchive& archive) {
41   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
42   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, max_iter);
43   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(int64_t, max_eval);
44   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_grad);
45   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_change);
46   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, history_size);
47   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::string, line_search_fn);
48 }
49 
get_lr() const50 double LBFGSOptions::get_lr() const {
51   return lr();
52 }
53 
set_lr(const double lr)54 void LBFGSOptions::set_lr(const double lr) {
55   this->lr(lr);
56 }
57 
58 template <typename T>
if_container_equal(T lhs,T rhs)59 bool if_container_equal(T lhs, T rhs) {
60   if (!(lhs.size() == rhs.size()))
61     return false;
62   for (const auto i : c10::irange(lhs.size())) {
63     if (!torch::equal(lhs.at(i), rhs.at(i)))
64       return false;
65   }
66   return true;
67 }
68 
operator ==(const LBFGSParamState & lhs,const LBFGSParamState & rhs)69 bool operator==(const LBFGSParamState& lhs, const LBFGSParamState& rhs) {
70   auto isNull = [](const std::optional<std::vector<Tensor>>& val) {
71     return val == std::nullopt;
72   };
73   return (lhs.func_evals() == rhs.func_evals()) &&
74       (lhs.n_iter() == rhs.n_iter()) && (lhs.t() == rhs.t()) &&
75       (lhs.prev_loss() == rhs.prev_loss()) &&
76       torch::equal_if_defined(lhs.d(), rhs.d()) &&
77       torch::equal_if_defined(lhs.H_diag(), rhs.H_diag()) &&
78       torch::equal_if_defined(lhs.prev_flat_grad(), rhs.prev_flat_grad()) &&
79       if_container_equal(lhs.old_dirs(), rhs.old_dirs()) &&
80       if_container_equal(lhs.old_stps(), rhs.old_stps()) &&
81       if_container_equal(lhs.ro(), rhs.ro()) &&
82       ((isNull(lhs.al()) && isNull(rhs.al())) ||
83        (!isNull(lhs.al()) && !isNull(rhs.al()) &&
84         if_container_equal(*lhs.al(), *rhs.al())));
85 }
86 
serialize(torch::serialize::OutputArchive & archive) const87 void LBFGSParamState::serialize(
88     torch::serialize::OutputArchive& archive) const {
89   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(func_evals);
90   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(n_iter);
91   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(t);
92   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_loss);
93   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(d);
94   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(H_diag);
95   _TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_flat_grad);
96   _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_dirs);
97   _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_stps);
98   _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(ro);
99   // Python version only serializes state vars if explicitly defined
100   if (al() != std::nullopt) {
101     _TORCH_OPTIM_SERIALIZE_TORCH_ARG(al);
102   }
103 }
104 
serialize(torch::serialize::InputArchive & archive)105 void LBFGSParamState::serialize(torch::serialize::InputArchive& archive) {
106   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, func_evals);
107   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, n_iter);
108   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, t);
109   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, prev_loss);
110   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, d);
111   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, H_diag);
112   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, prev_flat_grad);
113   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, old_dirs);
114   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, old_stps);
115   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, ro);
116   _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::vector<Tensor>, al);
117 }
118 
_gather_flat_grad()119 Tensor LBFGS::_gather_flat_grad() {
120   std::vector<Tensor> views;
121   for (const auto& p : param_groups_.at(0).params()) {
122     if (!p.grad().defined()) {
123       views.emplace_back(p.new_empty({p.numel()}).zero_());
124     } else if (p.grad().is_sparse()) {
125       views.emplace_back(p.grad().to_dense().view(-1));
126     } else {
127       views.emplace_back(p.grad().view(-1));
128     }
129   }
130   return torch::cat(views, 0);
131 }
132 
_numel()133 int64_t LBFGS::_numel() {
134   if (_numel_cache == std::nullopt) {
135     auto res = 0;
136     for (const auto& p : param_groups_.at(0).params()) {
137       res += p.numel();
138     }
139     _numel_cache = res;
140   }
141   return *_numel_cache;
142 }
143 
_add_grad(const double step_size,const Tensor & update)144 void LBFGS::_add_grad(const double step_size, const Tensor& update) {
145   auto offset = 0;
146   for (auto& p : param_groups_.at(0).params()) {
147     auto numel = p.numel();
148     // view as to avoid deprecated pointwise semantics
149     p.add_(
150         update.index({at::indexing::Slice(offset, offset + numel)}).view_as(p),
151         step_size);
152     offset += numel;
153   }
154   TORCH_INTERNAL_ASSERT(offset == _numel());
155 }
156 
_set_param(const std::vector<Tensor> & params_data)157 void LBFGS::_set_param(const std::vector<Tensor>& params_data) {
158   auto& _params = param_groups_.at(0).params();
159   TORCH_INTERNAL_ASSERT(params_data.size() == _params.size());
160   for (const auto i : c10::irange(_params.size())) {
161     _params.at(i).copy_(params_data.at(i));
162   }
163 }
164 
_clone_param()165 std::vector<Tensor> LBFGS::_clone_param() {
166   std::vector<Tensor> result;
167   for (const auto& p : param_groups_.at(0).params()) {
168     result.emplace_back(p.clone(at::MemoryFormat::Contiguous));
169   }
170   return result;
171 }
172 
_directional_evaluate(const LossClosure & closure,const std::vector<Tensor> & x,double t,const Tensor & d)173 std::tuple<double, Tensor> LBFGS::_directional_evaluate(
174     const LossClosure& closure,
175     const std::vector<Tensor>& x,
176     double t,
177     const Tensor& d) {
178   _add_grad(t, d);
179   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
180   double loss;
181   {
182     torch::AutoGradMode enable_grad(true);
183     loss = closure().item<double>();
184   }
185   auto flat_grad = _gather_flat_grad();
186   _set_param(x);
187   return std::make_tuple(loss, flat_grad);
188 }
189 
_cubic_interpolate(double x1,double f1,double g1,double x2,double f2,double g2,std::optional<std::tuple<double,double>> bounds=std::nullopt)190 static double _cubic_interpolate(
191     double x1,
192     double f1,
193     double g1,
194     double x2,
195     double f2,
196     double g2,
197     std::optional<std::tuple<double, double>> bounds = std::nullopt) {
198   // ported from https://github.com/torch/optim/blob/master/polyinterp.lua
199   // Compute bounds of interpolation area
200   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
201   double xmin_bound, xmax_bound;
202   if (bounds != std::nullopt) {
203     std::tie(xmin_bound, xmax_bound) = *bounds;
204   } else {
205     std::tie(xmin_bound, xmax_bound) =
206         (x1 <= x2) ? std::make_tuple(x1, x2) : std::make_tuple(x2, x1);
207   }
208   // Code for most common case: cubic interpolation of 2 points
209   //   w/ function and derivative values for both
210   // Solution in this case (where x2 is the farthest point):
211   //   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
212   //   d2 = sqrt(d1^2 - g1*g2);
213   //   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
214   //   t_new = min(max(min_pos,xmin_bound),xmax_bound);
215 
216   auto d1 = (g1 + g2) - (3 * (f1 - f2) / (x1 - x2));
217   auto d2_square = std::pow(d1, 2) - g1 * g2;
218   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
219   double d2;
220   if (d2_square >= 0) {
221     d2 = std::sqrt(d2_square);
222     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
223     double min_pos;
224     if (x1 <= x2) {
225       min_pos = x2 - ((x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)));
226     } else {
227       min_pos = x1 - ((x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)));
228     }
229     return std::min(std::max(min_pos, xmin_bound), xmax_bound);
230   } else {
231     return (xmin_bound + xmax_bound) / 2;
232   }
233 }
234 
235 using Function = std::function<std::tuple<double, Tensor>(
236     const std::vector<Tensor>& x,
237     double t,
238     const Tensor& d)>;
_strong_wolfe(const Function & obj_func,const std::vector<Tensor> & x,double t,const Tensor & d,double f,Tensor g,const Tensor & gtd,double c1=1e-4,double c2=0.9,double tolerance_change=1e-9,double max_ls=25)239 static std::tuple<double, Tensor, double, int64_t> _strong_wolfe(
240     const Function& obj_func,
241     const std::vector<Tensor>& x,
242     double t,
243     const Tensor& d,
244     double f,
245     Tensor g,
246     const Tensor& gtd,
247     double c1 = 1e-4,
248     double c2 = 0.9, // // NOLINT(cppcoreguidelines-avoid-magic-numbers)
249     double tolerance_change = 1e-9,
250     double max_ls = 25) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
251 
252   auto val = [](const Tensor& t) { return t.item<double>(); };
253 
254   auto d_norm = val(d.abs().max());
255   g = g.clone(at::MemoryFormat::Contiguous);
256   // evaluate objective and gradient using initial step
257   auto [f_new, g_new] = obj_func(x, t, d);
258   int64_t ls_func_evals = 1;
259   auto gtd_new = g_new.dot(d);
260 
261   // bracket an interval containing a point satisfying the Wolfe criteria
262   double t_prev = 0;
263   auto f_prev = f;
264   auto g_prev = g;
265   auto gtd_prev = gtd;
266   bool done = false;
267   auto ls_iter = 0;
268   std::vector<double> bracket, bracket_f;
269   std::vector<Tensor> bracket_g, bracket_gtd;
270 
271   while (ls_iter < max_ls) {
272     // check conditions
273     if ((f_new > (f + c1 * t * val(gtd))) ||
274         (ls_iter > 1 && (f_new >= f_prev))) {
275       bracket = {t_prev, t};
276       bracket_f = {f_prev, f_new};
277       bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)};
278       bracket_gtd = {gtd_prev, gtd_new};
279       break;
280     }
281     if (std::abs(val(gtd_new)) <= (-c2 * val(gtd))) {
282       bracket = {t, t};
283       bracket_f = {f_new, f_new};
284       bracket_g = {g_new, g_new};
285       done = true;
286       break;
287     }
288     if (val(gtd_new) >= 0) {
289       bracket = {t_prev, t};
290       bracket_f = {f_prev, f_new};
291       bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)};
292       bracket_gtd = {gtd_prev, gtd_new};
293       break;
294     }
295     // interpolate
296     auto min_step = t +
297         0.01 * (t - t_prev); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
298     auto max_step = t * 10; // NOLINT(cppcoreguidelines-avoid-magic-numbers)
299     auto tmp = t;
300     t = _cubic_interpolate(
301         t_prev,
302         f_prev,
303         val(gtd_prev),
304         t,
305         f_new,
306         val(gtd_new),
307         std::make_tuple(min_step, max_step));
308     // next step
309     t_prev = tmp;
310     f_prev = f_new;
311     g_prev = g_new.clone(at::MemoryFormat::Contiguous);
312     gtd_prev = gtd_new;
313     std::tie(f_new, g_new) = obj_func(x, t, d);
314     ls_func_evals += 1;
315     gtd_new = g_new.dot(d);
316     ls_iter += 1;
317   }
318   // reached max number of iterations?
319   if (ls_iter == max_ls) {
320     bracket = {0, t};
321     bracket_f = {f, f_new};
322     bracket_g = {g, g_new};
323   }
324 
325   // zoom phase: we now have a point satisfying the criteria, or
326   // a bracket around it. We refine the bracket until we find the
327   // exact point satisfying the criteria
328   bool insuf_progress = false;
329   // find high and low points in bracket
330   auto [low_pos, high_pos] = bracket_f[0] <= bracket_f[1]
331       ? std::make_tuple(0, 1)
332       : std::make_tuple(1, 0);
333   while (!done && (ls_iter < max_ls)) {
334     // compute new trial value
335     t = _cubic_interpolate(
336         bracket[0],
337         bracket_f[0],
338         val(bracket_gtd[0]),
339         bracket[1],
340         bracket_f[1],
341         val(bracket_gtd[1]));
342 
343     // test that we are making sufficient progress:
344     // in case `t` is so close to boundary, we mark that we are making
345     // insufficient progress, and if
346     //   + we have made insufficient progress in the last step, or
347     //   + `t` is at one of the boundary,
348     // we will move `t` to a position which is `0.1 * len(bracket)`
349     // away from the nearest boundary point.
350     double bracket_max = std::max(bracket[0], bracket[1]);
351     auto bracket_min = std::min(bracket[0], bracket[1]);
352     auto eps = 0.1 *
353         (bracket_max -
354          bracket_min); // // NOLINT(cppcoreguidelines-avoid-magic-numbers)
355     if (std::min(bracket_max - t, t - bracket_min) < eps) {
356       // interpolation close to boundary
357       if (insuf_progress || (t >= bracket_max) || (t <= bracket_min)) {
358         // evaluate at 0.1 away from boundary
359         t = (std::abs(t - bracket_max) < std::abs(t - bracket_min))
360             ? bracket_max - eps
361             : bracket_min + eps;
362         insuf_progress = false;
363       } else {
364         insuf_progress = true;
365       }
366     } else {
367       insuf_progress = false;
368     }
369 
370     // Evaluate new point
371     std::tie(f_new, g_new) = obj_func(x, t, d);
372     ls_func_evals += 1;
373     gtd_new = g_new.dot(d);
374     ls_iter += 1;
375 
376     if ((f_new > (f + c1 * t * val(gtd))) || (f_new >= bracket_f[low_pos])) {
377       // Armijo condition not satisfied or not lower than lowest point
378       // # Armijo condition not satisfied or not lower than lowest point
379       bracket[high_pos] = t;
380       bracket_f[high_pos] = f_new;
381       bracket_g[high_pos] = g_new.clone(at::MemoryFormat::Contiguous);
382       bracket_gtd[high_pos] = gtd_new;
383       std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1]
384           ? std::make_tuple(0, 1)
385           : std::make_tuple(1, 0);
386     } else {
387       if (val(at::abs(gtd_new)) <= (-c2 * val(gtd))) {
388         // Wolfe conditions satisfied
389         done = true;
390       } else if ((val(gtd_new) * (bracket[high_pos] - bracket[low_pos])) >= 0) {
391         // old high becomes new low
392         bracket[high_pos] = bracket[low_pos];
393         bracket_f[high_pos] = bracket_f[low_pos];
394         bracket_g[high_pos] = bracket_g[low_pos];
395         bracket_gtd[high_pos] = bracket_gtd[low_pos];
396       }
397 
398       // new point becomes new low
399       bracket[low_pos] = t;
400       bracket_f[low_pos] = f_new;
401       bracket_g[low_pos] = g_new.clone(at::MemoryFormat::Contiguous);
402       bracket_gtd[low_pos] = gtd_new;
403     }
404 
405     // line-search bracket is so small
406     if ((std::abs(bracket[1] - bracket[0]) * d_norm) < tolerance_change)
407       break;
408   }
409 
410   // return stuff
411   t = bracket[low_pos];
412   f_new = bracket_f[low_pos];
413   g_new = bracket_g[low_pos];
414   return std::make_tuple(f_new, g_new, t, ls_func_evals);
415 }
416 
step(LossClosure closure)417 Tensor LBFGS::step(LossClosure closure) {
418   NoGradGuard no_grad;
419   TORCH_CHECK(closure != nullptr, "LBFGS requires a closure function");
420   TORCH_INTERNAL_ASSERT(param_groups_.size() == 1);
421   auto val = [](const Tensor& t) { return t.item<double>(); };
422 
423   auto& group = param_groups_.at(0);
424   auto& _params = group.params();
425   const auto& options = static_cast<const LBFGSOptions&>(group.options());
426   auto lr = options.lr();
427   auto max_iter = options.max_iter();
428   auto max_eval = options.max_eval();
429   auto tolerance_grad = options.tolerance_grad();
430   auto tolerance_change = options.tolerance_change();
431   auto line_search_fn = options.line_search_fn();
432   auto history_size = options.history_size();
433 
434   // NOTE: LBFGS has only global state, but we register it as state for
435   // the first param, because this helps with casting in load_state_dict
436   auto param_state = state_.find(_params.at(0).unsafeGetTensorImpl());
437   if (param_state == state_.end()) {
438     state_[_params.at(0).unsafeGetTensorImpl()] =
439         std::make_unique<LBFGSParamState>();
440   }
441   auto& state = static_cast<LBFGSParamState&>(
442       *state_[_params.at(0).unsafeGetTensorImpl()]);
443   // evaluate initial f(x) and df/dx
444   Tensor orig_loss;
445   {
446     torch::AutoGradMode enable_grad(true);
447     orig_loss = closure();
448   }
449 
450   auto loss = val(orig_loss);
451   auto current_evals = 1;
452   state.func_evals(state.func_evals() + 1);
453   auto flat_grad = _gather_flat_grad();
454   auto opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad);
455 
456   // optimal condition
457   if (opt_cond) {
458     return orig_loss;
459   }
460 
461   // tensors cached in state (for tracing)
462   auto& d = state.d();
463   auto& t = state.t();
464   auto& old_dirs = state.old_dirs();
465   auto& old_stps = state.old_stps();
466   auto& ro = state.ro();
467   auto& H_diag = state.H_diag();
468   auto& prev_flat_grad = state.prev_flat_grad();
469   auto& prev_loss = state.prev_loss();
470 
471   int n_iter = 0;
472 
473   // optimize for a max of max_iter iterations
474   while (n_iter < max_iter) {
475     // keep track of nb of iterations
476     n_iter += 1;
477     state.n_iter(state.n_iter() + 1);
478 
479     // compute gradient descent direction
480     if (state.n_iter() == 1) {
481       d = flat_grad.neg();
482       H_diag = torch::tensor(1);
483       old_dirs = {};
484       old_stps = {};
485       ro = {};
486     } else {
487       // do lbfgs update (update memory)
488       auto y = flat_grad.sub(prev_flat_grad);
489       auto s = d.mul(t);
490       auto ys = y.dot(s); // y*s
491       if (val(ys) > 1e-10) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
492         // updating memory
493         if (static_cast<int64_t>(old_dirs.size()) == history_size) {
494           // shift history by one (limited-memory)
495           old_dirs.pop_front();
496           old_stps.pop_front();
497           ro.pop_front();
498         }
499         // store new direction/step
500         old_dirs.emplace_back(y);
501         old_stps.emplace_back(s);
502         ro.emplace_back(1. / ys);
503 
504         // update scale of initial Hessian approximation
505         H_diag = ys / y.dot(y); // (y*y)
506       }
507 
508       // compute the approximate (L-BFGS) inverse Hessian
509       // multiplied by the gradient
510       int64_t num_old = static_cast<int64_t>(old_dirs.size());
511 
512       if (state.al() == std::nullopt) {
513         state.al(std::vector<Tensor>(history_size));
514       }
515       auto& al = state.al();
516 
517       // iteration in L-BFGS loop collapsed to use just one buffer
518       auto q = flat_grad.neg();
519       for (int64_t i = num_old - 1; i > -1; i--) {
520         (*al).at(i) = old_stps.at(i).dot(q) * ro.at(i);
521         q.add_(old_dirs.at(i), -val((*al).at(i)));
522       }
523 
524       // multiply by initial Hessian
525       // r/d is the final direction
526       auto r = torch::mul(q, H_diag);
527       d = r;
528       for (const auto i : c10::irange(num_old)) {
529         auto be_i = old_dirs.at(i).dot(r) * ro.at(i);
530         r.add_(old_stps.at(i), val((*al).at(i) - be_i));
531       }
532     }
533 
534     if (!prev_flat_grad.defined()) {
535       prev_flat_grad = flat_grad.clone(at::MemoryFormat::Contiguous);
536     } else {
537       prev_flat_grad.copy_(flat_grad);
538     }
539     prev_loss = loss;
540 
541     // ############################################################
542     // # compute step length
543     // ############################################################
544     // reset initial guess for step size
545     if (state.n_iter() == 1) {
546       t = std::min(1., 1. / val(flat_grad.abs().sum())) * lr;
547     } else {
548       t = lr;
549     }
550 
551     // directional derivative
552     auto gtd = flat_grad.dot(d); // g * d
553 
554     // directional derivative is below tolerance
555     if (val(gtd) > -tolerance_change)
556       break;
557 
558     // optional line search: user function
559     auto ls_func_evals = 0;
560     if (line_search_fn != std::nullopt) {
561       TORCH_CHECK(
562           *line_search_fn == "strong_wolfe",
563           "only 'strong_wolfe' is supported");
564       auto x_init = _clone_param();
565       auto obj_func =
566           [&](const std::vector<Tensor>& x, double t, const Tensor& d) {
567             return _directional_evaluate(closure, x, t, d);
568           };
569       std::tie(loss, flat_grad, t, ls_func_evals) =
570           _strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd);
571       _add_grad(t, d);
572       opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad);
573     } else {
574       // no line search, simply move with fixed-step
575       _add_grad(t, d);
576       if (n_iter != max_iter) {
577         // re-evaluate function only if not in last iteration
578         // the reason we do this: in a stochastic setting,
579         // no use to re-evaluate that function here
580         {
581           torch::AutoGradMode enable_grad(true);
582           loss = val(closure());
583         }
584         flat_grad = _gather_flat_grad();
585         opt_cond = val(torch::max(flat_grad.abs())) <= tolerance_grad;
586         ls_func_evals = 1;
587       }
588     }
589     // update func eval
590     current_evals += ls_func_evals;
591     state.func_evals(state.func_evals() + ls_func_evals);
592 
593     // ############################################################
594     // # check conditions
595     // ############################################################
596     if (n_iter == max_iter)
597       break;
598 
599     if (current_evals >= *max_eval)
600       break;
601 
602     // optimal condition
603     if (opt_cond)
604       break;
605 
606     // lack of progress
607     if (val(d.mul(t).abs().max()) <= tolerance_change)
608       break;
609 
610     if (std::abs(loss - prev_loss) < tolerance_change)
611       break;
612   }
613 
614   return orig_loss;
615 }
616 
save(serialize::OutputArchive & archive) const617 void LBFGS::save(serialize::OutputArchive& archive) const {
618   serialize(*this, archive);
619 }
620 
load(serialize::InputArchive & archive)621 void LBFGS::load(serialize::InputArchive& archive) {
622   IValue pytorch_version;
623   if (archive.try_read("pytorch_version", pytorch_version)) {
624     serialize(*this, archive);
625   } else { // deserializing archives saved in old format (prior to
626            // version 1.5.0)
627     TORCH_WARN(
628         "Your serialized LBFGS optimizer is still using the old serialization format. "
629         "The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque "
630         "and al will be set to std::nullopt because the old LBFGS optimizer didn't save these values."
631         "You should re-save your LBFGS optimizer to use the new serialization format.");
632     Tensor d, t, H_diag, prev_flat_grad, prev_loss;
633     std::deque<Tensor> old_dirs, old_stps;
634     archive("d", d, /*is_buffer=*/true);
635     archive("t", t, /*is_buffer=*/true);
636     archive("H_diag", H_diag, /*is_buffer=*/true);
637     archive("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
638     archive("prev_loss", prev_loss, /*is_buffer=*/true);
639     torch::optim::serialize(archive, "old_dirs", old_dirs);
640     torch::optim::serialize(archive, "old_stps", old_stps);
641 
642     // NOTE: LBFGS has only global state, but we register it as state for
643     // the first param, because this helps with casting in load_state_dict
644     auto state = std::make_unique<LBFGSParamState>();
645     state->d(d);
646     state->t(t.item<double>());
647     state->H_diag(H_diag);
648     state->prev_flat_grad(prev_flat_grad);
649     state->prev_loss(prev_loss.item<double>());
650     state->old_dirs(old_dirs);
651     state->old_stps(old_stps);
652     state_[param_groups_.at(0).params().at(0).unsafeGetTensorImpl()] =
653         std::move(state);
654   }
655 }
656 } // namespace optim
657 } // namespace torch
658