xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/optim/schedulers/step_lr.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/optim/schedulers/step_lr.h>
2 
3 namespace torch {
4 namespace optim {
5 
StepLR(torch::optim::Optimizer & optimizer,const unsigned step_size,const double gamma)6 StepLR::StepLR(
7     torch::optim::Optimizer& optimizer,
8     const unsigned step_size,
9     const double gamma)
10     : LRScheduler(optimizer), step_size_(step_size), gamma_(gamma) {}
11 
get_lrs()12 std::vector<double> StepLR::get_lrs() {
13   if (step_count_ == 0 || step_count_ % step_size_ != 0)
14     return get_current_lrs();
15   else {
16     std::vector<double> lrs = get_current_lrs();
17     std::transform(
18         lrs.begin(), lrs.end(), lrs.begin(), [this](const double& v) {
19           return this->gamma_ * v;
20         });
21     return lrs;
22   }
23 }
24 
25 } // namespace optim
26 } // namespace torch
27