1 #include <torch/optim/lbfgs.h>
2
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/utils.h>
7
8 #include <ATen/ATen.h>
9 #include <c10/util/irange.h>
10
11 #include <algorithm>
12 #include <cmath>
13 #include <functional>
14 #include <vector>
15
16 namespace torch {
17 namespace optim {
18
LBFGSOptions(double lr)19 LBFGSOptions::LBFGSOptions(double lr) : lr_(lr) {}
20
operator ==(const LBFGSOptions & lhs,const LBFGSOptions & rhs)21 bool operator==(const LBFGSOptions& lhs, const LBFGSOptions& rhs) {
22 return (lhs.lr() == rhs.lr()) && (lhs.max_iter() == rhs.max_iter()) &&
23 (lhs.max_eval() == rhs.max_eval()) &&
24 (lhs.tolerance_grad() == rhs.tolerance_grad()) &&
25 (lhs.tolerance_change() == rhs.tolerance_change() &&
26 (lhs.history_size() == rhs.history_size())) &&
27 (lhs.line_search_fn() == rhs.line_search_fn());
28 }
29
serialize(torch::serialize::OutputArchive & archive) const30 void LBFGSOptions::serialize(torch::serialize::OutputArchive& archive) const {
31 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
32 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_iter);
33 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_eval);
34 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_grad);
35 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_change);
36 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(history_size);
37 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(line_search_fn);
38 }
39
serialize(torch::serialize::InputArchive & archive)40 void LBFGSOptions::serialize(torch::serialize::InputArchive& archive) {
41 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
42 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, max_iter);
43 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(int64_t, max_eval);
44 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_grad);
45 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_change);
46 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, history_size);
47 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::string, line_search_fn);
48 }
49
get_lr() const50 double LBFGSOptions::get_lr() const {
51 return lr();
52 }
53
set_lr(const double lr)54 void LBFGSOptions::set_lr(const double lr) {
55 this->lr(lr);
56 }
57
58 template <typename T>
if_container_equal(T lhs,T rhs)59 bool if_container_equal(T lhs, T rhs) {
60 if (!(lhs.size() == rhs.size()))
61 return false;
62 for (const auto i : c10::irange(lhs.size())) {
63 if (!torch::equal(lhs.at(i), rhs.at(i)))
64 return false;
65 }
66 return true;
67 }
68
operator ==(const LBFGSParamState & lhs,const LBFGSParamState & rhs)69 bool operator==(const LBFGSParamState& lhs, const LBFGSParamState& rhs) {
70 auto isNull = [](const std::optional<std::vector<Tensor>>& val) {
71 return val == std::nullopt;
72 };
73 return (lhs.func_evals() == rhs.func_evals()) &&
74 (lhs.n_iter() == rhs.n_iter()) && (lhs.t() == rhs.t()) &&
75 (lhs.prev_loss() == rhs.prev_loss()) &&
76 torch::equal_if_defined(lhs.d(), rhs.d()) &&
77 torch::equal_if_defined(lhs.H_diag(), rhs.H_diag()) &&
78 torch::equal_if_defined(lhs.prev_flat_grad(), rhs.prev_flat_grad()) &&
79 if_container_equal(lhs.old_dirs(), rhs.old_dirs()) &&
80 if_container_equal(lhs.old_stps(), rhs.old_stps()) &&
81 if_container_equal(lhs.ro(), rhs.ro()) &&
82 ((isNull(lhs.al()) && isNull(rhs.al())) ||
83 (!isNull(lhs.al()) && !isNull(rhs.al()) &&
84 if_container_equal(*lhs.al(), *rhs.al())));
85 }
86
serialize(torch::serialize::OutputArchive & archive) const87 void LBFGSParamState::serialize(
88 torch::serialize::OutputArchive& archive) const {
89 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(func_evals);
90 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(n_iter);
91 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(t);
92 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_loss);
93 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(d);
94 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(H_diag);
95 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_flat_grad);
96 _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_dirs);
97 _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_stps);
98 _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(ro);
99 // Python version only serializes state vars if explicitly defined
100 if (al() != std::nullopt) {
101 _TORCH_OPTIM_SERIALIZE_TORCH_ARG(al);
102 }
103 }
104
serialize(torch::serialize::InputArchive & archive)105 void LBFGSParamState::serialize(torch::serialize::InputArchive& archive) {
106 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, func_evals);
107 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, n_iter);
108 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, t);
109 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, prev_loss);
110 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, d);
111 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, H_diag);
112 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, prev_flat_grad);
113 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, old_dirs);
114 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, old_stps);
115 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque<Tensor>, ro);
116 _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::vector<Tensor>, al);
117 }
118
_gather_flat_grad()119 Tensor LBFGS::_gather_flat_grad() {
120 std::vector<Tensor> views;
121 for (const auto& p : param_groups_.at(0).params()) {
122 if (!p.grad().defined()) {
123 views.emplace_back(p.new_empty({p.numel()}).zero_());
124 } else if (p.grad().is_sparse()) {
125 views.emplace_back(p.grad().to_dense().view(-1));
126 } else {
127 views.emplace_back(p.grad().view(-1));
128 }
129 }
130 return torch::cat(views, 0);
131 }
132
_numel()133 int64_t LBFGS::_numel() {
134 if (_numel_cache == std::nullopt) {
135 auto res = 0;
136 for (const auto& p : param_groups_.at(0).params()) {
137 res += p.numel();
138 }
139 _numel_cache = res;
140 }
141 return *_numel_cache;
142 }
143
_add_grad(const double step_size,const Tensor & update)144 void LBFGS::_add_grad(const double step_size, const Tensor& update) {
145 auto offset = 0;
146 for (auto& p : param_groups_.at(0).params()) {
147 auto numel = p.numel();
148 // view as to avoid deprecated pointwise semantics
149 p.add_(
150 update.index({at::indexing::Slice(offset, offset + numel)}).view_as(p),
151 step_size);
152 offset += numel;
153 }
154 TORCH_INTERNAL_ASSERT(offset == _numel());
155 }
156
_set_param(const std::vector<Tensor> & params_data)157 void LBFGS::_set_param(const std::vector<Tensor>& params_data) {
158 auto& _params = param_groups_.at(0).params();
159 TORCH_INTERNAL_ASSERT(params_data.size() == _params.size());
160 for (const auto i : c10::irange(_params.size())) {
161 _params.at(i).copy_(params_data.at(i));
162 }
163 }
164
_clone_param()165 std::vector<Tensor> LBFGS::_clone_param() {
166 std::vector<Tensor> result;
167 for (const auto& p : param_groups_.at(0).params()) {
168 result.emplace_back(p.clone(at::MemoryFormat::Contiguous));
169 }
170 return result;
171 }
172
_directional_evaluate(const LossClosure & closure,const std::vector<Tensor> & x,double t,const Tensor & d)173 std::tuple<double, Tensor> LBFGS::_directional_evaluate(
174 const LossClosure& closure,
175 const std::vector<Tensor>& x,
176 double t,
177 const Tensor& d) {
178 _add_grad(t, d);
179 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
180 double loss;
181 {
182 torch::AutoGradMode enable_grad(true);
183 loss = closure().item<double>();
184 }
185 auto flat_grad = _gather_flat_grad();
186 _set_param(x);
187 return std::make_tuple(loss, flat_grad);
188 }
189
_cubic_interpolate(double x1,double f1,double g1,double x2,double f2,double g2,std::optional<std::tuple<double,double>> bounds=std::nullopt)190 static double _cubic_interpolate(
191 double x1,
192 double f1,
193 double g1,
194 double x2,
195 double f2,
196 double g2,
197 std::optional<std::tuple<double, double>> bounds = std::nullopt) {
198 // ported from https://github.com/torch/optim/blob/master/polyinterp.lua
199 // Compute bounds of interpolation area
200 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
201 double xmin_bound, xmax_bound;
202 if (bounds != std::nullopt) {
203 std::tie(xmin_bound, xmax_bound) = *bounds;
204 } else {
205 std::tie(xmin_bound, xmax_bound) =
206 (x1 <= x2) ? std::make_tuple(x1, x2) : std::make_tuple(x2, x1);
207 }
208 // Code for most common case: cubic interpolation of 2 points
209 // w/ function and derivative values for both
210 // Solution in this case (where x2 is the farthest point):
211 // d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
212 // d2 = sqrt(d1^2 - g1*g2);
213 // min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
214 // t_new = min(max(min_pos,xmin_bound),xmax_bound);
215
216 auto d1 = (g1 + g2) - (3 * (f1 - f2) / (x1 - x2));
217 auto d2_square = std::pow(d1, 2) - g1 * g2;
218 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
219 double d2;
220 if (d2_square >= 0) {
221 d2 = std::sqrt(d2_square);
222 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
223 double min_pos;
224 if (x1 <= x2) {
225 min_pos = x2 - ((x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)));
226 } else {
227 min_pos = x1 - ((x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)));
228 }
229 return std::min(std::max(min_pos, xmin_bound), xmax_bound);
230 } else {
231 return (xmin_bound + xmax_bound) / 2;
232 }
233 }
234
235 using Function = std::function<std::tuple<double, Tensor>(
236 const std::vector<Tensor>& x,
237 double t,
238 const Tensor& d)>;
_strong_wolfe(const Function & obj_func,const std::vector<Tensor> & x,double t,const Tensor & d,double f,Tensor g,const Tensor & gtd,double c1=1e-4,double c2=0.9,double tolerance_change=1e-9,double max_ls=25)239 static std::tuple<double, Tensor, double, int64_t> _strong_wolfe(
240 const Function& obj_func,
241 const std::vector<Tensor>& x,
242 double t,
243 const Tensor& d,
244 double f,
245 Tensor g,
246 const Tensor& gtd,
247 double c1 = 1e-4,
248 double c2 = 0.9, // // NOLINT(cppcoreguidelines-avoid-magic-numbers)
249 double tolerance_change = 1e-9,
250 double max_ls = 25) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
251
252 auto val = [](const Tensor& t) { return t.item<double>(); };
253
254 auto d_norm = val(d.abs().max());
255 g = g.clone(at::MemoryFormat::Contiguous);
256 // evaluate objective and gradient using initial step
257 auto [f_new, g_new] = obj_func(x, t, d);
258 int64_t ls_func_evals = 1;
259 auto gtd_new = g_new.dot(d);
260
261 // bracket an interval containing a point satisfying the Wolfe criteria
262 double t_prev = 0;
263 auto f_prev = f;
264 auto g_prev = g;
265 auto gtd_prev = gtd;
266 bool done = false;
267 auto ls_iter = 0;
268 std::vector<double> bracket, bracket_f;
269 std::vector<Tensor> bracket_g, bracket_gtd;
270
271 while (ls_iter < max_ls) {
272 // check conditions
273 if ((f_new > (f + c1 * t * val(gtd))) ||
274 (ls_iter > 1 && (f_new >= f_prev))) {
275 bracket = {t_prev, t};
276 bracket_f = {f_prev, f_new};
277 bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)};
278 bracket_gtd = {gtd_prev, gtd_new};
279 break;
280 }
281 if (std::abs(val(gtd_new)) <= (-c2 * val(gtd))) {
282 bracket = {t, t};
283 bracket_f = {f_new, f_new};
284 bracket_g = {g_new, g_new};
285 done = true;
286 break;
287 }
288 if (val(gtd_new) >= 0) {
289 bracket = {t_prev, t};
290 bracket_f = {f_prev, f_new};
291 bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)};
292 bracket_gtd = {gtd_prev, gtd_new};
293 break;
294 }
295 // interpolate
296 auto min_step = t +
297 0.01 * (t - t_prev); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
298 auto max_step = t * 10; // NOLINT(cppcoreguidelines-avoid-magic-numbers)
299 auto tmp = t;
300 t = _cubic_interpolate(
301 t_prev,
302 f_prev,
303 val(gtd_prev),
304 t,
305 f_new,
306 val(gtd_new),
307 std::make_tuple(min_step, max_step));
308 // next step
309 t_prev = tmp;
310 f_prev = f_new;
311 g_prev = g_new.clone(at::MemoryFormat::Contiguous);
312 gtd_prev = gtd_new;
313 std::tie(f_new, g_new) = obj_func(x, t, d);
314 ls_func_evals += 1;
315 gtd_new = g_new.dot(d);
316 ls_iter += 1;
317 }
318 // reached max number of iterations?
319 if (ls_iter == max_ls) {
320 bracket = {0, t};
321 bracket_f = {f, f_new};
322 bracket_g = {g, g_new};
323 }
324
325 // zoom phase: we now have a point satisfying the criteria, or
326 // a bracket around it. We refine the bracket until we find the
327 // exact point satisfying the criteria
328 bool insuf_progress = false;
329 // find high and low points in bracket
330 auto [low_pos, high_pos] = bracket_f[0] <= bracket_f[1]
331 ? std::make_tuple(0, 1)
332 : std::make_tuple(1, 0);
333 while (!done && (ls_iter < max_ls)) {
334 // compute new trial value
335 t = _cubic_interpolate(
336 bracket[0],
337 bracket_f[0],
338 val(bracket_gtd[0]),
339 bracket[1],
340 bracket_f[1],
341 val(bracket_gtd[1]));
342
343 // test that we are making sufficient progress:
344 // in case `t` is so close to boundary, we mark that we are making
345 // insufficient progress, and if
346 // + we have made insufficient progress in the last step, or
347 // + `t` is at one of the boundary,
348 // we will move `t` to a position which is `0.1 * len(bracket)`
349 // away from the nearest boundary point.
350 double bracket_max = std::max(bracket[0], bracket[1]);
351 auto bracket_min = std::min(bracket[0], bracket[1]);
352 auto eps = 0.1 *
353 (bracket_max -
354 bracket_min); // // NOLINT(cppcoreguidelines-avoid-magic-numbers)
355 if (std::min(bracket_max - t, t - bracket_min) < eps) {
356 // interpolation close to boundary
357 if (insuf_progress || (t >= bracket_max) || (t <= bracket_min)) {
358 // evaluate at 0.1 away from boundary
359 t = (std::abs(t - bracket_max) < std::abs(t - bracket_min))
360 ? bracket_max - eps
361 : bracket_min + eps;
362 insuf_progress = false;
363 } else {
364 insuf_progress = true;
365 }
366 } else {
367 insuf_progress = false;
368 }
369
370 // Evaluate new point
371 std::tie(f_new, g_new) = obj_func(x, t, d);
372 ls_func_evals += 1;
373 gtd_new = g_new.dot(d);
374 ls_iter += 1;
375
376 if ((f_new > (f + c1 * t * val(gtd))) || (f_new >= bracket_f[low_pos])) {
377 // Armijo condition not satisfied or not lower than lowest point
378 // # Armijo condition not satisfied or not lower than lowest point
379 bracket[high_pos] = t;
380 bracket_f[high_pos] = f_new;
381 bracket_g[high_pos] = g_new.clone(at::MemoryFormat::Contiguous);
382 bracket_gtd[high_pos] = gtd_new;
383 std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1]
384 ? std::make_tuple(0, 1)
385 : std::make_tuple(1, 0);
386 } else {
387 if (val(at::abs(gtd_new)) <= (-c2 * val(gtd))) {
388 // Wolfe conditions satisfied
389 done = true;
390 } else if ((val(gtd_new) * (bracket[high_pos] - bracket[low_pos])) >= 0) {
391 // old high becomes new low
392 bracket[high_pos] = bracket[low_pos];
393 bracket_f[high_pos] = bracket_f[low_pos];
394 bracket_g[high_pos] = bracket_g[low_pos];
395 bracket_gtd[high_pos] = bracket_gtd[low_pos];
396 }
397
398 // new point becomes new low
399 bracket[low_pos] = t;
400 bracket_f[low_pos] = f_new;
401 bracket_g[low_pos] = g_new.clone(at::MemoryFormat::Contiguous);
402 bracket_gtd[low_pos] = gtd_new;
403 }
404
405 // line-search bracket is so small
406 if ((std::abs(bracket[1] - bracket[0]) * d_norm) < tolerance_change)
407 break;
408 }
409
410 // return stuff
411 t = bracket[low_pos];
412 f_new = bracket_f[low_pos];
413 g_new = bracket_g[low_pos];
414 return std::make_tuple(f_new, g_new, t, ls_func_evals);
415 }
416
step(LossClosure closure)417 Tensor LBFGS::step(LossClosure closure) {
418 NoGradGuard no_grad;
419 TORCH_CHECK(closure != nullptr, "LBFGS requires a closure function");
420 TORCH_INTERNAL_ASSERT(param_groups_.size() == 1);
421 auto val = [](const Tensor& t) { return t.item<double>(); };
422
423 auto& group = param_groups_.at(0);
424 auto& _params = group.params();
425 const auto& options = static_cast<const LBFGSOptions&>(group.options());
426 auto lr = options.lr();
427 auto max_iter = options.max_iter();
428 auto max_eval = options.max_eval();
429 auto tolerance_grad = options.tolerance_grad();
430 auto tolerance_change = options.tolerance_change();
431 auto line_search_fn = options.line_search_fn();
432 auto history_size = options.history_size();
433
434 // NOTE: LBFGS has only global state, but we register it as state for
435 // the first param, because this helps with casting in load_state_dict
436 auto param_state = state_.find(_params.at(0).unsafeGetTensorImpl());
437 if (param_state == state_.end()) {
438 state_[_params.at(0).unsafeGetTensorImpl()] =
439 std::make_unique<LBFGSParamState>();
440 }
441 auto& state = static_cast<LBFGSParamState&>(
442 *state_[_params.at(0).unsafeGetTensorImpl()]);
443 // evaluate initial f(x) and df/dx
444 Tensor orig_loss;
445 {
446 torch::AutoGradMode enable_grad(true);
447 orig_loss = closure();
448 }
449
450 auto loss = val(orig_loss);
451 auto current_evals = 1;
452 state.func_evals(state.func_evals() + 1);
453 auto flat_grad = _gather_flat_grad();
454 auto opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad);
455
456 // optimal condition
457 if (opt_cond) {
458 return orig_loss;
459 }
460
461 // tensors cached in state (for tracing)
462 auto& d = state.d();
463 auto& t = state.t();
464 auto& old_dirs = state.old_dirs();
465 auto& old_stps = state.old_stps();
466 auto& ro = state.ro();
467 auto& H_diag = state.H_diag();
468 auto& prev_flat_grad = state.prev_flat_grad();
469 auto& prev_loss = state.prev_loss();
470
471 int n_iter = 0;
472
473 // optimize for a max of max_iter iterations
474 while (n_iter < max_iter) {
475 // keep track of nb of iterations
476 n_iter += 1;
477 state.n_iter(state.n_iter() + 1);
478
479 // compute gradient descent direction
480 if (state.n_iter() == 1) {
481 d = flat_grad.neg();
482 H_diag = torch::tensor(1);
483 old_dirs = {};
484 old_stps = {};
485 ro = {};
486 } else {
487 // do lbfgs update (update memory)
488 auto y = flat_grad.sub(prev_flat_grad);
489 auto s = d.mul(t);
490 auto ys = y.dot(s); // y*s
491 if (val(ys) > 1e-10) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
492 // updating memory
493 if (static_cast<int64_t>(old_dirs.size()) == history_size) {
494 // shift history by one (limited-memory)
495 old_dirs.pop_front();
496 old_stps.pop_front();
497 ro.pop_front();
498 }
499 // store new direction/step
500 old_dirs.emplace_back(y);
501 old_stps.emplace_back(s);
502 ro.emplace_back(1. / ys);
503
504 // update scale of initial Hessian approximation
505 H_diag = ys / y.dot(y); // (y*y)
506 }
507
508 // compute the approximate (L-BFGS) inverse Hessian
509 // multiplied by the gradient
510 int64_t num_old = static_cast<int64_t>(old_dirs.size());
511
512 if (state.al() == std::nullopt) {
513 state.al(std::vector<Tensor>(history_size));
514 }
515 auto& al = state.al();
516
517 // iteration in L-BFGS loop collapsed to use just one buffer
518 auto q = flat_grad.neg();
519 for (int64_t i = num_old - 1; i > -1; i--) {
520 (*al).at(i) = old_stps.at(i).dot(q) * ro.at(i);
521 q.add_(old_dirs.at(i), -val((*al).at(i)));
522 }
523
524 // multiply by initial Hessian
525 // r/d is the final direction
526 auto r = torch::mul(q, H_diag);
527 d = r;
528 for (const auto i : c10::irange(num_old)) {
529 auto be_i = old_dirs.at(i).dot(r) * ro.at(i);
530 r.add_(old_stps.at(i), val((*al).at(i) - be_i));
531 }
532 }
533
534 if (!prev_flat_grad.defined()) {
535 prev_flat_grad = flat_grad.clone(at::MemoryFormat::Contiguous);
536 } else {
537 prev_flat_grad.copy_(flat_grad);
538 }
539 prev_loss = loss;
540
541 // ############################################################
542 // # compute step length
543 // ############################################################
544 // reset initial guess for step size
545 if (state.n_iter() == 1) {
546 t = std::min(1., 1. / val(flat_grad.abs().sum())) * lr;
547 } else {
548 t = lr;
549 }
550
551 // directional derivative
552 auto gtd = flat_grad.dot(d); // g * d
553
554 // directional derivative is below tolerance
555 if (val(gtd) > -tolerance_change)
556 break;
557
558 // optional line search: user function
559 auto ls_func_evals = 0;
560 if (line_search_fn != std::nullopt) {
561 TORCH_CHECK(
562 *line_search_fn == "strong_wolfe",
563 "only 'strong_wolfe' is supported");
564 auto x_init = _clone_param();
565 auto obj_func =
566 [&](const std::vector<Tensor>& x, double t, const Tensor& d) {
567 return _directional_evaluate(closure, x, t, d);
568 };
569 std::tie(loss, flat_grad, t, ls_func_evals) =
570 _strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd);
571 _add_grad(t, d);
572 opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad);
573 } else {
574 // no line search, simply move with fixed-step
575 _add_grad(t, d);
576 if (n_iter != max_iter) {
577 // re-evaluate function only if not in last iteration
578 // the reason we do this: in a stochastic setting,
579 // no use to re-evaluate that function here
580 {
581 torch::AutoGradMode enable_grad(true);
582 loss = val(closure());
583 }
584 flat_grad = _gather_flat_grad();
585 opt_cond = val(torch::max(flat_grad.abs())) <= tolerance_grad;
586 ls_func_evals = 1;
587 }
588 }
589 // update func eval
590 current_evals += ls_func_evals;
591 state.func_evals(state.func_evals() + ls_func_evals);
592
593 // ############################################################
594 // # check conditions
595 // ############################################################
596 if (n_iter == max_iter)
597 break;
598
599 if (current_evals >= *max_eval)
600 break;
601
602 // optimal condition
603 if (opt_cond)
604 break;
605
606 // lack of progress
607 if (val(d.mul(t).abs().max()) <= tolerance_change)
608 break;
609
610 if (std::abs(loss - prev_loss) < tolerance_change)
611 break;
612 }
613
614 return orig_loss;
615 }
616
save(serialize::OutputArchive & archive) const617 void LBFGS::save(serialize::OutputArchive& archive) const {
618 serialize(*this, archive);
619 }
620
load(serialize::InputArchive & archive)621 void LBFGS::load(serialize::InputArchive& archive) {
622 IValue pytorch_version;
623 if (archive.try_read("pytorch_version", pytorch_version)) {
624 serialize(*this, archive);
625 } else { // deserializing archives saved in old format (prior to
626 // version 1.5.0)
627 TORCH_WARN(
628 "Your serialized LBFGS optimizer is still using the old serialization format. "
629 "The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque "
630 "and al will be set to std::nullopt because the old LBFGS optimizer didn't save these values."
631 "You should re-save your LBFGS optimizer to use the new serialization format.");
632 Tensor d, t, H_diag, prev_flat_grad, prev_loss;
633 std::deque<Tensor> old_dirs, old_stps;
634 archive("d", d, /*is_buffer=*/true);
635 archive("t", t, /*is_buffer=*/true);
636 archive("H_diag", H_diag, /*is_buffer=*/true);
637 archive("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
638 archive("prev_loss", prev_loss, /*is_buffer=*/true);
639 torch::optim::serialize(archive, "old_dirs", old_dirs);
640 torch::optim::serialize(archive, "old_stps", old_stps);
641
642 // NOTE: LBFGS has only global state, but we register it as state for
643 // the first param, because this helps with casting in load_state_dict
644 auto state = std::make_unique<LBFGSParamState>();
645 state->d(d);
646 state->t(t.item<double>());
647 state->H_diag(H_diag);
648 state->prev_flat_grad(prev_flat_grad);
649 state->prev_loss(prev_loss.item<double>());
650 state->old_dirs(old_dirs);
651 state->old_stps(old_stps);
652 state_[param_groups_.at(0).params().at(0).unsafeGetTensorImpl()] =
653 std::move(state);
654 }
655 }
656 } // namespace optim
657 } // namespace torch
658