1 #include <c10/util/irange.h> 2 #include <torch/optim/schedulers/lr_scheduler.h> 3 4 namespace torch { 5 namespace optim { 6 LRScheduler(torch::optim::Optimizer & optimizer)7LRScheduler::LRScheduler(torch::optim::Optimizer& optimizer) 8 : optimizer_(optimizer) {} 9 step()10void LRScheduler::step() { 11 std::vector<double> learning_rates = get_lrs(); 12 set_optimizer_lrs(learning_rates); 13 step_count_++; 14 } 15 set_optimizer_lrs(const std::vector<double> & learning_rates)16void LRScheduler::set_optimizer_lrs(const std::vector<double>& learning_rates) { 17 // Check the number of learning rates is equal to the number of parameters 18 // groups in the optimizer 19 TORCH_CHECK( 20 learning_rates.size() == optimizer_.param_groups().size(), 21 "Number of learning rates not equal to the number of param groups\n", 22 "Number of learning rates given: ", 23 learning_rates.size(), 24 "\nNumber of param groups: ", 25 optimizer_.param_groups().size()); 26 27 for (const auto i : c10::irange(optimizer_.param_groups().size())) { 28 optimizer_.param_groups()[i].options().set_lr(learning_rates[i]); 29 } 30 } 31 get_current_lrs() const32std::vector<double> LRScheduler::get_current_lrs() const { 33 std::vector<double> learnings_rates(optimizer_.param_groups().size()); 34 if (!learnings_rates.empty()) { 35 for (const auto i : c10::irange(optimizer_.param_groups().size())) { 36 learnings_rates[i] = optimizer_.param_groups()[i].options().get_lr(); 37 } 38 } 39 return learnings_rates; 40 } 41 42 } // namespace optim 43 } // namespace torch 44