xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/optim/optimizer.h>
4 
5 #include <torch/csrc/Export.h>
6 
7 namespace torch {
8 namespace optim {
9 
10 class TORCH_API LRScheduler {
11  public:
12   // This class needs to take a reference of an optimizer from outside such that
13   // it can modify its learning rates; due to this the lifetime of said
14   // optimizer must be maintained
15   LRScheduler(torch::optim::Optimizer& optimizer);
16 
17   virtual ~LRScheduler() = default;
18 
19   void step();
20 
21  protected:
22   // A vector of learning rates is calculated and returned from the specific
23   // subclass. A vector is returned with each element being a separate learning
24   // rate for each param group - although the normal use case would be to return
25   // a vector of identical elements.
26   virtual std::vector<double> get_lrs() = 0;
27 
28   // Get current learning rates from the optimizer
29   std::vector<double> get_current_lrs() const;
30 
31   unsigned step_count_{};
32 
33  private:
34   void set_optimizer_lrs(const std::vector<double>& learning_rates);
35 
36   torch::optim::Optimizer& optimizer_;
37 };
38 } // namespace optim
39 } // namespace torch
40