xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/optim/schedulers/lr_scheduler.h>
3 
4 namespace torch {
5 namespace optim {
6 
LRScheduler(torch::optim::Optimizer & optimizer)7 LRScheduler::LRScheduler(torch::optim::Optimizer& optimizer)
8     : optimizer_(optimizer) {}
9 
step()10 void LRScheduler::step() {
11   std::vector<double> learning_rates = get_lrs();
12   set_optimizer_lrs(learning_rates);
13   step_count_++;
14 }
15 
set_optimizer_lrs(const std::vector<double> & learning_rates)16 void LRScheduler::set_optimizer_lrs(const std::vector<double>& learning_rates) {
17   // Check the number of learning rates is equal to the number of parameters
18   // groups in the optimizer
19   TORCH_CHECK(
20       learning_rates.size() == optimizer_.param_groups().size(),
21       "Number of learning rates not equal to the number of param groups\n",
22       "Number of learning rates given: ",
23       learning_rates.size(),
24       "\nNumber of param groups: ",
25       optimizer_.param_groups().size());
26 
27   for (const auto i : c10::irange(optimizer_.param_groups().size())) {
28     optimizer_.param_groups()[i].options().set_lr(learning_rates[i]);
29   }
30 }
31 
get_current_lrs() const32 std::vector<double> LRScheduler::get_current_lrs() const {
33   std::vector<double> learnings_rates(optimizer_.param_groups().size());
34   if (!learnings_rates.empty()) {
35     for (const auto i : c10::irange(optimizer_.param_groups().size())) {
36       learnings_rates[i] = optimizer_.param_groups()[i].options().get_lr();
37     }
38   }
39   return learnings_rates;
40 }
41 
42 } // namespace optim
43 } // namespace torch
44