1 #include <torch/optim/schedulers/step_lr.h> 2 3 namespace torch { 4 namespace optim { 5 StepLR(torch::optim::Optimizer & optimizer,const unsigned step_size,const double gamma)6StepLR::StepLR( 7 torch::optim::Optimizer& optimizer, 8 const unsigned step_size, 9 const double gamma) 10 : LRScheduler(optimizer), step_size_(step_size), gamma_(gamma) {} 11 get_lrs()12std::vector<double> StepLR::get_lrs() { 13 if (step_count_ == 0 || step_count_ % step_size_ != 0) 14 return get_current_lrs(); 15 else { 16 std::vector<double> lrs = get_current_lrs(); 17 std::transform( 18 lrs.begin(), lrs.end(), lrs.begin(), [this](const double& v) { 19 return this->gamma_ * v; 20 }); 21 return lrs; 22 } 23 } 24 25 } // namespace optim 26 } // namespace torch 27