1 #pragma once 2 3 #include <torch/optim/optimizer.h> 4 5 #include <torch/csrc/Export.h> 6 7 namespace torch { 8 namespace optim { 9 10 class TORCH_API LRScheduler { 11 public: 12 // This class needs to take a reference of an optimizer from outside such that 13 // it can modify its learning rates; due to this the lifetime of said 14 // optimizer must be maintained 15 LRScheduler(torch::optim::Optimizer& optimizer); 16 17 virtual ~LRScheduler() = default; 18 19 void step(); 20 21 protected: 22 // A vector of learning rates is calculated and returned from the specific 23 // subclass. A vector is returned with each element being a separate learning 24 // rate for each param group - although the normal use case would be to return 25 // a vector of identical elements. 26 virtual std::vector<double> get_lrs() = 0; 27 28 // Get current learning rates from the optimizer 29 std::vector<double> get_current_lrs() const; 30 31 unsigned step_count_{}; 32 33 private: 34 void set_optimizer_lrs(const std::vector<double>& learning_rates); 35 36 torch::optim::Optimizer& optimizer_; 37 }; 38 } // namespace optim 39 } // namespace torch 40