1# Owner(s): ["module: optimizer", "module: LrScheduler" ] 2import copy 3import math 4import pickle 5import tempfile 6import types 7import warnings 8from functools import partial 9 10import torch 11import torch.nn.functional as F 12from torch.nn import Parameter 13from torch.optim import Adam, Rprop, SGD 14from torch.optim.lr_scheduler import ( 15 ChainedScheduler, 16 ConstantLR, 17 CosineAnnealingLR, 18 CosineAnnealingWarmRestarts, 19 CyclicLR, 20 EPOCH_DEPRECATION_WARNING, 21 ExponentialLR, 22 LambdaLR, 23 LinearLR, 24 LRScheduler, 25 MultiplicativeLR, 26 MultiStepLR, 27 OneCycleLR, 28 PolynomialLR, 29 ReduceLROnPlateau, 30 SequentialLR, 31 StepLR, 32) 33from torch.optim.swa_utils import SWALR 34from torch.testing._internal.common_utils import ( 35 instantiate_parametrized_tests, 36 load_tests, 37 parametrize, 38 skipIfTorchDynamo, 39 TestCase, 40) 41 42 43# load_tests from common_utils is used to automatically filter tests for 44# sharding on sandcastle. This line silences flake warnings 45load_tests = load_tests 46 47 48class TestLRScheduler(TestCase): 49 class SchedulerTestNet(torch.nn.Module): 50 def __init__(self) -> None: 51 super().__init__() 52 self.conv1 = torch.nn.Conv2d(1, 1, 1) 53 self.conv2 = torch.nn.Conv2d(1, 1, 1) 54 55 def forward(self, x): 56 return self.conv2(F.relu(self.conv1(x))) 57 58 class LambdaLRTestObject: 59 def __init__(self, value): 60 self.value = value 61 62 def __call__(self, epoch): 63 return self.value * epoch 64 65 def __eq__(self, other): 66 if isinstance(other, self.__class__): 67 return self.__dict__ == other.__dict__ 68 else: 69 return False 70 71 exact_dtype = True 72 73 def setUp(self): 74 super().setUp() 75 self.net = self.SchedulerTestNet() 76 self.opt = SGD( 77 [ 78 {"params": self.net.conv1.parameters()}, 79 {"params": self.net.conv2.parameters(), "lr": 0.5}, 80 ], 81 lr=0.05, 82 ) 83 84 def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1): 85 """This function swallows the epoch deprecation warning which is produced when we 86 call `scheduler.step(epoch)` with some not `None` value of `epoch`. 87 this is deprecated, and this function will need to be removed/updated when 88 the schedulers no longer accept the parameter at all. 89 """ 90 self.assertEqual(len(w), num_warnings) 91 for warning in w: 92 self.assertEqual(len(warning.message.args), 1) 93 self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING) 94 95 def test_error_when_getlr_has_epoch(self): 96 class MultiStepLR(torch.optim.lr_scheduler.LRScheduler): 97 def __init__(self, optimizer, gamma, milestones, last_epoch=-1): 98 self.init_lr = [group["lr"] for group in optimizer.param_groups] 99 self.gamma = gamma 100 self.milestones = milestones 101 super().__init__(optimizer, last_epoch) 102 103 def get_lr(self, step): 104 global_step = self.last_epoch 105 gamma_power = ( 106 [0] 107 + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m] 108 )[-1] 109 return [ 110 init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr 111 ] 112 113 optimizer = SGD([torch.rand(1)], lr=1) 114 115 with self.assertRaises(TypeError): 116 scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20]) 117 118 @skipIfTorchDynamo( 119 "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames" 120 ) 121 def test_no_cyclic_references(self): 122 import gc 123 124 param = Parameter(torch.empty(10)) 125 optim = SGD([param], lr=0.5) 126 scheduler = LambdaLR(optim, lambda epoch: 1.0) 127 del scheduler 128 129 self.assertTrue( 130 len(gc.get_referrers(optim)) == 0, 131 "Optimizer should contain no cyclic references", 132 ) 133 134 gc.collect() 135 del optim 136 self.assertEqual( 137 gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__" 138 ) 139 140 @skipIfTorchDynamo( 141 "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames" 142 ) 143 def test_no_cyclic_references_in_step(self): 144 import gc 145 import weakref 146 147 def run(): 148 param = torch.empty(10, requires_grad=True) 149 optim = SGD(params=[param], lr=0.5) 150 scheduler = LambdaLR(optim, lambda epoch: 1.0) 151 param.sum().backward() 152 optim.step() 153 scheduler.step() 154 155 return weakref.ref(scheduler) 156 157 # To ensure that there are no reference cycles in scheduler, 158 # we need to turn off the garbage collector. Since gc will 159 # automatically collect unreachable objects. 160 gc.disable() 161 ref = run() 162 163 assert ref() is None 164 gc.enable() # restore 165 166 def test_old_pattern_warning(self): 167 epochs = 35 168 with warnings.catch_warnings(record=True) as ws: 169 warnings.simplefilter("always") # allow any warning to be raised 170 scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 171 self.assertTrue(len(ws) == 0, "No warning should be raised") 172 173 def old_pattern(): 174 for _ in range(epochs): 175 scheduler.step() 176 self.opt.step() 177 178 self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) 179 180 def test_old_pattern_warning_with_arg(self): 181 epochs = 35 182 with warnings.catch_warnings(record=True) as ws: 183 warnings.simplefilter("always") # allow any warning to be raised 184 scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 185 self.assertTrue(len(ws) == 0, "No warning should be raised") 186 187 def old_pattern2(): 188 for _ in range(epochs): 189 scheduler.step() 190 self.opt.step() 191 192 self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) 193 194 def test_old_pattern_warning_resuming(self): 195 epochs = 35 196 for i, group in enumerate(self.opt.param_groups): 197 group["initial_lr"] = 0.01 198 199 with warnings.catch_warnings(record=True) as ws: 200 warnings.simplefilter("always") # allow any warning to be raised 201 scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) 202 self.assertTrue(len(ws) == 0, "No warning should be raised") 203 204 def old_pattern(): 205 for _ in range(epochs): 206 scheduler.step() 207 self.opt.step() 208 209 self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) 210 211 def test_old_pattern_warning_resuming_with_arg(self): 212 epochs = 35 213 for i, group in enumerate(self.opt.param_groups): 214 group["initial_lr"] = 0.01 215 216 with warnings.catch_warnings(record=True) as ws: 217 warnings.simplefilter("always") # allow any warning to be raised 218 scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) 219 self.assertTrue(len(ws) == 0, "No warning should be raised") 220 221 def old_pattern2(): 222 for _ in range(epochs): 223 scheduler.step() 224 self.opt.step() 225 226 self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) 227 228 def test_old_pattern_warning_with_overridden_optim_step(self): 229 epochs = 35 230 for i, group in enumerate(self.opt.param_groups): 231 group["initial_lr"] = 0.01 232 233 with warnings.catch_warnings(record=True) as ws: 234 warnings.simplefilter("always") # allow any warning to be raised 235 scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) 236 self.assertTrue(len(ws) == 0, "No warning should be raised") 237 238 # emulate use-case with optimizer.step overridden 239 import types 240 241 old_step = self.opt.step 242 243 def new_step(o, *args, **kwargs): 244 retval = old_step(*args, **kwargs) 245 return retval 246 247 self.opt.step = types.MethodType(new_step, self.opt) 248 249 def old_pattern2(): 250 for _ in range(epochs): 251 scheduler.step() 252 self.opt.step() 253 254 self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) 255 256 def test_new_pattern_no_warning(self): 257 epochs = 35 258 with warnings.catch_warnings(record=True) as ws: 259 warnings.simplefilter("always") # allow any warning to be raised 260 scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 261 self.assertTrue(len(ws) == 0, "No warning should be raised") 262 263 with warnings.catch_warnings(record=True) as ws: 264 warnings.simplefilter("always") # allow any warning to be raised 265 for _ in range(epochs): 266 self.opt.step() 267 scheduler.step() 268 self.assertTrue(len(ws) == 0, "No warning should be raised") 269 270 def test_new_pattern_no_warning_with_arg(self): 271 epochs = 35 272 with warnings.catch_warnings(record=True) as ws: 273 warnings.simplefilter("always") # allow any warning to be raised 274 scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 275 self.assertTrue(len(ws) == 0, "No warning should be raised") 276 277 with warnings.catch_warnings(record=True) as ws: 278 warnings.simplefilter("always") # allow any warning to be raised 279 for _ in range(epochs): 280 self.opt.step() 281 scheduler.step() 282 self.assertTrue(len(ws) == 0, "No warning should be raised") 283 284 def test_new_pattern_no_warning_with_overridden_optim_step(self): 285 epochs = 35 286 with warnings.catch_warnings(record=True) as ws: 287 warnings.simplefilter("always") # allow any warning to be raised 288 scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 289 self.assertTrue(len(ws) == 0, "No warning should be raised") 290 291 # emulate use-case with optimizer.step overridden 292 import types 293 294 old_step = self.opt.step 295 296 def new_step(o, *args, **kwargs): 297 retval = old_step(*args, **kwargs) 298 return retval 299 300 self.opt.step = types.MethodType(new_step, self.opt) 301 302 def new_pattern(): 303 for e in range(epochs): 304 self.opt.step() 305 scheduler.step() 306 307 self.assertWarnsRegex( 308 UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern 309 ) 310 311 def _test_lr_is_constant_for_constant_epoch(self, scheduler): 312 l = [] 313 314 for _ in range(10): 315 scheduler.optimizer.step() 316 with warnings.catch_warnings(record=True) as w: 317 scheduler.step(2) 318 self._check_warning_is_epoch_deprecation_warning(w) 319 320 l.append(self.opt.param_groups[0]["lr"]) 321 self.assertEqual(min(l), max(l)) 322 323 def test_step_lr_is_constant_for_constant_epoch(self): 324 scheduler = StepLR(self.opt, 2) 325 self._test_lr_is_constant_for_constant_epoch(scheduler) 326 327 def test_exponential_lr_is_constant_for_constant_epoch(self): 328 scheduler = ExponentialLR(self.opt, gamma=0.9) 329 self._test_lr_is_constant_for_constant_epoch(scheduler) 330 331 def test_constantlr_is_constant_for_constant_epoch(self): 332 scheduler = ConstantLR(self.opt) 333 self._test_lr_is_constant_for_constant_epoch(scheduler) 334 335 def test_linear_linearlr_is_constant_for_constant_epoch(self): 336 scheduler = LinearLR(self.opt) 337 self._test_lr_is_constant_for_constant_epoch(scheduler) 338 339 def test_polynomial_lr_is_constant_for_constant_epoch(self): 340 scheduler = PolynomialLR(self.opt, power=0.9) 341 self._test_lr_is_constant_for_constant_epoch(scheduler) 342 343 def test_step_lr(self): 344 # lr = 0.05 if epoch < 3 345 # lr = 0.005 if 30 <= epoch < 6 346 # lr = 0.0005 if epoch >= 9 347 epochs = 10 348 single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 349 targets = [single_targets, [x * epochs for x in single_targets]] 350 scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 351 self._test(scheduler, targets, epochs) 352 353 def test_get_last_lr_step_lr(self): 354 from torch.nn import Parameter 355 356 epochs = 10 357 optimizer = SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1) 358 targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]] 359 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) 360 self._test_get_last_lr(scheduler, targets, epochs) 361 362 def test_get_last_lr_multi_step_lr(self): 363 # lr = 0.05 if epoch < 2 364 # lr = 0.005 if 2 <= epoch < 5 365 # lr = 0.0005 if 5 <= epoch < 9 366 # lr = 0.00005 if 9 <= epoch 367 epochs = 10 368 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1 369 targets = [single_targets, [x * epochs for x in single_targets]] 370 scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 371 self._test_get_last_lr(scheduler, targets, epochs) 372 373 def test_multi_step_lr(self): 374 # lr = 0.05 if epoch < 2 375 # lr = 0.005 if 2 <= epoch < 5 376 # lr = 0.0005 if epoch < 9 377 # lr = 0.00005 if epoch >= 9 378 epochs = 10 379 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 380 targets = [single_targets, [x * epochs for x in single_targets]] 381 scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 382 self._test(scheduler, targets, epochs) 383 384 def test_multi_step_lr_with_epoch(self): 385 # lr = 0.05 if epoch < 2 386 # lr = 0.005 if 2 <= epoch < 5 387 # lr = 0.0005 if epoch < 9 388 # lr = 0.00005 if epoch >= 9 389 epochs = 10 390 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 391 targets = [single_targets, [x * epochs for x in single_targets]] 392 scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 393 self._test_with_epoch(scheduler, targets, epochs) 394 395 def test_get_last_lr_constantlr(self): 396 # lr = 0.025 if epoch < 5 397 # lr = 0.005 if 5 <= epoch 398 epochs = 10 399 single_targets = [0.025] * 5 + [0.05] * 5 400 targets = [single_targets, [x * epochs for x in single_targets]] 401 scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) 402 self._test_get_last_lr(scheduler, targets, epochs) 403 404 def test_get_last_lr_linearlr(self): 405 # lr = 0.025 if epoch == 0 406 # lr = 0.03125 if epoch == 1 407 # lr = 0.0375 if epoch == 2 408 # lr = 0.04375 if epoch == 3 409 # lr = 0.005 if 4 <= epoch 410 epochs = 10 411 start_factor = 1.0 / 4 412 end_factor = 3.0 / 5 413 iters = 4 414 interpolation = [ 415 start_factor + i * (end_factor - start_factor) / iters for i in range(iters) 416 ] 417 single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * ( 418 epochs - iters 419 ) 420 targets = [single_targets, [x * epochs for x in single_targets]] 421 scheduler = LinearLR( 422 self.opt, 423 start_factor=start_factor, 424 end_factor=end_factor, 425 total_iters=iters, 426 ) 427 self._test_get_last_lr(scheduler, targets, epochs) 428 429 def test_constantlr(self): 430 # lr = 0.025 if epoch < 5 431 # lr = 0.005 if 5 <= epoch 432 epochs = 10 433 single_targets = [0.025] * 5 + [0.05] * 5 434 targets = [single_targets, [x * epochs for x in single_targets]] 435 scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) 436 self._test(scheduler, targets, epochs) 437 438 def test_linearlr(self): 439 # lr = 0.025 if epoch == 0 440 # lr = 0.03125 if epoch == 1 441 # lr = 0.0375 if epoch == 2 442 # lr = 0.04375 if epoch == 3 443 # lr = 0.005 if 4 <= epoch 444 epochs = 10 445 start_factor = 1.0 / 2 446 iters = 4 447 interpolation = [ 448 start_factor + i * (1 - start_factor) / iters for i in range(iters) 449 ] 450 single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) 451 targets = [single_targets, [x * epochs for x in single_targets]] 452 scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 453 self._test(scheduler, targets, epochs) 454 455 def test_linearlr_start_factor_limits1(self): 456 start_factor = 0.0 457 iters = 4 458 with self.assertRaises(ValueError): 459 LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 460 461 def test_linearlr_start_factor_limits2(self): 462 start_factor = 1.1 463 iters = 4 464 with self.assertRaises(ValueError): 465 LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 466 467 def test_constantlr_with_epoch(self): 468 # lr = 0.025 if epoch < 5 469 # lr = 0.005 if 5 <= epoch 470 epochs = 10 471 single_targets = [0.025] * 5 + [0.05] * 5 472 targets = [single_targets, [x * epochs for x in single_targets]] 473 scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) 474 self._test_with_epoch(scheduler, targets, epochs) 475 476 def test_linearlr_with_epoch(self): 477 # lr = 0.025 if epoch == 0 478 # lr = 0.03125 if epoch == 1 479 # lr = 0.0375 if epoch == 2 480 # lr = 0.04375 if epoch == 3 481 # lr = 0.005 if 4 <= epoch 482 epochs = 10 483 start_factor = 1.0 / 2 484 end_factor = 1.0 485 iters = 4 486 interpolation = [ 487 start_factor + i * (end_factor - start_factor) / iters for i in range(iters) 488 ] 489 single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) 490 targets = [single_targets, [x * epochs for x in single_targets]] 491 scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 492 self._test_with_epoch(scheduler, targets, epochs) 493 494 def test_exp_lr(self): 495 epochs = 10 496 single_targets = [0.05 * (0.9**x) for x in range(epochs)] 497 targets = [single_targets, [x * epochs for x in single_targets]] 498 scheduler = ExponentialLR(self.opt, gamma=0.9) 499 self._test(scheduler, targets, epochs) 500 501 def test_poly_lr(self): 502 epochs = 10 503 power = 0.9 504 total_iters = 5 505 single_targets = [ 506 (1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters) 507 ] + [0.0] * (epochs - total_iters) 508 targets = [single_targets, [x * epochs for x in single_targets]] 509 scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters) 510 self._test(scheduler, targets, epochs) 511 512 def test_cos_anneal_lr(self): 513 epochs = 10 514 eta_min = 1e-10 515 single_targets = [ 516 eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 517 for x in range(epochs) 518 ] 519 targets = [single_targets, [x * epochs for x in single_targets]] 520 scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 521 self._test(scheduler, targets, epochs) 522 523 def test_closed_form_step_lr(self): 524 scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 525 closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3) 526 self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 527 528 def test_closed_form_linearlr(self): 529 scheduler = LinearLR( 530 self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 531 ) 532 closed_form_scheduler = LinearLR( 533 self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 534 ) 535 self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 536 537 def test_closed_form_constantlr(self): 538 scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) 539 closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) 540 self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 541 542 def test_closed_form_multi_step_lr(self): 543 scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 544 closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 545 self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 546 547 def test_closed_form_exp_lr(self): 548 scheduler = ExponentialLR(self.opt, gamma=0.9) 549 closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9) 550 self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 551 552 def test_closed_form_poly_lr(self): 553 scheduler = PolynomialLR(self.opt, power=0.9) 554 closed_form_scheduler = PolynomialLR(self.opt, power=0.9) 555 self._test_against_closed_form(scheduler, closed_form_scheduler, 20) 556 557 def test_closed_form_cos_anneal_lr(self): 558 eta_min = 1e-10 559 epochs = 20 560 T_max = 5 561 scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) 562 closed_form_scheduler = CosineAnnealingLR( 563 self.opt, T_max=T_max, eta_min=eta_min 564 ) 565 self._test_against_closed_form(scheduler, closed_form_scheduler, epochs) 566 567 def test_cos_anneal_lr_continue(self): 568 eta_min = 0.1 569 T_max = 5 570 scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) 571 self.opt.step() 572 scheduler.step() 573 original_lrs = scheduler._last_lr 574 new_scheduler = CosineAnnealingLR( 575 self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0 576 ) 577 new_lrs = new_scheduler._last_lr 578 torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5) 579 580 def test_reduce_lr_on_plateau1(self): 581 epochs = 10 582 for param_group in self.opt.param_groups: 583 param_group["lr"] = 0.5 584 targets = [[0.5] * 20] 585 metrics = [10 - i * 0.0167 for i in range(20)] 586 scheduler = ReduceLROnPlateau( 587 self.opt, 588 threshold_mode="abs", 589 mode="min", 590 threshold=0.01, 591 patience=5, 592 cooldown=5, 593 ) 594 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 595 596 def test_reduce_lr_on_plateau2(self): 597 epochs = 22 598 for param_group in self.opt.param_groups: 599 param_group["lr"] = 0.5 600 targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2] 601 metrics = [10 - i * 0.0165 for i in range(22)] 602 scheduler = ReduceLROnPlateau( 603 self.opt, 604 patience=5, 605 cooldown=0, 606 threshold_mode="abs", 607 mode="min", 608 threshold=0.1, 609 ) 610 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 611 612 def test_reduce_lr_on_plateau3(self): 613 epochs = 22 614 for param_group in self.opt.param_groups: 615 param_group["lr"] = 0.5 616 targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4] 617 metrics = [-0.8] * 2 + [-0.234] * 20 618 scheduler = ReduceLROnPlateau( 619 self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" 620 ) 621 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 622 623 def test_reduce_lr_on_plateau4(self): 624 epochs = 20 625 for param_group in self.opt.param_groups: 626 param_group["lr"] = 0.5 627 targets = [[0.5] * 20] 628 metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 629 scheduler = ReduceLROnPlateau( 630 self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 631 ) 632 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 633 634 def test_reduce_lr_on_plateau5(self): 635 epochs = 20 636 for param_group in self.opt.param_groups: 637 param_group["lr"] = 0.5 638 targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] 639 metrics = [1.5 * (1.005**i) for i in range(20)] 640 scheduler = ReduceLROnPlateau( 641 self.opt, 642 mode="max", 643 threshold_mode="rel", 644 threshold=0.1, 645 patience=5, 646 cooldown=5, 647 ) 648 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 649 650 def test_reduce_lr_on_plateau6(self): 651 epochs = 20 652 for param_group in self.opt.param_groups: 653 param_group["lr"] = 0.5 654 targets = [[0.5] * 20] 655 metrics = [1.5 * (0.85**i) for i in range(20)] 656 scheduler = ReduceLROnPlateau( 657 self.opt, mode="min", threshold_mode="rel", threshold=0.1 658 ) 659 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 660 661 def test_reduce_lr_on_plateau7(self): 662 epochs = 20 663 for param_group in self.opt.param_groups: 664 param_group["lr"] = 0.5 665 targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] 666 metrics = [1] * 7 + [0.6] + [0.5] * 12 667 scheduler = ReduceLROnPlateau( 668 self.opt, 669 mode="min", 670 threshold_mode="rel", 671 threshold=0.1, 672 patience=5, 673 cooldown=5, 674 ) 675 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 676 677 def test_reduce_lr_on_plateau8(self): 678 epochs = 20 679 for param_group in self.opt.param_groups: 680 param_group["lr"] = 0.5 681 targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14] 682 metrics = [1.5 * (1.005**i) for i in range(20)] 683 scheduler = ReduceLROnPlateau( 684 self.opt, 685 mode="max", 686 threshold_mode="rel", 687 min_lr=[0.4, 0.3], 688 threshold=0.1, 689 patience=5, 690 cooldown=5, 691 ) 692 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) 693 694 def test_reduce_lr_on_plateau_get_last_lr_before_step(self): 695 for param_group in self.opt.param_groups: 696 param_group["lr"] = 0.5 697 scheduler = ReduceLROnPlateau( 698 self.opt, 699 ) 700 self.assertEqual( 701 scheduler.get_last_lr(), [0.5 for param_group in self.opt.param_groups] 702 ) 703 704 def test_sequentiallr1(self): 705 epochs = 19 706 schedulers = [None] * 2 707 targets = [ 708 [0.05, 0.04, 0.032] 709 + [0.05 for x in range(4)] 710 + [0.05 * 0.1 for x in range(4)] 711 + [0.05 * 0.01 for x in range(4)] 712 + [0.05 * 0.001 for x in range(4)] 713 ] 714 milestones = [3] 715 schedulers[0] = ExponentialLR(self.opt, gamma=0.8) 716 schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4) 717 scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 718 self._test(scheduler, targets, epochs) 719 720 def test_sequentiallr2(self): 721 epochs = 13 722 schedulers = [None] * 2 723 targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]] 724 milestones = [3] 725 schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) 726 schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 727 scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 728 self._test(scheduler, targets, epochs) 729 730 def test_sequentiallr3(self): 731 epochs = 12 732 schedulers = [None] * 3 733 targets = [ 734 [0.005, 0.005, 0.005] 735 + [0.05, 0.04, 0.032] 736 + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] 737 ] 738 milestones = [3, 6] 739 schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) 740 schedulers[1] = ExponentialLR(self.opt, gamma=0.8) 741 schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) 742 scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 743 self._test(scheduler, targets, epochs) 744 745 def test_sequentiallr4(self): 746 optimizer = SGD([torch.tensor(0.5)], lr=0.1) 747 prev_lr = optimizer.param_groups[0]["lr"] 748 749 schedulers = [ 750 torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1), 751 torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1), 752 ] 753 scheduler = torch.optim.lr_scheduler.SequentialLR( 754 optimizer, schedulers, milestones=[10] 755 ) 756 757 new_lr = optimizer.param_groups[0]["lr"] 758 759 # Ensure that multiple schedulers does not affect the initial learning rate 760 self.assertEqual(prev_lr, new_lr) 761 762 def test_get_last_lr_sequentiallr(self): 763 epochs = 12 764 milestones = [3, 6] 765 schedulers = [None] * 3 766 schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) 767 schedulers[1] = ExponentialLR(self.opt, gamma=0.8) 768 schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) 769 scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) 770 constant_lr_target = [0.005] * 3 771 exponential_lr_target = [0.05, 0.04, 0.032] 772 step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] 773 single_targets = constant_lr_target + exponential_lr_target + step_lr_target 774 targets = [single_targets, [x * 10 for x in single_targets]] 775 self._test_get_last_lr(scheduler, targets, epochs) 776 777 def test_chained_lr2_get_last_lr_before_step(self): 778 schedulers = [ 779 LinearLR(self.opt, start_factor=0.4, total_iters=3), 780 MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1), 781 ] 782 scheduler = ChainedScheduler(schedulers) 783 self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 784 785 def test_chained_lr1(self): 786 epochs = 10 787 schedulers = [None] * 1 788 targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3] 789 schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 790 scheduler = ChainedScheduler(schedulers) 791 self._test([scheduler], targets, epochs) 792 self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 793 794 def test_chained_lr2(self): 795 epochs = 10 796 schedulers = [None] * 1 797 targets = [[0.02, 0.03, 0.04] + [0.05] * 9] 798 schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) 799 scheduler = ChainedScheduler(schedulers) 800 self._test([scheduler], targets, epochs) 801 self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 802 803 def test_chained_lr3(self): 804 epochs = 10 805 schedulers = [None] * 2 806 targets = [ 807 [0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3 808 ] 809 schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) 810 schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) 811 scheduler = ChainedScheduler(schedulers) 812 self._test([scheduler], targets, epochs) 813 self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 814 815 def test_chained_lr4(self): 816 epochs = 9 817 schedulers = [None] * 3 818 targets = [ 819 [0.05 * 0.2 * 0.9**x for x in range(3)] 820 + [0.05 * 0.2 * 0.9**3 * 0.1] 821 + [0.05 * 0.9**x * 0.1 for x in range(4, 6)] 822 + [0.05 * 0.9**x * 0.01 for x in range(6, 9)] 823 ] 824 schedulers[0] = ExponentialLR(self.opt, gamma=0.9) 825 schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4) 826 schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3) 827 scheduler = ChainedScheduler(schedulers) 828 self._test([scheduler], targets, epochs) 829 self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 830 831 def test_chained_lr5(self): 832 def poly_lr(lr: float): 833 return [ 834 (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters) 835 ] + [0.0] * (epochs - total_iters) 836 837 schedulers = [None] * 2 838 epochs = 10 839 power = 0.9 840 total_iters = 5 841 const_factor = 0.1 842 single_targets = [x * const_factor for x in poly_lr(lr=0.05)] 843 targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]] 844 schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters) 845 schedulers[1] = ConstantLR(self.opt, factor=const_factor) 846 scheduler = ChainedScheduler(schedulers) 847 self._test(scheduler, targets, epochs) 848 self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) 849 850 def test_compound_step_and_multistep_lr(self): 851 epochs = 10 852 schedulers = [None] * 2 853 schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 854 schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 855 targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]] 856 self._test(schedulers, targets, epochs) 857 858 def test_compound_step_and_exp_lr(self): 859 epochs = 10 860 schedulers = [None] * 2 861 single_targets = [0.05 * (0.9**x) for x in range(3)] 862 single_targets += [0.005 * (0.9**x) for x in range(3, 6)] 863 single_targets += [0.0005 * (0.9**x) for x in range(6, 9)] 864 single_targets += [0.00005 * (0.9**x) for x in range(9, 12)] 865 targets = [single_targets, [x * epochs for x in single_targets]] 866 schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 867 schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 868 self._test(schedulers, targets, epochs) 869 870 def test_compound_exp_and_multistep_lr(self): 871 epochs = 10 872 schedulers = [None] * 2 873 single_targets = [0.05 * (0.9**x) for x in range(2)] 874 single_targets += [0.005 * (0.9**x) for x in range(2, 5)] 875 single_targets += [0.0005 * (0.9**x) for x in range(5, 9)] 876 single_targets += [0.00005 * (0.9**x) for x in range(9, 11)] 877 targets = [single_targets, [x * epochs for x in single_targets]] 878 schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 879 schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 880 self._test(schedulers, targets, epochs) 881 882 def test_compound_exp_and_linearlr(self): 883 epochs = 10 884 iters = 4 885 start_factor = 0.4 886 end_factor = 0.9 887 schedulers = [None] * 2 888 single_targets = [0.05 * (0.9**x) for x in range(11)] 889 for i in range(iters): 890 single_targets[i] *= start_factor + i / iters * (end_factor - start_factor) 891 for i in range(iters, 11): 892 single_targets[i] *= end_factor 893 targets = [single_targets, [x * epochs for x in single_targets]] 894 schedulers[0] = LinearLR( 895 self.opt, 896 start_factor=start_factor, 897 end_factor=end_factor, 898 total_iters=iters, 899 ) 900 schedulers[1] = ExponentialLR(self.opt, gamma=0.9) 901 self._test(schedulers, targets, epochs) 902 903 def test_compound_step_and_constantlr(self): 904 epochs = 10 905 iters = 4 906 factor = 0.4 907 schedulers = [None] * 2 908 single_targets = ( 909 [0.05 * 0.4] * 3 910 + [0.005 * 0.4] 911 + [0.005] * 2 912 + [0.0005] * 3 913 + [0.00005] * 3 914 ) 915 targets = [single_targets, [x * epochs for x in single_targets]] 916 schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) 917 schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4) 918 self._test(schedulers, targets, epochs) 919 920 def test_compound_linearlr_and_multistep_lr(self): 921 epochs = 10 922 iters = 4 923 start_factor = 0.4 924 schedulers = [None] * 2 925 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2 926 for i in range(iters): 927 single_targets[i] *= start_factor + i / iters * (1 - start_factor) 928 targets = [single_targets, [x * epochs for x in single_targets]] 929 schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 930 schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 931 self._test(schedulers, targets, epochs) 932 933 def test_compound_cosanneal_and_step_lr(self): 934 epochs = 10 935 eta_min = 1e-10 936 single_targets = [ 937 eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 938 for x in range(epochs) 939 ] 940 single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)] 941 targets = [single_targets, [x * epochs for x in single_targets]] 942 schedulers = [None] * 2 943 schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 944 schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) 945 self._test(schedulers, targets, epochs) 946 947 def test_compound_cosanneal_and_multistep_lr(self): 948 epochs = 10 949 eta_min = 1e-10 950 single_targets = [ 951 eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 952 for x in range(epochs) 953 ] 954 multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] 955 single_targets = [x * y for x, y in zip(single_targets, multipliers)] 956 targets = [single_targets, [x * epochs for x in single_targets]] 957 schedulers = [None] * 2 958 schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 959 schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) 960 self._test(schedulers, targets, epochs) 961 962 def test_compound_cosanneal_and_linearlr(self): 963 epochs = 10 964 iters = 4 965 start_factor = 0.4 966 eta_min = 1e-10 967 schedulers = [None] * 2 968 single_targets = [ 969 eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 970 for x in range(epochs) 971 ] 972 for i in range(iters): 973 single_targets[i] *= start_factor + i / iters * (1 - start_factor) 974 targets = [single_targets, [x * epochs for x in single_targets]] 975 schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 976 schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 977 self._test(schedulers, targets, epochs) 978 979 def test_compound_cosanneal_and_exp_lr(self): 980 epochs = 10 981 eta_min = 1e-10 982 single_targets = [ 983 eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 984 for x in range(epochs) 985 ] 986 multipliers = [0.1**i for i in range(epochs)] 987 single_targets = [x * y for x, y in zip(single_targets, multipliers)] 988 targets = [single_targets, [x * epochs for x in single_targets]] 989 schedulers = [None] * 2 990 schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) 991 schedulers[1] = ExponentialLR(self.opt, gamma=0.1) 992 self._test(schedulers, targets, epochs) 993 994 def test_compound_reduce_lr_on_plateau1(self): 995 epochs = 10 996 for param_group in self.opt.param_groups: 997 param_group["lr"] = 0.5 998 single_targets = [0.5] * 20 999 multipliers = [0.1 ** (i // 3) for i in range(20)] 1000 single_targets = [x * y for x, y in zip(multipliers, single_targets)] 1001 targets = [single_targets] 1002 targets = targets[1:] # test runs step before checking lr 1003 metrics = [10 - i * 0.0167 for i in range(20)] 1004 schedulers = [None, None] 1005 schedulers[0] = ReduceLROnPlateau( 1006 self.opt, 1007 threshold_mode="abs", 1008 mode="min", 1009 threshold=0.01, 1010 patience=5, 1011 cooldown=5, 1012 ) 1013 schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) 1014 self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1015 1016 def test_compound_reduce_lr_on_plateau2(self): 1017 epochs = 22 1018 for param_group in self.opt.param_groups: 1019 param_group["lr"] = 0.5 1020 single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 1021 multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10 1022 single_targets = [x * y for x, y in zip(single_targets, multipliers)] 1023 targets = [single_targets] 1024 targets = targets[1:] # test runs step before checking lr 1025 metrics = [10 - i * 0.0165 for i in range(22)] 1026 schedulers = [None] * 2 1027 schedulers[0] = ReduceLROnPlateau( 1028 self.opt, 1029 patience=5, 1030 cooldown=0, 1031 threshold_mode="abs", 1032 mode="min", 1033 threshold=0.1, 1034 ) 1035 schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12]) 1036 self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1037 1038 def test_compound_reduce_lr_on_plateau3(self): 1039 epochs = 22 1040 for param_group in self.opt.param_groups: 1041 param_group["lr"] = 0.5 1042 single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4 1043 multipliers = [0.1**i for i in range(epochs)] 1044 single_targets = [x * y for x, y in zip(multipliers, single_targets)] 1045 targets = [single_targets] 1046 targets = targets[1:] # test runs step before checking lr 1047 metrics = [-0.8] * 2 + [-0.234] * 20 1048 schedulers = [None, None] 1049 schedulers[0] = ReduceLROnPlateau( 1050 self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" 1051 ) 1052 schedulers[1] = ExponentialLR(self.opt, gamma=0.1) 1053 self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1054 1055 def test_compound_reduce_lr_on_plateau4(self): 1056 epochs = 20 1057 for param_group in self.opt.param_groups: 1058 param_group["lr"] = 0.05 1059 epochs = 10 1060 eta_min = 1e-10 1061 single_targets = [ 1062 eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 1063 for x in range(epochs) 1064 ] 1065 targets = [single_targets] 1066 targets = targets[1:] # test runs step before checking lr 1067 metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 1068 schedulers = [None, None] 1069 schedulers[0] = ReduceLROnPlateau( 1070 self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 1071 ) 1072 schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) 1073 self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1074 1075 def test_compound_reduce_lr_on_plateau5(self): 1076 iters = 4 1077 start_factor = 0.4 1078 epochs = 22 1079 for param_group in self.opt.param_groups: 1080 param_group["lr"] = 0.5 1081 single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 1082 multipliers = [1] * 22 1083 for i in range(iters): 1084 multipliers[i] *= start_factor + i / iters * (1 - start_factor) 1085 single_targets = [x * y for x, y in zip(single_targets, multipliers)] 1086 targets = [single_targets] 1087 targets = targets[1:] # test runs step before checking lr 1088 metrics = [10 - i * 0.0165 for i in range(22)] 1089 schedulers = [None] * 2 1090 schedulers[0] = ReduceLROnPlateau( 1091 self.opt, 1092 patience=5, 1093 cooldown=0, 1094 threshold_mode="abs", 1095 mode="min", 1096 threshold=0.1, 1097 ) 1098 schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) 1099 self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) 1100 1101 def test_cycle_lr_invalid_mode(self): 1102 with self.assertRaises(ValueError): 1103 scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS") 1104 1105 def test_cycle_lr_triangular_mode_one_lr(self): 1106 lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1107 momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] 1108 lr_targets = [lr_target, lr_target] 1109 momentum_targets = [momentum_target, momentum_target] 1110 scheduler = CyclicLR( 1111 self.opt, 1112 base_lr=1, 1113 max_lr=5, 1114 step_size_up=4, 1115 cycle_momentum=True, 1116 base_momentum=1, 1117 max_momentum=5, 1118 mode="triangular", 1119 ) 1120 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1121 1122 def test_cycle_lr_triangular_mode_one_lr_no_momentum(self): 1123 lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1124 lr_targets = [lr_target, lr_target] 1125 momentum_target = [self.opt.defaults["momentum"]] * len(lr_target) 1126 momentum_targets = [momentum_target, momentum_target] 1127 scheduler = CyclicLR( 1128 self.opt, 1129 base_lr=1, 1130 max_lr=5, 1131 step_size_up=4, 1132 cycle_momentum=False, 1133 mode="triangular", 1134 ) 1135 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1136 1137 def test_cycle_lr_triangular2_mode_one_lr(self): 1138 lr_target = [ 1139 1, 1140 2, 1141 3, 1142 4, 1143 5, 1144 4, 1145 3, 1146 2, 1147 1, 1148 1.5, 1149 2.0, 1150 2.5, 1151 3.0, 1152 2.5, 1153 2.0, 1154 1.5, 1155 1, 1156 1.25, 1157 1.50, 1158 1.75, 1159 2.00, 1160 1.75, 1161 ] 1162 momentum_target = [ 1163 5.0, 1164 4.0, 1165 3.0, 1166 2.0, 1167 1.0, 1168 2.0, 1169 3.0, 1170 4.0, 1171 5.0, 1172 4.5, 1173 4.0, 1174 3.5, 1175 3.0, 1176 3.5, 1177 4.0, 1178 4.5, 1179 5.0, 1180 4.75, 1181 4.5, 1182 4.25, 1183 4.0, 1184 4.25, 1185 ] 1186 lr_targets = [lr_target, lr_target] 1187 momentum_targets = [momentum_target, momentum_target] 1188 scheduler = CyclicLR( 1189 self.opt, 1190 base_lr=1, 1191 max_lr=5, 1192 step_size_up=4, 1193 cycle_momentum=True, 1194 base_momentum=1, 1195 max_momentum=5, 1196 mode="triangular2", 1197 ) 1198 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1199 1200 def test_cycle_lr_exp_range_mode_one_lr(self): 1201 base_lr, max_lr = 1, 5 1202 diff_lr = max_lr - base_lr 1203 gamma = 0.9 1204 xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] 1205 lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] 1206 momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] 1207 lr_targets = [lr_target, lr_target] 1208 momentum_targets = [momentum_target, momentum_target] 1209 scheduler = CyclicLR( 1210 self.opt, 1211 base_lr=base_lr, 1212 max_lr=max_lr, 1213 step_size_up=4, 1214 cycle_momentum=True, 1215 base_momentum=base_lr, 1216 max_momentum=max_lr, 1217 mode="exp_range", 1218 gamma=gamma, 1219 ) 1220 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1221 1222 def test_cycle_lr_triangular_mode(self): 1223 lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1224 lr_target_2 = [x + 1 for x in lr_target_1] 1225 lr_targets = [lr_target_1, lr_target_2] 1226 momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] 1227 momentum_target_2 = [x + 1 for x in momentum_target_1] 1228 momentum_targets = [momentum_target_1, momentum_target_2] 1229 scheduler = CyclicLR( 1230 self.opt, 1231 base_lr=[1, 2], 1232 max_lr=[5, 6], 1233 step_size_up=4, 1234 cycle_momentum=True, 1235 base_momentum=[1, 2], 1236 max_momentum=[5, 6], 1237 mode="triangular", 1238 ) 1239 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) 1240 1241 def test_cycle_lr_triangular2_mode(self): 1242 lr_target_1 = [ 1243 1, 1244 2, 1245 3, 1246 4, 1247 5, 1248 4, 1249 3, 1250 2, 1251 1, 1252 1.5, 1253 2.0, 1254 2.5, 1255 3.0, 1256 2.5, 1257 2.0, 1258 1.5, 1259 1, 1260 1.25, 1261 1.50, 1262 1.75, 1263 2.00, 1264 1.75, 1265 ] 1266 lr_target_2 = [x + 2 for x in lr_target_1] 1267 lr_targets = [lr_target_1, lr_target_2] 1268 momentum_target_1 = [ 1269 5.0, 1270 4.0, 1271 3.0, 1272 2.0, 1273 1.0, 1274 2.0, 1275 3.0, 1276 4.0, 1277 5.0, 1278 4.5, 1279 4.0, 1280 3.5, 1281 3.0, 1282 3.5, 1283 4.0, 1284 4.5, 1285 5.0, 1286 4.75, 1287 4.5, 1288 4.25, 1289 4.0, 1290 4.25, 1291 ] 1292 momentum_target_2 = [x + 2 for x in momentum_target_1] 1293 momentum_targets = [momentum_target_1, momentum_target_2] 1294 scheduler = CyclicLR( 1295 self.opt, 1296 base_lr=[1, 3], 1297 max_lr=[5, 7], 1298 step_size_up=4, 1299 cycle_momentum=True, 1300 base_momentum=[1, 3], 1301 max_momentum=[5, 7], 1302 mode="triangular2", 1303 ) 1304 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) 1305 1306 def test_cycle_lr_exp_range_mode(self): 1307 base_lr_1, max_lr_1 = 1, 5 1308 base_lr_2, max_lr_2 = 5, 12 1309 1310 diff_lr_1 = max_lr_1 - base_lr_1 1311 diff_lr_2 = max_lr_2 - base_lr_2 1312 1313 gamma = 0.9 1314 xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] 1315 lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] 1316 lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] 1317 lr_targets = [lr_target_1, lr_target_2] 1318 momentum_target_1 = [ 1319 max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs) 1320 ] 1321 momentum_target_2 = [ 1322 max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs) 1323 ] 1324 momentum_targets = [momentum_target_1, momentum_target_2] 1325 scheduler = CyclicLR( 1326 self.opt, 1327 base_lr=[base_lr_1, base_lr_2], 1328 max_lr=[max_lr_1, max_lr_2], 1329 step_size_up=4, 1330 cycle_momentum=True, 1331 base_momentum=[base_lr_1, base_lr_2], 1332 max_momentum=[max_lr_1, max_lr_2], 1333 mode="exp_range", 1334 gamma=gamma, 1335 ) 1336 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) 1337 1338 def test_cycle_lr_triangular_mode_step_size_up_down(self): 1339 lr_target = [ 1340 1.0, 1341 2.0, 1342 3.0, 1343 4.0, 1344 5.0, 1345 13.0 / 3, 1346 11.0 / 3, 1347 9.0 / 3, 1348 7.0 / 3, 1349 5.0 / 3, 1350 1.0, 1351 ] 1352 lr_targets = [lr_target, lr_target] 1353 momentum_target = [ 1354 5.0, 1355 4.0, 1356 3.0, 1357 2.0, 1358 1.0, 1359 5.0 / 3, 1360 7.0 / 3, 1361 3.0, 1362 11.0 / 3, 1363 13.0 / 3, 1364 5.0, 1365 ] 1366 momentum_targets = [momentum_target, momentum_target] 1367 1368 scheduler = CyclicLR( 1369 self.opt, 1370 base_lr=1, 1371 max_lr=5, 1372 step_size_up=4, 1373 step_size_down=6, 1374 cycle_momentum=True, 1375 base_momentum=1, 1376 max_momentum=5, 1377 mode="triangular", 1378 ) 1379 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1380 1381 def test_cycle_lr_triangular2_mode_step_size_up_down(self): 1382 lr_base_target = [ 1383 1.0, 1384 3.0, 1385 5.0, 1386 13.0 / 3, 1387 11.0 / 3, 1388 9.0 / 3, 1389 7.0 / 3, 1390 5.0 / 3, 1391 1.0, 1392 2.0, 1393 3.0, 1394 8.0 / 3, 1395 7.0 / 3, 1396 6.0 / 3, 1397 5.0 / 3, 1398 4.0 / 3, 1399 1.0, 1400 3.0 / 2, 1401 2.0, 1402 11.0 / 6, 1403 10.0 / 6, 1404 9.0 / 6, 1405 8.0 / 6, 1406 7.0 / 6, 1407 ] 1408 momentum_base_target = [ 1409 5.0, 1410 3.0, 1411 1.0, 1412 5.0 / 3, 1413 7.0 / 3, 1414 3.0, 1415 11.0 / 3, 1416 13.0 / 3, 1417 5.0, 1418 4.0, 1419 3.0, 1420 10.0 / 3, 1421 11.0 / 3, 1422 4.0, 1423 13.0 / 3, 1424 14.0 / 3, 1425 5.0, 1426 4.5, 1427 4.0, 1428 25.0 / 6, 1429 13.0 / 3, 1430 4.5, 1431 14.0 / 3, 1432 29.0 / 6, 1433 ] 1434 deltas = [2 * i for i in range(0, 2)] 1435 base_lrs = [1 + delta for delta in deltas] 1436 max_lrs = [5 + delta for delta in deltas] 1437 lr_targets = [[x + delta for x in lr_base_target] for delta in deltas] 1438 momentum_targets = [ 1439 [x + delta for x in momentum_base_target] for delta in deltas 1440 ] 1441 scheduler = CyclicLR( 1442 self.opt, 1443 base_lr=base_lrs, 1444 max_lr=max_lrs, 1445 step_size_up=2, 1446 step_size_down=6, 1447 cycle_momentum=True, 1448 base_momentum=base_lrs, 1449 max_momentum=max_lrs, 1450 mode="triangular2", 1451 ) 1452 self._test_cycle_lr( 1453 scheduler, lr_targets, momentum_targets, len(lr_base_target) 1454 ) 1455 1456 def test_cycle_lr_exp_range_mode_step_size_up_down(self): 1457 base_lr, max_lr = 1, 5 1458 diff_lr = max_lr - base_lr 1459 gamma = 0.9 1460 xs = [ 1461 0.0, 1462 0.5, 1463 1.0, 1464 5.0 / 6, 1465 4.0 / 6, 1466 3.0 / 6, 1467 2.0 / 6, 1468 1.0 / 6, 1469 0.0, 1470 0.5, 1471 1.0, 1472 5.0 / 6, 1473 4.0 / 6, 1474 ] 1475 lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] 1476 lr_targets = [lr_target, lr_target] 1477 momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] 1478 momentum_targets = [momentum_target, momentum_target] 1479 scheduler = CyclicLR( 1480 self.opt, 1481 base_lr=base_lr, 1482 max_lr=max_lr, 1483 step_size_up=2, 1484 step_size_down=6, 1485 cycle_momentum=True, 1486 base_momentum=base_lr, 1487 max_momentum=max_lr, 1488 mode="exp_range", 1489 gamma=gamma, 1490 ) 1491 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1492 1493 def test_cycle_lr_with_momentumless_optimizer(self): 1494 # Note [Temporarily set optimizer to Adam] 1495 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1496 # The TestLRScheduler object carries around an SGD optimizer to avoid having to 1497 # instantiate one for every test. This gets in the way for our very specific case 1498 # in which we need to use Adam (or really any optimizer that doesn't use momentum) 1499 # in order to test that the momentum bug in CyclicLR is fixed (the bug is described 1500 # in more detail in https://github.com/pytorch/pytorch/issues/19003 ). 1501 old_opt = self.opt 1502 self.opt = Adam( 1503 [ 1504 {"params": self.net.conv1.parameters()}, 1505 {"params": self.net.conv2.parameters(), "lr": 0.5}, 1506 ], 1507 lr=0.05, 1508 ) 1509 1510 lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] 1511 lr_targets = [lr_target, lr_target] 1512 momentum_target = [None] * len(lr_target) 1513 momentum_targets = [momentum_target, momentum_target] 1514 scheduler = CyclicLR( 1515 self.opt, 1516 base_lr=1, 1517 max_lr=5, 1518 step_size_up=4, 1519 cycle_momentum=False, 1520 mode="triangular", 1521 ) 1522 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) 1523 1524 self.opt = old_opt # set optimizer back to SGD 1525 1526 def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self): 1527 with self.assertRaises(ValueError): 1528 rprop_opt = Rprop(self.net.parameters()) 1529 scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True) 1530 1531 def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self): 1532 adam_opt = Adam(self.net.parameters()) 1533 scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True) 1534 1535 def test_cycle_lr_removed_after_out_of_scope(self): 1536 import gc 1537 import weakref 1538 1539 gc.disable() 1540 1541 def test(): 1542 adam_opt = Adam(self.net.parameters()) 1543 scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False) 1544 return weakref.ref(scheduler) 1545 1546 ref = test() 1547 assert ref() is None 1548 gc.enable() 1549 1550 def test_cycle_lr_state_dict_picklable(self): 1551 adam_opt = Adam(self.net.parameters()) 1552 1553 # Case 1: Built-in mode 1554 scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False) 1555 self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType) 1556 state = scheduler.state_dict() 1557 self.assertNotIn("_scale_fn_ref", state) 1558 self.assertIs(state["_scale_fn_custom"], None) 1559 pickle.dumps(state) 1560 1561 # Case 2: Custom `scale_fn`, a function object 1562 def scale_fn(_): 1563 return 0.5 1564 1565 scheduler = CyclicLR( 1566 adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1567 ) 1568 state = scheduler.state_dict() 1569 self.assertNotIn("_scale_fn_ref", state) 1570 self.assertIs(state["_scale_fn_custom"], None) 1571 pickle.dumps(state) 1572 1573 # Case 3: Custom `scale_fn`, a callable class 1574 class ScaleFn: 1575 def __init__(self) -> None: 1576 self.x = 0.5 1577 1578 def __call__(self, _): 1579 return self.x 1580 1581 scale_fn = ScaleFn() 1582 1583 scheduler = CyclicLR( 1584 adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1585 ) 1586 state = scheduler.state_dict() 1587 self.assertNotIn("_scale_fn_ref", state) 1588 self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__) 1589 pickle.dumps(state) 1590 1591 def test_cycle_lr_scale_fn_restored_from_state_dict(self): 1592 adam_opt = Adam(self.net.parameters()) 1593 1594 # Case 1: Built-in mode 1595 scheduler = CyclicLR( 1596 adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, mode="triangular2" 1597 ) 1598 restored_scheduler = CyclicLR( 1599 adam_opt, base_lr=1, max_lr=5, cycle_momentum=False 1600 ) 1601 restored_scheduler.load_state_dict(scheduler.state_dict()) 1602 self.assertTrue(restored_scheduler.mode == scheduler.mode == "triangular2") 1603 self.assertIsNotNone(restored_scheduler._scale_fn_ref) and self.assertIsNotNone( 1604 scheduler._scale_fn_ref 1605 ) 1606 self.assertIs(restored_scheduler._scale_fn_custom, None) 1607 self.assertIs(scheduler._scale_fn_custom, None) 1608 1609 # Case 2: Custom `scale_fn` 1610 def scale_fn(_): 1611 return 0.5 1612 1613 scheduler = CyclicLR( 1614 adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1615 ) 1616 restored_scheduler = CyclicLR( 1617 adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn 1618 ) 1619 restored_scheduler.load_state_dict(scheduler.state_dict()) 1620 self.assertIs(scheduler._scale_fn_custom, scale_fn) 1621 self.assertIs(restored_scheduler._scale_fn_custom, scale_fn) 1622 1623 def test_onecycle_lr_invalid_anneal_strategy(self): 1624 with self.assertRaises(ValueError): 1625 scheduler = OneCycleLR( 1626 self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS" 1627 ) 1628 1629 def test_onecycle_lr_invalid_pct_start(self): 1630 with self.assertRaises(ValueError): 1631 scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1) 1632 1633 def test_onecycle_lr_cannot_calculate_total_steps(self): 1634 with self.assertRaises(ValueError): 1635 scheduler = OneCycleLR(self.opt, max_lr=1e-3) 1636 1637 def test_onecycle_lr_linear_annealing(self): 1638 lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] 1639 momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] 1640 lr_targets = [lr_target, lr_target] 1641 momentum_targets = [momentum_target, momentum_target] 1642 scheduler = OneCycleLR( 1643 self.opt, 1644 max_lr=25, 1645 final_div_factor=2, 1646 base_momentum=1, 1647 max_momentum=22, 1648 total_steps=10, 1649 anneal_strategy="linear", 1650 ) 1651 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1652 1653 def test_onecycle_lr_linear_annealing_three_phases(self): 1654 lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25] 1655 momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22] 1656 lr_targets = [lr_target, lr_target] 1657 momentum_targets = [momentum_target, momentum_target] 1658 scheduler = OneCycleLR( 1659 self.opt, 1660 max_lr=25, 1661 div_factor=25, 1662 base_momentum=1, 1663 max_momentum=22, 1664 total_steps=10, 1665 anneal_strategy="linear", 1666 pct_start=0.4, 1667 final_div_factor=4, 1668 three_phase=True, 1669 ) 1670 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1671 1672 def test_onecycle_lr_cosine_annealing(self): 1673 def annealing_cos(start, end, pct): 1674 cos_out = math.cos(math.pi * pct) + 1 1675 return end + (start - end) / 2.0 * cos_out 1676 1677 lr_target = [ 1678 1, 1679 13, 1680 25, 1681 annealing_cos(25, 0.5, 1 / 7.0), 1682 annealing_cos(25, 0.5, 2 / 7.0), 1683 annealing_cos(25, 0.5, 3 / 7.0), 1684 annealing_cos(25, 0.5, 4 / 7.0), 1685 annealing_cos(25, 0.5, 5 / 7.0), 1686 annealing_cos(25, 0.5, 6 / 7.0), 1687 0.5, 1688 ] 1689 momentum_target = [ 1690 22, 1691 11.5, 1692 1, 1693 annealing_cos(1, 22, 1 / 7.0), 1694 annealing_cos(1, 22, 2 / 7.0), 1695 annealing_cos(1, 22, 3 / 7.0), 1696 annealing_cos(1, 22, 4 / 7.0), 1697 annealing_cos(1, 22, 5 / 7.0), 1698 annealing_cos(1, 22, 6 / 7.0), 1699 22, 1700 ] 1701 lr_targets = [lr_target, lr_target] 1702 momentum_targets = [momentum_target, momentum_target] 1703 scheduler = OneCycleLR( 1704 self.opt, 1705 max_lr=25, 1706 final_div_factor=2, 1707 base_momentum=1, 1708 max_momentum=22, 1709 total_steps=10, 1710 ) 1711 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1712 1713 def test_onecycle_lr_legacy_state_dict(self): 1714 scheduler = OneCycleLR( 1715 self.opt, 1716 max_lr=25, 1717 final_div_factor=2, 1718 base_momentum=1, 1719 max_momentum=22, 1720 total_steps=10, 1721 anneal_strategy="cos", 1722 ) 1723 delattr(scheduler, "_anneal_func_type") 1724 state_dict = scheduler.state_dict() 1725 self.assertNotIn("anneal_func_type", state_dict) 1726 state_dict["anneal_func"] = OneCycleLR._annealing_cos 1727 scheduler.load_state_dict(state_dict) 1728 1729 def annealing_cos(start, end, pct): 1730 cos_out = math.cos(math.pi * pct) + 1 1731 return end + (start - end) / 2.0 * cos_out 1732 1733 lr_target = [ 1734 1, 1735 13, 1736 25, 1737 annealing_cos(25, 0.5, 1 / 7.0), 1738 annealing_cos(25, 0.5, 2 / 7.0), 1739 annealing_cos(25, 0.5, 3 / 7.0), 1740 annealing_cos(25, 0.5, 4 / 7.0), 1741 annealing_cos(25, 0.5, 5 / 7.0), 1742 annealing_cos(25, 0.5, 6 / 7.0), 1743 0.5, 1744 ] 1745 momentum_target = [ 1746 22, 1747 11.5, 1748 1, 1749 annealing_cos(1, 22, 1 / 7.0), 1750 annealing_cos(1, 22, 2 / 7.0), 1751 annealing_cos(1, 22, 3 / 7.0), 1752 annealing_cos(1, 22, 4 / 7.0), 1753 annealing_cos(1, 22, 5 / 7.0), 1754 annealing_cos(1, 22, 6 / 7.0), 1755 22, 1756 ] 1757 lr_targets = [lr_target, lr_target] 1758 momentum_targets = [momentum_target, momentum_target] 1759 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) 1760 1761 def test_cycle_lr_with_adam(self): 1762 old_opt = self.opt 1763 self.opt = Adam( 1764 [ 1765 {"params": self.net.conv1.parameters()}, 1766 {"params": self.net.conv2.parameters(), "lr": 0.5}, 1767 ], 1768 lr=0.05, 1769 ) 1770 1771 lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] 1772 momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] 1773 lr_targets = [lr_target, lr_target] 1774 momentum_targets = [momentum_target, momentum_target] 1775 scheduler = OneCycleLR( 1776 self.opt, 1777 max_lr=25, 1778 final_div_factor=2, 1779 base_momentum=1, 1780 max_momentum=22, 1781 total_steps=10, 1782 anneal_strategy="linear", 1783 ) 1784 self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True) 1785 self.opt = old_opt # set optimizer back to SGD 1786 1787 def test_lambda_lr(self): 1788 epochs = 10 1789 self.opt.param_groups[0]["lr"] = 0.05 1790 self.opt.param_groups[1]["lr"] = 0.4 1791 targets = [ 1792 [0.05 * (0.9**x) for x in range(epochs)], 1793 [0.4 * (0.8**x) for x in range(epochs)], 1794 ] 1795 scheduler = LambdaLR( 1796 self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2] 1797 ) 1798 self._test(scheduler, targets, epochs) 1799 1800 def test_multiplicative_lr(self): 1801 epochs = 10 1802 self.opt.param_groups[0]["lr"] = 0.05 1803 self.opt.param_groups[1]["lr"] = 0.4 1804 targets = [ 1805 [0.05 * (0.9**x) for x in range(epochs)], 1806 [0.4 * (0.8**x) for x in range(epochs)], 1807 ] 1808 scheduler = MultiplicativeLR( 1809 self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8] 1810 ) 1811 self._test(scheduler, targets, epochs) 1812 1813 @parametrize("T_mult", [1, 2, 4]) 1814 def test_CosineAnnealingWarmRestarts_lr1(self, T_mult): 1815 iters = 100 1816 eta_min = 1e-10 1817 T_i = 10 1818 T_cur = 0 1819 targets = [[0.05], [0.5]] 1820 scheduler = CosineAnnealingWarmRestarts( 1821 self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min 1822 ) 1823 for _ in range(1, iters, 1): 1824 T_cur += 1 1825 if T_cur >= T_i: 1826 T_cur = T_cur - T_i 1827 T_i = int(T_mult) * T_i 1828 targets[0] += [ 1829 eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1830 ] 1831 targets[1] += [ 1832 eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1833 ] 1834 self._test(scheduler, targets, iters) 1835 1836 def test_CosineAnnealingWarmRestarts_lr2(self): 1837 iters = 30 1838 eta_min = 1e-10 1839 T_mults = [1, 2, 4] 1840 for T_mult in T_mults: 1841 T_i = 10 1842 T_cur = 0 1843 targets = [[0.05], [0.5]] 1844 scheduler = CosineAnnealingWarmRestarts( 1845 self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min 1846 ) 1847 for _ in torch.arange(0.1, iters, 0.1): 1848 T_cur = round(T_cur + 0.1, 1) 1849 if T_cur >= T_i: 1850 T_cur = T_cur - T_i 1851 T_i = int(T_mult) * T_i 1852 targets[0] += [ 1853 eta_min 1854 + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1855 ] 1856 targets[1] += [ 1857 eta_min 1858 + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1859 ] 1860 self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters) 1861 1862 def test_CosineAnnealingWarmRestarts_lr3(self): 1863 epochs_for_T_mults = [ 1864 [0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13], 1865 [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3], 1866 [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50], 1867 ] 1868 T_curs_for_T_mults = [ 1869 [1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3], 1870 [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3], 1871 [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10], 1872 ] 1873 T_is_for_T_mults = [ 1874 [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], 1875 [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10], 1876 [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90], 1877 ] 1878 eta_min = 1e-10 1879 T_mults = [1, 2, 3] 1880 for epochs, T_mult, T_curs, T_is in zip( 1881 epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults 1882 ): 1883 targets = [[0.05], [0.5]] 1884 scheduler = CosineAnnealingWarmRestarts( 1885 self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min 1886 ) 1887 for T_cur, T_i in zip(T_curs, T_is): 1888 targets[0] += [ 1889 eta_min 1890 + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1891 ] 1892 targets[1] += [ 1893 eta_min 1894 + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 1895 ] 1896 self._test_interleaved_CosineAnnealingWarmRestarts( 1897 scheduler, targets, epochs 1898 ) 1899 1900 def test_swalr_no_anneal(self): 1901 epochs, swa_start, swa_lr = 10, 5, 0.01 1902 initial_lrs = [group["lr"] for group in self.opt.param_groups] 1903 targets = [ 1904 [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1) 1905 for lr in initial_lrs 1906 ] 1907 swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr) 1908 self._test_swalr(swa_scheduler, None, targets, swa_start, epochs) 1909 1910 def test_swalr_cosine_anneal_after_multiplicative(self): 1911 # same swa_lr for different param_groups 1912 epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5 1913 mult_factor = 0.9 1914 scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) 1915 swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr) 1916 1917 def anneal_coef(t): 1918 if t + 1 >= anneal_epochs: 1919 return 0.0 1920 return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2 1921 1922 initial_lrs = [group["lr"] for group in self.opt.param_groups] 1923 targets_before_swa = [ 1924 [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs 1925 ] 1926 swa_epochs = epochs - swa_start - 1 1927 targets = [ 1928 lrs 1929 + [ 1930 lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) 1931 for t in range(swa_epochs) 1932 ] 1933 for lrs in targets_before_swa 1934 ] 1935 1936 self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) 1937 1938 def test_swalr_linear_anneal_after_multiplicative(self): 1939 # separate swa_lr for different param_groups 1940 epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4 1941 mult_factor = 0.9 1942 scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) 1943 swa_scheduler = SWALR( 1944 self.opt, 1945 anneal_epochs=anneal_epochs, 1946 anneal_strategy="linear", 1947 swa_lr=swa_lrs, 1948 ) 1949 1950 def anneal_coef(t): 1951 if t + 1 >= anneal_epochs: 1952 return 0.0 1953 return 1 - (t + 1) / anneal_epochs 1954 1955 initial_lrs = [group["lr"] for group in self.opt.param_groups] 1956 targets_before_swa = [ 1957 [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs 1958 ] 1959 swa_epochs = epochs - swa_start - 1 1960 targets = [ 1961 lrs 1962 + [ 1963 lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) 1964 for t in range(swa_epochs) 1965 ] 1966 for lrs, swa_lr in zip(targets_before_swa, swa_lrs) 1967 ] 1968 1969 self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) 1970 1971 def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs): 1972 for epoch in range(epochs): 1973 for param_group, target in zip(self.opt.param_groups, targets): 1974 self.assertEqual( 1975 target[epoch], 1976 param_group["lr"], 1977 msg="LR is wrong in epoch {}: expected {}, got {}".format( 1978 epoch, target[epoch], param_group["lr"] 1979 ), 1980 atol=1e-5, 1981 rtol=0, 1982 ) 1983 if epoch >= swa_start: 1984 self.opt.step() 1985 swa_scheduler.step() 1986 elif scheduler is not None: 1987 self.opt.step() 1988 scheduler.step() 1989 1990 def test_swalr_hypers(self): 1991 # Test that SWALR raises errors for incorrect hyper-parameters 1992 with self.assertRaisesRegex(ValueError, "anneal_strategy must"): 1993 swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0) 1994 1995 with self.assertRaisesRegex(ValueError, "anneal_epochs must"): 1996 swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0) 1997 with self.assertRaisesRegex(ValueError, "anneal_epochs must"): 1998 swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0) 1999 with self.assertRaisesRegex(ValueError, "swa_lr must"): 2000 swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01]) 2001 2002 def test_step_lr_state_dict(self): 2003 self._check_scheduler_state_dict( 2004 lambda: StepLR(self.opt, gamma=0.1, step_size=3), 2005 lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1), 2006 ) 2007 2008 def test_multi_step_lr_state_dict(self): 2009 self._check_scheduler_state_dict( 2010 lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]), 2011 lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]), 2012 ) 2013 2014 def test_exp_step_lr_state_dict(self): 2015 self._check_scheduler_state_dict( 2016 lambda: ExponentialLR(self.opt, gamma=0.1), 2017 lambda: ExponentialLR(self.opt, gamma=0.01), 2018 ) 2019 2020 def test_cosine_lr_state_dict(self): 2021 epochs = 10 2022 eta_min = 1e-10 2023 self._check_scheduler_state_dict( 2024 lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min), 2025 lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2), 2026 epochs=epochs, 2027 ) 2028 2029 def test_reduce_lr_on_plateau_state_dict(self): 2030 scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2) 2031 for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]: 2032 scheduler.step(score) 2033 scheduler_copy = ReduceLROnPlateau( 2034 self.opt, mode="max", factor=0.5, patience=10 2035 ) 2036 scheduler_copy.load_state_dict(scheduler.state_dict()) 2037 for key in scheduler.__dict__.keys(): 2038 if key not in {"optimizer", "is_better"}: 2039 self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2040 2041 def test_lambda_lr_state_dict_fn(self): 2042 scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x) 2043 state = scheduler.state_dict() 2044 self.assertIsNone(state["lr_lambdas"][0]) 2045 2046 scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x) 2047 scheduler_copy.load_state_dict(state) 2048 for key in scheduler.__dict__.keys(): 2049 if key not in {"optimizer", "lr_lambdas"}: 2050 self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2051 2052 def test_lambda_lr_state_dict_obj(self): 2053 scheduler = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(10)) 2054 state = scheduler.state_dict() 2055 self.assertIsNotNone(state["lr_lambdas"][0]) 2056 2057 scheduler_copy = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(-1)) 2058 scheduler_copy.load_state_dict(state) 2059 for key in scheduler.__dict__.keys(): 2060 if key not in {"optimizer"}: 2061 self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2062 2063 def test_CosineAnnealingWarmRestarts_lr_state_dict(self): 2064 self._check_scheduler_state_dict( 2065 lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2), 2066 lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100), 2067 ) 2068 2069 def test_swa_lr_state_dict(self): 2070 self._check_scheduler_state_dict( 2071 lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5), 2072 lambda: SWALR( 2073 self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0 2074 ), 2075 ) 2076 2077 def _check_scheduler_state_dict(self, constr, constr2, epochs=10): 2078 scheduler = constr() 2079 for _ in range(epochs): 2080 scheduler.optimizer.step() 2081 scheduler.step() 2082 scheduler_copy = constr2() 2083 scheduler_copy.load_state_dict(scheduler.state_dict()) 2084 for key in scheduler.__dict__.keys(): 2085 if key != "optimizer": 2086 self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) 2087 self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) 2088 2089 def _test_get_last_lr(self, schedulers, targets, epochs=10): 2090 if isinstance(schedulers, LRScheduler): 2091 schedulers = [schedulers] 2092 optimizers = {scheduler.optimizer for scheduler in schedulers} 2093 for epoch in range(epochs): 2094 result = [scheduler.get_last_lr() for scheduler in schedulers] 2095 [optimizer.step() for optimizer in optimizers] 2096 [scheduler.step() for scheduler in schedulers] 2097 target = [[t[epoch] for t in targets]] * len(schedulers) 2098 for t, r in zip(target, result): 2099 self.assertEqual( 2100 t, 2101 r, 2102 msg=f"LR is wrong in epoch {epoch}: expected {t}, got {r}", 2103 atol=1e-5, 2104 rtol=0, 2105 ) 2106 2107 def _test_with_epoch(self, schedulers, targets, epochs=10): 2108 if isinstance(schedulers, LRScheduler): 2109 schedulers = [schedulers] 2110 optimizers = {scheduler.optimizer for scheduler in schedulers} 2111 for epoch in range(epochs): 2112 [optimizer.step() for optimizer in optimizers] 2113 with warnings.catch_warnings(record=True) as w: 2114 [ 2115 scheduler.step(epoch) for scheduler in schedulers 2116 ] # step before assert: skip initial lr 2117 self._check_warning_is_epoch_deprecation_warning( 2118 w, num_warnings=len(schedulers) 2119 ) 2120 for param_group, target in zip(self.opt.param_groups, targets): 2121 self.assertEqual( 2122 target[epoch], 2123 param_group["lr"], 2124 msg="LR is wrong in epoch {}: expected {}, got {}".format( 2125 epoch, target[epoch], param_group["lr"] 2126 ), 2127 atol=1e-5, 2128 rtol=0, 2129 ) 2130 2131 def _test(self, schedulers, targets, epochs=10): 2132 if isinstance(schedulers, LRScheduler): 2133 schedulers = [schedulers] 2134 for epoch in range(epochs): 2135 for param_group, target in zip(self.opt.param_groups, targets): 2136 self.assertEqual( 2137 target[epoch], 2138 param_group["lr"], 2139 msg="LR is wrong in epoch {}: expected {}, got {}".format( 2140 epoch, target[epoch], param_group["lr"] 2141 ), 2142 atol=1e-5, 2143 rtol=0, 2144 ) 2145 [scheduler.step() for scheduler in schedulers] 2146 2147 def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10): 2148 for index, epoch in enumerate(torch.arange(0, epochs, 0.1)): 2149 epoch = round(epoch.item(), 1) 2150 scheduler.step(epoch) 2151 for param_group, target in zip(self.opt.param_groups, targets): 2152 self.assertEqual( 2153 target[index], 2154 param_group["lr"], 2155 msg="LR is wrong in epoch {}: expected {}, got {}".format( 2156 epoch, target[index], param_group["lr"] 2157 ), 2158 atol=1e-5, 2159 rtol=0, 2160 ) 2161 2162 def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs): 2163 for index, epoch in enumerate(epochs): 2164 scheduler.step(epoch) 2165 for param_group, target in zip(self.opt.param_groups, targets): 2166 self.assertEqual( 2167 target[index], 2168 param_group["lr"], 2169 msg="LR is wrong in epoch {}: expected {}, got {}".format( 2170 epoch, target[index], param_group["lr"] 2171 ), 2172 atol=1e-5, 2173 rtol=0, 2174 ) 2175 2176 def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10): 2177 self.setUp() 2178 targets = [] 2179 for epoch in range(epochs): 2180 closed_form_scheduler.optimizer.step() 2181 with warnings.catch_warnings(record=True) as w: 2182 closed_form_scheduler.step(epoch) 2183 self._check_warning_is_epoch_deprecation_warning(w) 2184 targets.append([group["lr"] for group in self.opt.param_groups]) 2185 self.setUp() 2186 for epoch in range(epochs): 2187 self.opt.step() 2188 scheduler.step() 2189 for i, param_group in enumerate(self.opt.param_groups): 2190 self.assertEqual( 2191 targets[epoch][i], 2192 param_group["lr"], 2193 msg="LR is wrong in epoch {}: expected {}, got {}".format( 2194 epoch, targets[epoch][i], param_group["lr"] 2195 ), 2196 atol=1e-5, 2197 rtol=0, 2198 ) 2199 2200 def _test_reduce_lr_on_plateau( 2201 self, schedulers, targets, metrics, epochs=10, verbose=False 2202 ): 2203 if isinstance(schedulers, (LRScheduler, ReduceLROnPlateau)): 2204 schedulers = [schedulers] 2205 for epoch in range(epochs): 2206 self.opt.step() 2207 for scheduler in schedulers: 2208 if isinstance(scheduler, ReduceLROnPlateau): 2209 scheduler.step(metrics[epoch]) 2210 else: 2211 scheduler.step() 2212 if verbose: 2213 print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"])) 2214 for param_group, target in zip(self.opt.param_groups, targets): 2215 self.assertEqual( 2216 target[epoch], 2217 param_group["lr"], 2218 msg="LR is wrong in epoch {}: expected {}, got {}".format( 2219 epoch, target[epoch], param_group["lr"] 2220 ), 2221 atol=1e-5, 2222 rtol=0, 2223 ) 2224 2225 def _test_cycle_lr( 2226 self, 2227 scheduler, 2228 lr_targets, 2229 momentum_targets, 2230 batch_iterations, 2231 verbose=False, 2232 use_beta1=False, 2233 ): 2234 for batch_num in range(batch_iterations): 2235 if verbose: 2236 if "momentum" in self.opt.param_groups[0].keys(): 2237 print( 2238 "batch{}:\tlr={},momentum={}".format( 2239 batch_num, 2240 self.opt.param_groups[0]["lr"], 2241 self.opt.param_groups[0]["momentum"], 2242 ) 2243 ) 2244 elif use_beta1 and "betas" in self.opt.param_groups[0].keys(): 2245 print( 2246 "batch{}:\tlr={},beta1={}".format( 2247 batch_num, 2248 self.opt.param_groups[0]["lr"], 2249 self.opt.param_groups[0]["betas"][0], 2250 ) 2251 ) 2252 else: 2253 print( 2254 "batch{}:\tlr={}".format( 2255 batch_num, self.opt.param_groups[0]["lr"] 2256 ) 2257 ) 2258 2259 for param_group, lr_target, momentum_target in zip( 2260 self.opt.param_groups, lr_targets, momentum_targets 2261 ): 2262 self.assertEqual( 2263 lr_target[batch_num], 2264 param_group["lr"], 2265 msg="LR is wrong in batch_num {}: expected {}, got {}".format( 2266 batch_num, lr_target[batch_num], param_group["lr"] 2267 ), 2268 atol=1e-5, 2269 rtol=0, 2270 ) 2271 2272 if use_beta1 and "betas" in param_group.keys(): 2273 self.assertEqual( 2274 momentum_target[batch_num], 2275 param_group["betas"][0], 2276 msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format( 2277 batch_num, 2278 momentum_target[batch_num], 2279 param_group["betas"][0], 2280 ), 2281 atol=1e-5, 2282 rtol=0, 2283 ) 2284 elif "momentum" in param_group.keys(): 2285 self.assertEqual( 2286 momentum_target[batch_num], 2287 param_group["momentum"], 2288 msg="Momentum is wrong in batch_num {}: expected {}, got {}".format( 2289 batch_num, 2290 momentum_target[batch_num], 2291 param_group["momentum"], 2292 ), 2293 atol=1e-5, 2294 rtol=0, 2295 ) 2296 self.opt.step() 2297 scheduler.step() 2298 2299 def test_cosine_then_cyclic(self): 2300 # https://github.com/pytorch/pytorch/issues/21965 2301 2302 max_lr = 0.3 2303 base_lr = 0.1 2304 optim_lr = 0.5 2305 2306 model = torch.nn.Linear(2, 1) 2307 optimizer = SGD(model.parameters(), lr=optim_lr) 2308 lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR( 2309 optimizer, T_max=20, eta_min=0.1 2310 ) 2311 lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR( 2312 optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3 2313 ) 2314 2315 for i in range(40): 2316 optimizer.step() 2317 if i <= lr_scheduler_1.T_max: 2318 lr_scheduler_1.step() 2319 else: 2320 lr_scheduler_2.step() 2321 last_lr = optimizer.param_groups[0]["lr"] 2322 2323 self.assertLessEqual(last_lr, max_lr) 2324 2325 @parametrize( 2326 "LRClass", 2327 [ 2328 partial(LambdaLR, lr_lambda=lambda e: e // 10), 2329 partial(MultiplicativeLR, lr_lambda=lambda: 0.95), 2330 partial(StepLR, step_size=30), 2331 partial(MultiStepLR, milestones=[30, 80]), 2332 ConstantLR, 2333 LinearLR, 2334 partial(ExponentialLR, gamma=0.9), 2335 lambda opt, **kwargs: SequentialLR( 2336 opt, 2337 schedulers=[ConstantLR(opt), ConstantLR(opt)], 2338 milestones=[2], 2339 **kwargs, 2340 ), 2341 PolynomialLR, 2342 partial(CosineAnnealingLR, T_max=10), 2343 ReduceLROnPlateau, 2344 partial(CyclicLR, base_lr=0.01, max_lr=0.1), 2345 partial(CosineAnnealingWarmRestarts, T_0=20), 2346 partial(OneCycleLR, max_lr=0.01, total_steps=10), 2347 ], 2348 ) 2349 def test_lr_scheduler_verbose_deprecation_warning(self, LRClass): 2350 """Check that a deprecating warning with verbose parameter.""" 2351 with self.assertWarnsOnceRegex( 2352 UserWarning, "The verbose parameter is deprecated" 2353 ): 2354 LRClass(self.opt, verbose=True) 2355 2356 with self.assertWarnsOnceRegex( 2357 UserWarning, "The verbose parameter is deprecated" 2358 ): 2359 LRClass(self.opt, verbose=False) 2360 2361 # No warning is raised when verbose is the default value. 2362 with warnings.catch_warnings(): 2363 warnings.simplefilter("error", UserWarning) 2364 LRClass(self.opt) 2365 2366 @parametrize( 2367 "LRClass", 2368 [ 2369 partial(LambdaLR, lr_lambda=lambda e: e // 10), 2370 partial(MultiplicativeLR, lr_lambda=lambda: 0.95), 2371 partial(StepLR, step_size=30), 2372 partial(MultiStepLR, milestones=[30, 80]), 2373 ConstantLR, 2374 LinearLR, 2375 partial(ExponentialLR, gamma=0.9), 2376 PolynomialLR, 2377 partial(CosineAnnealingLR, T_max=10), 2378 lambda opt, **kwargs: ChainedScheduler( 2379 schedulers=[ConstantLR(opt), ConstantLR(opt)], **kwargs 2380 ), 2381 lambda opt, **kwargs: SequentialLR( 2382 opt, 2383 schedulers=[ConstantLR(opt), ConstantLR(opt)], 2384 milestones=[2], 2385 **kwargs, 2386 ), 2387 ReduceLROnPlateau, 2388 partial(CyclicLR, base_lr=0.01, max_lr=0.1), 2389 partial(OneCycleLR, max_lr=0.01, total_steps=10, anneal_strategy="linear"), 2390 partial(CosineAnnealingWarmRestarts, T_0=20), 2391 ], 2392 ) 2393 @parametrize("weights_only", [True, False]) 2394 def test_lr_scheduler_state_dict_load(self, LRClass, weights_only): 2395 scheduler = LRClass(self.opt) 2396 state_dict = scheduler.state_dict() 2397 2398 with tempfile.TemporaryFile() as f: 2399 torch.save(state_dict, f) 2400 f.seek(0) 2401 state_dict_loaded = torch.load(f, weights_only=weights_only) 2402 self.assertEqual(state_dict, state_dict_loaded) 2403 # Make sure state_dict can be loaded 2404 scheduler2 = LRClass(self.opt) 2405 scheduler2.load_state_dict(state_dict_loaded) 2406 self.assertEqual(scheduler2.state_dict(), state_dict) 2407 2408 @parametrize( 2409 "LRClass", 2410 [ 2411 partial(LambdaLR, lr_lambda=lambda e: e // 10), 2412 partial(MultiplicativeLR, lr_lambda=lambda e: 0.95), 2413 partial(StepLR, step_size=30), 2414 partial(MultiStepLR, milestones=[30, 80]), 2415 ConstantLR, 2416 LinearLR, 2417 partial(ExponentialLR, gamma=0.9), 2418 PolynomialLR, 2419 partial(CosineAnnealingLR, T_max=10), 2420 partial(CosineAnnealingWarmRestarts, T_0=20), 2421 ], 2422 ) 2423 def test_constant_initial_lr(self, LRClass): 2424 # Test that the initial learning rate is constant 2425 lr = torch.as_tensor(0.1) 2426 opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2427 sch = LRClass(opt) 2428 2429 ori_param_groups = copy.deepcopy(opt.param_groups) 2430 2431 for i in range(2): 2432 opt.step() 2433 sch.step(i) 2434 lr.multiply_(0.1) 2435 for group, ori_group in zip(opt.param_groups, ori_param_groups): 2436 self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2437 self.assertEqual(sch.base_lrs, [0.1]) 2438 2439 def test_constant_initial_params_cyclelr(self): 2440 # Test that the initial learning rate is constant 2441 lr = torch.as_tensor(0.1) 2442 max_lr = torch.as_tensor(0.2) 2443 base_momentum = torch.as_tensor(0.8) 2444 max_momentum = torch.as_tensor(0.9) 2445 opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2446 sch = CyclicLR( 2447 opt, 2448 base_lr=lr, 2449 max_lr=max_lr, 2450 base_momentum=base_momentum, 2451 max_momentum=max_momentum, 2452 ) 2453 ori_param_groups = copy.deepcopy(opt.param_groups) 2454 2455 for i in range(2): 2456 lr.multiply_(0.5) 2457 max_lr.multiply_(0.5) 2458 base_momentum.multiply_(0.5) 2459 max_momentum.multiply_(0.5) 2460 opt.step() 2461 sch.step(i) 2462 for group, ori_group in zip(opt.param_groups, ori_param_groups): 2463 self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2464 self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) 2465 self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) 2466 self.assertEqual(sch.base_lrs, [0.1]) 2467 self.assertEqual(sch.max_lrs, [0.2]) 2468 self.assertEqual(group["max_momentum"], 0.9) 2469 self.assertEqual(group["base_momentum"], 0.8) 2470 2471 def test_constant_initial_params_onecyclelr(self): 2472 # Test that the initial learning rate is constant 2473 lr = torch.as_tensor(0.1) 2474 base_momentum = torch.as_tensor(0.85) 2475 max_momentum = torch.as_tensor(0.95) 2476 opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2477 sch = OneCycleLR( 2478 opt, 2479 max_lr=lr, 2480 total_steps=10, 2481 base_momentum=base_momentum, 2482 max_momentum=max_momentum, 2483 ) 2484 ori_param_groups = copy.deepcopy(opt.param_groups) 2485 2486 for i in range(2): 2487 lr.multiply_(0.5) 2488 base_momentum.multiply_(0.5) 2489 max_momentum.multiply_(0.5) 2490 opt.step() 2491 sch.step(i) 2492 2493 for group, ori_group in zip(opt.param_groups, ori_param_groups): 2494 self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2495 self.assertEqual(group["max_lr"], ori_group["max_lr"]) 2496 self.assertEqual(group["min_lr"], ori_group["min_lr"]) 2497 self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) 2498 self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) 2499 self.assertEqual(group["max_momentum"], 0.95) 2500 self.assertEqual(group["base_momentum"], 0.85) 2501 2502 def test_constant_initial_params_swalr(self): 2503 # Test that the initial learning rate is constant 2504 lr = torch.as_tensor(0.1) 2505 swa_lr = torch.as_tensor(0.05) 2506 opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) 2507 sch = SWALR(opt, swa_lr=swa_lr) 2508 ori_param_groups = copy.deepcopy(opt.param_groups) 2509 2510 for i in range(2): 2511 lr.multiply_(0.5) 2512 swa_lr.multiply_(0.5) 2513 opt.step() 2514 sch.step() 2515 for group, ori_group in zip(opt.param_groups, ori_param_groups): 2516 self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) 2517 self.assertEqual(group["swa_lr"], ori_group["swa_lr"]) 2518 self.assertEqual(group["swa_lr"], 0.05) 2519 self.assertEqual(sch.base_lrs, [0.1]) 2520 2521 2522instantiate_parametrized_tests(TestLRScheduler) 2523 2524 2525if __name__ == "__main__": 2526 print("These tests should be run through test/test_optim.py instead") 2527