xref: /aosp_15_r20/external/pytorch/test/optim/test_lrscheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: optimizer", "module: LrScheduler" ]
2import copy
3import math
4import pickle
5import tempfile
6import types
7import warnings
8from functools import partial
9
10import torch
11import torch.nn.functional as F
12from torch.nn import Parameter
13from torch.optim import Adam, Rprop, SGD
14from torch.optim.lr_scheduler import (
15    ChainedScheduler,
16    ConstantLR,
17    CosineAnnealingLR,
18    CosineAnnealingWarmRestarts,
19    CyclicLR,
20    EPOCH_DEPRECATION_WARNING,
21    ExponentialLR,
22    LambdaLR,
23    LinearLR,
24    LRScheduler,
25    MultiplicativeLR,
26    MultiStepLR,
27    OneCycleLR,
28    PolynomialLR,
29    ReduceLROnPlateau,
30    SequentialLR,
31    StepLR,
32)
33from torch.optim.swa_utils import SWALR
34from torch.testing._internal.common_utils import (
35    instantiate_parametrized_tests,
36    load_tests,
37    parametrize,
38    skipIfTorchDynamo,
39    TestCase,
40)
41
42
43# load_tests from common_utils is used to automatically filter tests for
44# sharding on sandcastle. This line silences flake warnings
45load_tests = load_tests
46
47
48class TestLRScheduler(TestCase):
49    class SchedulerTestNet(torch.nn.Module):
50        def __init__(self) -> None:
51            super().__init__()
52            self.conv1 = torch.nn.Conv2d(1, 1, 1)
53            self.conv2 = torch.nn.Conv2d(1, 1, 1)
54
55        def forward(self, x):
56            return self.conv2(F.relu(self.conv1(x)))
57
58    class LambdaLRTestObject:
59        def __init__(self, value):
60            self.value = value
61
62        def __call__(self, epoch):
63            return self.value * epoch
64
65        def __eq__(self, other):
66            if isinstance(other, self.__class__):
67                return self.__dict__ == other.__dict__
68            else:
69                return False
70
71    exact_dtype = True
72
73    def setUp(self):
74        super().setUp()
75        self.net = self.SchedulerTestNet()
76        self.opt = SGD(
77            [
78                {"params": self.net.conv1.parameters()},
79                {"params": self.net.conv2.parameters(), "lr": 0.5},
80            ],
81            lr=0.05,
82        )
83
84    def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1):
85        """This function swallows the epoch deprecation warning which is produced when we
86        call `scheduler.step(epoch)` with some not `None` value of `epoch`.
87        this is deprecated, and this function will need to be removed/updated when
88        the schedulers no longer accept the parameter at all.
89        """
90        self.assertEqual(len(w), num_warnings)
91        for warning in w:
92            self.assertEqual(len(warning.message.args), 1)
93            self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING)
94
95    def test_error_when_getlr_has_epoch(self):
96        class MultiStepLR(torch.optim.lr_scheduler.LRScheduler):
97            def __init__(self, optimizer, gamma, milestones, last_epoch=-1):
98                self.init_lr = [group["lr"] for group in optimizer.param_groups]
99                self.gamma = gamma
100                self.milestones = milestones
101                super().__init__(optimizer, last_epoch)
102
103            def get_lr(self, step):
104                global_step = self.last_epoch
105                gamma_power = (
106                    [0]
107                    + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m]
108                )[-1]
109                return [
110                    init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr
111                ]
112
113        optimizer = SGD([torch.rand(1)], lr=1)
114
115        with self.assertRaises(TypeError):
116            scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20])
117
118    @skipIfTorchDynamo(
119        "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames"
120    )
121    def test_no_cyclic_references(self):
122        import gc
123
124        param = Parameter(torch.empty(10))
125        optim = SGD([param], lr=0.5)
126        scheduler = LambdaLR(optim, lambda epoch: 1.0)
127        del scheduler
128
129        self.assertTrue(
130            len(gc.get_referrers(optim)) == 0,
131            "Optimizer should contain no cyclic references",
132        )
133
134        gc.collect()
135        del optim
136        self.assertEqual(
137            gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__"
138        )
139
140    @skipIfTorchDynamo(
141        "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames"
142    )
143    def test_no_cyclic_references_in_step(self):
144        import gc
145        import weakref
146
147        def run():
148            param = torch.empty(10, requires_grad=True)
149            optim = SGD(params=[param], lr=0.5)
150            scheduler = LambdaLR(optim, lambda epoch: 1.0)
151            param.sum().backward()
152            optim.step()
153            scheduler.step()
154
155            return weakref.ref(scheduler)
156
157        # To ensure that there are no reference cycles in scheduler,
158        # we need to turn off the garbage collector. Since gc will
159        # automatically collect unreachable objects.
160        gc.disable()
161        ref = run()
162
163        assert ref() is None
164        gc.enable()  # restore
165
166    def test_old_pattern_warning(self):
167        epochs = 35
168        with warnings.catch_warnings(record=True) as ws:
169            warnings.simplefilter("always")  # allow any warning to be raised
170            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
171            self.assertTrue(len(ws) == 0, "No warning should be raised")
172
173        def old_pattern():
174            for _ in range(epochs):
175                scheduler.step()
176                self.opt.step()
177
178        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)
179
180    def test_old_pattern_warning_with_arg(self):
181        epochs = 35
182        with warnings.catch_warnings(record=True) as ws:
183            warnings.simplefilter("always")  # allow any warning to be raised
184            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
185            self.assertTrue(len(ws) == 0, "No warning should be raised")
186
187        def old_pattern2():
188            for _ in range(epochs):
189                scheduler.step()
190                self.opt.step()
191
192        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
193
194    def test_old_pattern_warning_resuming(self):
195        epochs = 35
196        for i, group in enumerate(self.opt.param_groups):
197            group["initial_lr"] = 0.01
198
199        with warnings.catch_warnings(record=True) as ws:
200            warnings.simplefilter("always")  # allow any warning to be raised
201            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
202            self.assertTrue(len(ws) == 0, "No warning should be raised")
203
204        def old_pattern():
205            for _ in range(epochs):
206                scheduler.step()
207                self.opt.step()
208
209        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)
210
211    def test_old_pattern_warning_resuming_with_arg(self):
212        epochs = 35
213        for i, group in enumerate(self.opt.param_groups):
214            group["initial_lr"] = 0.01
215
216        with warnings.catch_warnings(record=True) as ws:
217            warnings.simplefilter("always")  # allow any warning to be raised
218            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
219            self.assertTrue(len(ws) == 0, "No warning should be raised")
220
221        def old_pattern2():
222            for _ in range(epochs):
223                scheduler.step()
224                self.opt.step()
225
226        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
227
228    def test_old_pattern_warning_with_overridden_optim_step(self):
229        epochs = 35
230        for i, group in enumerate(self.opt.param_groups):
231            group["initial_lr"] = 0.01
232
233        with warnings.catch_warnings(record=True) as ws:
234            warnings.simplefilter("always")  # allow any warning to be raised
235            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
236            self.assertTrue(len(ws) == 0, "No warning should be raised")
237
238        # emulate use-case with optimizer.step overridden
239        import types
240
241        old_step = self.opt.step
242
243        def new_step(o, *args, **kwargs):
244            retval = old_step(*args, **kwargs)
245            return retval
246
247        self.opt.step = types.MethodType(new_step, self.opt)
248
249        def old_pattern2():
250            for _ in range(epochs):
251                scheduler.step()
252                self.opt.step()
253
254        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
255
256    def test_new_pattern_no_warning(self):
257        epochs = 35
258        with warnings.catch_warnings(record=True) as ws:
259            warnings.simplefilter("always")  # allow any warning to be raised
260            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
261            self.assertTrue(len(ws) == 0, "No warning should be raised")
262
263        with warnings.catch_warnings(record=True) as ws:
264            warnings.simplefilter("always")  # allow any warning to be raised
265            for _ in range(epochs):
266                self.opt.step()
267                scheduler.step()
268            self.assertTrue(len(ws) == 0, "No warning should be raised")
269
270    def test_new_pattern_no_warning_with_arg(self):
271        epochs = 35
272        with warnings.catch_warnings(record=True) as ws:
273            warnings.simplefilter("always")  # allow any warning to be raised
274            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
275            self.assertTrue(len(ws) == 0, "No warning should be raised")
276
277        with warnings.catch_warnings(record=True) as ws:
278            warnings.simplefilter("always")  # allow any warning to be raised
279            for _ in range(epochs):
280                self.opt.step()
281                scheduler.step()
282            self.assertTrue(len(ws) == 0, "No warning should be raised")
283
284    def test_new_pattern_no_warning_with_overridden_optim_step(self):
285        epochs = 35
286        with warnings.catch_warnings(record=True) as ws:
287            warnings.simplefilter("always")  # allow any warning to be raised
288            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
289            self.assertTrue(len(ws) == 0, "No warning should be raised")
290
291        # emulate use-case with optimizer.step overridden
292        import types
293
294        old_step = self.opt.step
295
296        def new_step(o, *args, **kwargs):
297            retval = old_step(*args, **kwargs)
298            return retval
299
300        self.opt.step = types.MethodType(new_step, self.opt)
301
302        def new_pattern():
303            for e in range(epochs):
304                self.opt.step()
305                scheduler.step()
306
307        self.assertWarnsRegex(
308            UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern
309        )
310
311    def _test_lr_is_constant_for_constant_epoch(self, scheduler):
312        l = []
313
314        for _ in range(10):
315            scheduler.optimizer.step()
316            with warnings.catch_warnings(record=True) as w:
317                scheduler.step(2)
318                self._check_warning_is_epoch_deprecation_warning(w)
319
320            l.append(self.opt.param_groups[0]["lr"])
321        self.assertEqual(min(l), max(l))
322
323    def test_step_lr_is_constant_for_constant_epoch(self):
324        scheduler = StepLR(self.opt, 2)
325        self._test_lr_is_constant_for_constant_epoch(scheduler)
326
327    def test_exponential_lr_is_constant_for_constant_epoch(self):
328        scheduler = ExponentialLR(self.opt, gamma=0.9)
329        self._test_lr_is_constant_for_constant_epoch(scheduler)
330
331    def test_constantlr_is_constant_for_constant_epoch(self):
332        scheduler = ConstantLR(self.opt)
333        self._test_lr_is_constant_for_constant_epoch(scheduler)
334
335    def test_linear_linearlr_is_constant_for_constant_epoch(self):
336        scheduler = LinearLR(self.opt)
337        self._test_lr_is_constant_for_constant_epoch(scheduler)
338
339    def test_polynomial_lr_is_constant_for_constant_epoch(self):
340        scheduler = PolynomialLR(self.opt, power=0.9)
341        self._test_lr_is_constant_for_constant_epoch(scheduler)
342
343    def test_step_lr(self):
344        # lr = 0.05     if epoch < 3
345        # lr = 0.005    if 30 <= epoch < 6
346        # lr = 0.0005   if epoch >= 9
347        epochs = 10
348        single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
349        targets = [single_targets, [x * epochs for x in single_targets]]
350        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
351        self._test(scheduler, targets, epochs)
352
353    def test_get_last_lr_step_lr(self):
354        from torch.nn import Parameter
355
356        epochs = 10
357        optimizer = SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1)
358        targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]]
359        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)
360        self._test_get_last_lr(scheduler, targets, epochs)
361
362    def test_get_last_lr_multi_step_lr(self):
363        # lr = 0.05     if epoch < 2
364        # lr = 0.005    if 2 <= epoch < 5
365        # lr = 0.0005   if 5 <= epoch < 9
366        # lr = 0.00005   if 9 <= epoch
367        epochs = 10
368        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1
369        targets = [single_targets, [x * epochs for x in single_targets]]
370        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
371        self._test_get_last_lr(scheduler, targets, epochs)
372
373    def test_multi_step_lr(self):
374        # lr = 0.05     if epoch < 2
375        # lr = 0.005    if 2 <= epoch < 5
376        # lr = 0.0005   if epoch < 9
377        # lr = 0.00005   if epoch >= 9
378        epochs = 10
379        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
380        targets = [single_targets, [x * epochs for x in single_targets]]
381        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
382        self._test(scheduler, targets, epochs)
383
384    def test_multi_step_lr_with_epoch(self):
385        # lr = 0.05     if epoch < 2
386        # lr = 0.005    if 2 <= epoch < 5
387        # lr = 0.0005   if epoch < 9
388        # lr = 0.00005   if epoch >= 9
389        epochs = 10
390        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
391        targets = [single_targets, [x * epochs for x in single_targets]]
392        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
393        self._test_with_epoch(scheduler, targets, epochs)
394
395    def test_get_last_lr_constantlr(self):
396        # lr = 0.025     if epoch < 5
397        # lr = 0.005    if 5 <= epoch
398        epochs = 10
399        single_targets = [0.025] * 5 + [0.05] * 5
400        targets = [single_targets, [x * epochs for x in single_targets]]
401        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
402        self._test_get_last_lr(scheduler, targets, epochs)
403
404    def test_get_last_lr_linearlr(self):
405        # lr = 0.025     if epoch == 0
406        # lr = 0.03125   if epoch == 1
407        # lr = 0.0375    if epoch == 2
408        # lr = 0.04375   if epoch == 3
409        # lr = 0.005     if 4 <= epoch
410        epochs = 10
411        start_factor = 1.0 / 4
412        end_factor = 3.0 / 5
413        iters = 4
414        interpolation = [
415            start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
416        ]
417        single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (
418            epochs - iters
419        )
420        targets = [single_targets, [x * epochs for x in single_targets]]
421        scheduler = LinearLR(
422            self.opt,
423            start_factor=start_factor,
424            end_factor=end_factor,
425            total_iters=iters,
426        )
427        self._test_get_last_lr(scheduler, targets, epochs)
428
429    def test_constantlr(self):
430        # lr = 0.025     if epoch < 5
431        # lr = 0.005    if 5 <= epoch
432        epochs = 10
433        single_targets = [0.025] * 5 + [0.05] * 5
434        targets = [single_targets, [x * epochs for x in single_targets]]
435        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
436        self._test(scheduler, targets, epochs)
437
438    def test_linearlr(self):
439        # lr = 0.025     if epoch == 0
440        # lr = 0.03125   if epoch == 1
441        # lr = 0.0375    if epoch == 2
442        # lr = 0.04375   if epoch == 3
443        # lr = 0.005     if 4 <= epoch
444        epochs = 10
445        start_factor = 1.0 / 2
446        iters = 4
447        interpolation = [
448            start_factor + i * (1 - start_factor) / iters for i in range(iters)
449        ]
450        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
451        targets = [single_targets, [x * epochs for x in single_targets]]
452        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
453        self._test(scheduler, targets, epochs)
454
455    def test_linearlr_start_factor_limits1(self):
456        start_factor = 0.0
457        iters = 4
458        with self.assertRaises(ValueError):
459            LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
460
461    def test_linearlr_start_factor_limits2(self):
462        start_factor = 1.1
463        iters = 4
464        with self.assertRaises(ValueError):
465            LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
466
467    def test_constantlr_with_epoch(self):
468        # lr = 0.025     if epoch < 5
469        # lr = 0.005    if 5 <= epoch
470        epochs = 10
471        single_targets = [0.025] * 5 + [0.05] * 5
472        targets = [single_targets, [x * epochs for x in single_targets]]
473        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
474        self._test_with_epoch(scheduler, targets, epochs)
475
476    def test_linearlr_with_epoch(self):
477        # lr = 0.025     if epoch == 0
478        # lr = 0.03125   if epoch == 1
479        # lr = 0.0375    if epoch == 2
480        # lr = 0.04375   if epoch == 3
481        # lr = 0.005     if 4 <= epoch
482        epochs = 10
483        start_factor = 1.0 / 2
484        end_factor = 1.0
485        iters = 4
486        interpolation = [
487            start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
488        ]
489        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
490        targets = [single_targets, [x * epochs for x in single_targets]]
491        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
492        self._test_with_epoch(scheduler, targets, epochs)
493
494    def test_exp_lr(self):
495        epochs = 10
496        single_targets = [0.05 * (0.9**x) for x in range(epochs)]
497        targets = [single_targets, [x * epochs for x in single_targets]]
498        scheduler = ExponentialLR(self.opt, gamma=0.9)
499        self._test(scheduler, targets, epochs)
500
501    def test_poly_lr(self):
502        epochs = 10
503        power = 0.9
504        total_iters = 5
505        single_targets = [
506            (1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters)
507        ] + [0.0] * (epochs - total_iters)
508        targets = [single_targets, [x * epochs for x in single_targets]]
509        scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters)
510        self._test(scheduler, targets, epochs)
511
512    def test_cos_anneal_lr(self):
513        epochs = 10
514        eta_min = 1e-10
515        single_targets = [
516            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
517            for x in range(epochs)
518        ]
519        targets = [single_targets, [x * epochs for x in single_targets]]
520        scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
521        self._test(scheduler, targets, epochs)
522
523    def test_closed_form_step_lr(self):
524        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
525        closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
526        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
527
528    def test_closed_form_linearlr(self):
529        scheduler = LinearLR(
530            self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
531        )
532        closed_form_scheduler = LinearLR(
533            self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
534        )
535        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
536
537    def test_closed_form_constantlr(self):
538        scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
539        closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
540        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
541
542    def test_closed_form_multi_step_lr(self):
543        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
544        closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
545        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
546
547    def test_closed_form_exp_lr(self):
548        scheduler = ExponentialLR(self.opt, gamma=0.9)
549        closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9)
550        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
551
552    def test_closed_form_poly_lr(self):
553        scheduler = PolynomialLR(self.opt, power=0.9)
554        closed_form_scheduler = PolynomialLR(self.opt, power=0.9)
555        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
556
557    def test_closed_form_cos_anneal_lr(self):
558        eta_min = 1e-10
559        epochs = 20
560        T_max = 5
561        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
562        closed_form_scheduler = CosineAnnealingLR(
563            self.opt, T_max=T_max, eta_min=eta_min
564        )
565        self._test_against_closed_form(scheduler, closed_form_scheduler, epochs)
566
567    def test_cos_anneal_lr_continue(self):
568        eta_min = 0.1
569        T_max = 5
570        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
571        self.opt.step()
572        scheduler.step()
573        original_lrs = scheduler._last_lr
574        new_scheduler = CosineAnnealingLR(
575            self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0
576        )
577        new_lrs = new_scheduler._last_lr
578        torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5)
579
580    def test_reduce_lr_on_plateau1(self):
581        epochs = 10
582        for param_group in self.opt.param_groups:
583            param_group["lr"] = 0.5
584        targets = [[0.5] * 20]
585        metrics = [10 - i * 0.0167 for i in range(20)]
586        scheduler = ReduceLROnPlateau(
587            self.opt,
588            threshold_mode="abs",
589            mode="min",
590            threshold=0.01,
591            patience=5,
592            cooldown=5,
593        )
594        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
595
596    def test_reduce_lr_on_plateau2(self):
597        epochs = 22
598        for param_group in self.opt.param_groups:
599            param_group["lr"] = 0.5
600        targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
601        metrics = [10 - i * 0.0165 for i in range(22)]
602        scheduler = ReduceLROnPlateau(
603            self.opt,
604            patience=5,
605            cooldown=0,
606            threshold_mode="abs",
607            mode="min",
608            threshold=0.1,
609        )
610        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
611
612    def test_reduce_lr_on_plateau3(self):
613        epochs = 22
614        for param_group in self.opt.param_groups:
615            param_group["lr"] = 0.5
616        targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
617        metrics = [-0.8] * 2 + [-0.234] * 20
618        scheduler = ReduceLROnPlateau(
619            self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
620        )
621        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
622
623    def test_reduce_lr_on_plateau4(self):
624        epochs = 20
625        for param_group in self.opt.param_groups:
626            param_group["lr"] = 0.5
627        targets = [[0.5] * 20]
628        metrics = [1.5 * (1.025**i) for i in range(20)]  # 1.025 > 1.1**0.25
629        scheduler = ReduceLROnPlateau(
630            self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
631        )
632        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
633
634    def test_reduce_lr_on_plateau5(self):
635        epochs = 20
636        for param_group in self.opt.param_groups:
637            param_group["lr"] = 0.5
638        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
639        metrics = [1.5 * (1.005**i) for i in range(20)]
640        scheduler = ReduceLROnPlateau(
641            self.opt,
642            mode="max",
643            threshold_mode="rel",
644            threshold=0.1,
645            patience=5,
646            cooldown=5,
647        )
648        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
649
650    def test_reduce_lr_on_plateau6(self):
651        epochs = 20
652        for param_group in self.opt.param_groups:
653            param_group["lr"] = 0.5
654        targets = [[0.5] * 20]
655        metrics = [1.5 * (0.85**i) for i in range(20)]
656        scheduler = ReduceLROnPlateau(
657            self.opt, mode="min", threshold_mode="rel", threshold=0.1
658        )
659        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
660
661    def test_reduce_lr_on_plateau7(self):
662        epochs = 20
663        for param_group in self.opt.param_groups:
664            param_group["lr"] = 0.5
665        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
666        metrics = [1] * 7 + [0.6] + [0.5] * 12
667        scheduler = ReduceLROnPlateau(
668            self.opt,
669            mode="min",
670            threshold_mode="rel",
671            threshold=0.1,
672            patience=5,
673            cooldown=5,
674        )
675        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
676
677    def test_reduce_lr_on_plateau8(self):
678        epochs = 20
679        for param_group in self.opt.param_groups:
680            param_group["lr"] = 0.5
681        targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
682        metrics = [1.5 * (1.005**i) for i in range(20)]
683        scheduler = ReduceLROnPlateau(
684            self.opt,
685            mode="max",
686            threshold_mode="rel",
687            min_lr=[0.4, 0.3],
688            threshold=0.1,
689            patience=5,
690            cooldown=5,
691        )
692        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
693
694    def test_reduce_lr_on_plateau_get_last_lr_before_step(self):
695        for param_group in self.opt.param_groups:
696            param_group["lr"] = 0.5
697        scheduler = ReduceLROnPlateau(
698            self.opt,
699        )
700        self.assertEqual(
701            scheduler.get_last_lr(), [0.5 for param_group in self.opt.param_groups]
702        )
703
704    def test_sequentiallr1(self):
705        epochs = 19
706        schedulers = [None] * 2
707        targets = [
708            [0.05, 0.04, 0.032]
709            + [0.05 for x in range(4)]
710            + [0.05 * 0.1 for x in range(4)]
711            + [0.05 * 0.01 for x in range(4)]
712            + [0.05 * 0.001 for x in range(4)]
713        ]
714        milestones = [3]
715        schedulers[0] = ExponentialLR(self.opt, gamma=0.8)
716        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4)
717        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
718        self._test(scheduler, targets, epochs)
719
720    def test_sequentiallr2(self):
721        epochs = 13
722        schedulers = [None] * 2
723        targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]]
724        milestones = [3]
725        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
726        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
727        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
728        self._test(scheduler, targets, epochs)
729
730    def test_sequentiallr3(self):
731        epochs = 12
732        schedulers = [None] * 3
733        targets = [
734            [0.005, 0.005, 0.005]
735            + [0.05, 0.04, 0.032]
736            + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
737        ]
738        milestones = [3, 6]
739        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
740        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
741        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
742        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
743        self._test(scheduler, targets, epochs)
744
745    def test_sequentiallr4(self):
746        optimizer = SGD([torch.tensor(0.5)], lr=0.1)
747        prev_lr = optimizer.param_groups[0]["lr"]
748
749        schedulers = [
750            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1),
751            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1),
752        ]
753        scheduler = torch.optim.lr_scheduler.SequentialLR(
754            optimizer, schedulers, milestones=[10]
755        )
756
757        new_lr = optimizer.param_groups[0]["lr"]
758
759        # Ensure that multiple schedulers does not affect the initial learning rate
760        self.assertEqual(prev_lr, new_lr)
761
762    def test_get_last_lr_sequentiallr(self):
763        epochs = 12
764        milestones = [3, 6]
765        schedulers = [None] * 3
766        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
767        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
768        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
769        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
770        constant_lr_target = [0.005] * 3
771        exponential_lr_target = [0.05, 0.04, 0.032]
772        step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
773        single_targets = constant_lr_target + exponential_lr_target + step_lr_target
774        targets = [single_targets, [x * 10 for x in single_targets]]
775        self._test_get_last_lr(scheduler, targets, epochs)
776
777    def test_chained_lr2_get_last_lr_before_step(self):
778        schedulers = [
779            LinearLR(self.opt, start_factor=0.4, total_iters=3),
780            MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1),
781        ]
782        scheduler = ChainedScheduler(schedulers)
783        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
784
785    def test_chained_lr1(self):
786        epochs = 10
787        schedulers = [None] * 1
788        targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3]
789        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
790        scheduler = ChainedScheduler(schedulers)
791        self._test([scheduler], targets, epochs)
792        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
793
794    def test_chained_lr2(self):
795        epochs = 10
796        schedulers = [None] * 1
797        targets = [[0.02, 0.03, 0.04] + [0.05] * 9]
798        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
799        scheduler = ChainedScheduler(schedulers)
800        self._test([scheduler], targets, epochs)
801        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
802
803    def test_chained_lr3(self):
804        epochs = 10
805        schedulers = [None] * 2
806        targets = [
807            [0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3
808        ]
809        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
810        schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
811        scheduler = ChainedScheduler(schedulers)
812        self._test([scheduler], targets, epochs)
813        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
814
815    def test_chained_lr4(self):
816        epochs = 9
817        schedulers = [None] * 3
818        targets = [
819            [0.05 * 0.2 * 0.9**x for x in range(3)]
820            + [0.05 * 0.2 * 0.9**3 * 0.1]
821            + [0.05 * 0.9**x * 0.1 for x in range(4, 6)]
822            + [0.05 * 0.9**x * 0.01 for x in range(6, 9)]
823        ]
824        schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
825        schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4)
826        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3)
827        scheduler = ChainedScheduler(schedulers)
828        self._test([scheduler], targets, epochs)
829        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
830
831    def test_chained_lr5(self):
832        def poly_lr(lr: float):
833            return [
834                (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters)
835            ] + [0.0] * (epochs - total_iters)
836
837        schedulers = [None] * 2
838        epochs = 10
839        power = 0.9
840        total_iters = 5
841        const_factor = 0.1
842        single_targets = [x * const_factor for x in poly_lr(lr=0.05)]
843        targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]]
844        schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters)
845        schedulers[1] = ConstantLR(self.opt, factor=const_factor)
846        scheduler = ChainedScheduler(schedulers)
847        self._test(scheduler, targets, epochs)
848        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
849
850    def test_compound_step_and_multistep_lr(self):
851        epochs = 10
852        schedulers = [None] * 2
853        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
854        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
855        targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]]
856        self._test(schedulers, targets, epochs)
857
858    def test_compound_step_and_exp_lr(self):
859        epochs = 10
860        schedulers = [None] * 2
861        single_targets = [0.05 * (0.9**x) for x in range(3)]
862        single_targets += [0.005 * (0.9**x) for x in range(3, 6)]
863        single_targets += [0.0005 * (0.9**x) for x in range(6, 9)]
864        single_targets += [0.00005 * (0.9**x) for x in range(9, 12)]
865        targets = [single_targets, [x * epochs for x in single_targets]]
866        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
867        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
868        self._test(schedulers, targets, epochs)
869
870    def test_compound_exp_and_multistep_lr(self):
871        epochs = 10
872        schedulers = [None] * 2
873        single_targets = [0.05 * (0.9**x) for x in range(2)]
874        single_targets += [0.005 * (0.9**x) for x in range(2, 5)]
875        single_targets += [0.0005 * (0.9**x) for x in range(5, 9)]
876        single_targets += [0.00005 * (0.9**x) for x in range(9, 11)]
877        targets = [single_targets, [x * epochs for x in single_targets]]
878        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
879        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
880        self._test(schedulers, targets, epochs)
881
882    def test_compound_exp_and_linearlr(self):
883        epochs = 10
884        iters = 4
885        start_factor = 0.4
886        end_factor = 0.9
887        schedulers = [None] * 2
888        single_targets = [0.05 * (0.9**x) for x in range(11)]
889        for i in range(iters):
890            single_targets[i] *= start_factor + i / iters * (end_factor - start_factor)
891        for i in range(iters, 11):
892            single_targets[i] *= end_factor
893        targets = [single_targets, [x * epochs for x in single_targets]]
894        schedulers[0] = LinearLR(
895            self.opt,
896            start_factor=start_factor,
897            end_factor=end_factor,
898            total_iters=iters,
899        )
900        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
901        self._test(schedulers, targets, epochs)
902
903    def test_compound_step_and_constantlr(self):
904        epochs = 10
905        iters = 4
906        factor = 0.4
907        schedulers = [None] * 2
908        single_targets = (
909            [0.05 * 0.4] * 3
910            + [0.005 * 0.4]
911            + [0.005] * 2
912            + [0.0005] * 3
913            + [0.00005] * 3
914        )
915        targets = [single_targets, [x * epochs for x in single_targets]]
916        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
917        schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4)
918        self._test(schedulers, targets, epochs)
919
920    def test_compound_linearlr_and_multistep_lr(self):
921        epochs = 10
922        iters = 4
923        start_factor = 0.4
924        schedulers = [None] * 2
925        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2
926        for i in range(iters):
927            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
928        targets = [single_targets, [x * epochs for x in single_targets]]
929        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
930        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
931        self._test(schedulers, targets, epochs)
932
933    def test_compound_cosanneal_and_step_lr(self):
934        epochs = 10
935        eta_min = 1e-10
936        single_targets = [
937            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
938            for x in range(epochs)
939        ]
940        single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)]
941        targets = [single_targets, [x * epochs for x in single_targets]]
942        schedulers = [None] * 2
943        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
944        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
945        self._test(schedulers, targets, epochs)
946
947    def test_compound_cosanneal_and_multistep_lr(self):
948        epochs = 10
949        eta_min = 1e-10
950        single_targets = [
951            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
952            for x in range(epochs)
953        ]
954        multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001]
955        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
956        targets = [single_targets, [x * epochs for x in single_targets]]
957        schedulers = [None] * 2
958        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
959        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
960        self._test(schedulers, targets, epochs)
961
962    def test_compound_cosanneal_and_linearlr(self):
963        epochs = 10
964        iters = 4
965        start_factor = 0.4
966        eta_min = 1e-10
967        schedulers = [None] * 2
968        single_targets = [
969            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
970            for x in range(epochs)
971        ]
972        for i in range(iters):
973            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
974        targets = [single_targets, [x * epochs for x in single_targets]]
975        schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
976        schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
977        self._test(schedulers, targets, epochs)
978
979    def test_compound_cosanneal_and_exp_lr(self):
980        epochs = 10
981        eta_min = 1e-10
982        single_targets = [
983            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
984            for x in range(epochs)
985        ]
986        multipliers = [0.1**i for i in range(epochs)]
987        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
988        targets = [single_targets, [x * epochs for x in single_targets]]
989        schedulers = [None] * 2
990        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
991        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
992        self._test(schedulers, targets, epochs)
993
994    def test_compound_reduce_lr_on_plateau1(self):
995        epochs = 10
996        for param_group in self.opt.param_groups:
997            param_group["lr"] = 0.5
998        single_targets = [0.5] * 20
999        multipliers = [0.1 ** (i // 3) for i in range(20)]
1000        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
1001        targets = [single_targets]
1002        targets = targets[1:]  # test runs step before checking lr
1003        metrics = [10 - i * 0.0167 for i in range(20)]
1004        schedulers = [None, None]
1005        schedulers[0] = ReduceLROnPlateau(
1006            self.opt,
1007            threshold_mode="abs",
1008            mode="min",
1009            threshold=0.01,
1010            patience=5,
1011            cooldown=5,
1012        )
1013        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
1014        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1015
1016    def test_compound_reduce_lr_on_plateau2(self):
1017        epochs = 22
1018        for param_group in self.opt.param_groups:
1019            param_group["lr"] = 0.5
1020        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
1021        multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10
1022        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
1023        targets = [single_targets]
1024        targets = targets[1:]  # test runs step before checking lr
1025        metrics = [10 - i * 0.0165 for i in range(22)]
1026        schedulers = [None] * 2
1027        schedulers[0] = ReduceLROnPlateau(
1028            self.opt,
1029            patience=5,
1030            cooldown=0,
1031            threshold_mode="abs",
1032            mode="min",
1033            threshold=0.1,
1034        )
1035        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12])
1036        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1037
1038    def test_compound_reduce_lr_on_plateau3(self):
1039        epochs = 22
1040        for param_group in self.opt.param_groups:
1041            param_group["lr"] = 0.5
1042        single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4
1043        multipliers = [0.1**i for i in range(epochs)]
1044        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
1045        targets = [single_targets]
1046        targets = targets[1:]  # test runs step before checking lr
1047        metrics = [-0.8] * 2 + [-0.234] * 20
1048        schedulers = [None, None]
1049        schedulers[0] = ReduceLROnPlateau(
1050            self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
1051        )
1052        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
1053        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1054
1055    def test_compound_reduce_lr_on_plateau4(self):
1056        epochs = 20
1057        for param_group in self.opt.param_groups:
1058            param_group["lr"] = 0.05
1059        epochs = 10
1060        eta_min = 1e-10
1061        single_targets = [
1062            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
1063            for x in range(epochs)
1064        ]
1065        targets = [single_targets]
1066        targets = targets[1:]  # test runs step before checking lr
1067        metrics = [1.5 * (1.025**i) for i in range(20)]  # 1.025 > 1.1**0.25
1068        schedulers = [None, None]
1069        schedulers[0] = ReduceLROnPlateau(
1070            self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
1071        )
1072        schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
1073        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1074
1075    def test_compound_reduce_lr_on_plateau5(self):
1076        iters = 4
1077        start_factor = 0.4
1078        epochs = 22
1079        for param_group in self.opt.param_groups:
1080            param_group["lr"] = 0.5
1081        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
1082        multipliers = [1] * 22
1083        for i in range(iters):
1084            multipliers[i] *= start_factor + i / iters * (1 - start_factor)
1085        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
1086        targets = [single_targets]
1087        targets = targets[1:]  # test runs step before checking lr
1088        metrics = [10 - i * 0.0165 for i in range(22)]
1089        schedulers = [None] * 2
1090        schedulers[0] = ReduceLROnPlateau(
1091            self.opt,
1092            patience=5,
1093            cooldown=0,
1094            threshold_mode="abs",
1095            mode="min",
1096            threshold=0.1,
1097        )
1098        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
1099        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
1100
1101    def test_cycle_lr_invalid_mode(self):
1102        with self.assertRaises(ValueError):
1103            scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS")
1104
1105    def test_cycle_lr_triangular_mode_one_lr(self):
1106        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1107        momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
1108        lr_targets = [lr_target, lr_target]
1109        momentum_targets = [momentum_target, momentum_target]
1110        scheduler = CyclicLR(
1111            self.opt,
1112            base_lr=1,
1113            max_lr=5,
1114            step_size_up=4,
1115            cycle_momentum=True,
1116            base_momentum=1,
1117            max_momentum=5,
1118            mode="triangular",
1119        )
1120        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1121
1122    def test_cycle_lr_triangular_mode_one_lr_no_momentum(self):
1123        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1124        lr_targets = [lr_target, lr_target]
1125        momentum_target = [self.opt.defaults["momentum"]] * len(lr_target)
1126        momentum_targets = [momentum_target, momentum_target]
1127        scheduler = CyclicLR(
1128            self.opt,
1129            base_lr=1,
1130            max_lr=5,
1131            step_size_up=4,
1132            cycle_momentum=False,
1133            mode="triangular",
1134        )
1135        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1136
1137    def test_cycle_lr_triangular2_mode_one_lr(self):
1138        lr_target = [
1139            1,
1140            2,
1141            3,
1142            4,
1143            5,
1144            4,
1145            3,
1146            2,
1147            1,
1148            1.5,
1149            2.0,
1150            2.5,
1151            3.0,
1152            2.5,
1153            2.0,
1154            1.5,
1155            1,
1156            1.25,
1157            1.50,
1158            1.75,
1159            2.00,
1160            1.75,
1161        ]
1162        momentum_target = [
1163            5.0,
1164            4.0,
1165            3.0,
1166            2.0,
1167            1.0,
1168            2.0,
1169            3.0,
1170            4.0,
1171            5.0,
1172            4.5,
1173            4.0,
1174            3.5,
1175            3.0,
1176            3.5,
1177            4.0,
1178            4.5,
1179            5.0,
1180            4.75,
1181            4.5,
1182            4.25,
1183            4.0,
1184            4.25,
1185        ]
1186        lr_targets = [lr_target, lr_target]
1187        momentum_targets = [momentum_target, momentum_target]
1188        scheduler = CyclicLR(
1189            self.opt,
1190            base_lr=1,
1191            max_lr=5,
1192            step_size_up=4,
1193            cycle_momentum=True,
1194            base_momentum=1,
1195            max_momentum=5,
1196            mode="triangular2",
1197        )
1198        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1199
1200    def test_cycle_lr_exp_range_mode_one_lr(self):
1201        base_lr, max_lr = 1, 5
1202        diff_lr = max_lr - base_lr
1203        gamma = 0.9
1204        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
1205        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
1206        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
1207        lr_targets = [lr_target, lr_target]
1208        momentum_targets = [momentum_target, momentum_target]
1209        scheduler = CyclicLR(
1210            self.opt,
1211            base_lr=base_lr,
1212            max_lr=max_lr,
1213            step_size_up=4,
1214            cycle_momentum=True,
1215            base_momentum=base_lr,
1216            max_momentum=max_lr,
1217            mode="exp_range",
1218            gamma=gamma,
1219        )
1220        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1221
1222    def test_cycle_lr_triangular_mode(self):
1223        lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1224        lr_target_2 = [x + 1 for x in lr_target_1]
1225        lr_targets = [lr_target_1, lr_target_2]
1226        momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
1227        momentum_target_2 = [x + 1 for x in momentum_target_1]
1228        momentum_targets = [momentum_target_1, momentum_target_2]
1229        scheduler = CyclicLR(
1230            self.opt,
1231            base_lr=[1, 2],
1232            max_lr=[5, 6],
1233            step_size_up=4,
1234            cycle_momentum=True,
1235            base_momentum=[1, 2],
1236            max_momentum=[5, 6],
1237            mode="triangular",
1238        )
1239        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
1240
1241    def test_cycle_lr_triangular2_mode(self):
1242        lr_target_1 = [
1243            1,
1244            2,
1245            3,
1246            4,
1247            5,
1248            4,
1249            3,
1250            2,
1251            1,
1252            1.5,
1253            2.0,
1254            2.5,
1255            3.0,
1256            2.5,
1257            2.0,
1258            1.5,
1259            1,
1260            1.25,
1261            1.50,
1262            1.75,
1263            2.00,
1264            1.75,
1265        ]
1266        lr_target_2 = [x + 2 for x in lr_target_1]
1267        lr_targets = [lr_target_1, lr_target_2]
1268        momentum_target_1 = [
1269            5.0,
1270            4.0,
1271            3.0,
1272            2.0,
1273            1.0,
1274            2.0,
1275            3.0,
1276            4.0,
1277            5.0,
1278            4.5,
1279            4.0,
1280            3.5,
1281            3.0,
1282            3.5,
1283            4.0,
1284            4.5,
1285            5.0,
1286            4.75,
1287            4.5,
1288            4.25,
1289            4.0,
1290            4.25,
1291        ]
1292        momentum_target_2 = [x + 2 for x in momentum_target_1]
1293        momentum_targets = [momentum_target_1, momentum_target_2]
1294        scheduler = CyclicLR(
1295            self.opt,
1296            base_lr=[1, 3],
1297            max_lr=[5, 7],
1298            step_size_up=4,
1299            cycle_momentum=True,
1300            base_momentum=[1, 3],
1301            max_momentum=[5, 7],
1302            mode="triangular2",
1303        )
1304        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
1305
1306    def test_cycle_lr_exp_range_mode(self):
1307        base_lr_1, max_lr_1 = 1, 5
1308        base_lr_2, max_lr_2 = 5, 12
1309
1310        diff_lr_1 = max_lr_1 - base_lr_1
1311        diff_lr_2 = max_lr_2 - base_lr_2
1312
1313        gamma = 0.9
1314        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
1315        lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
1316        lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
1317        lr_targets = [lr_target_1, lr_target_2]
1318        momentum_target_1 = [
1319            max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)
1320        ]
1321        momentum_target_2 = [
1322            max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)
1323        ]
1324        momentum_targets = [momentum_target_1, momentum_target_2]
1325        scheduler = CyclicLR(
1326            self.opt,
1327            base_lr=[base_lr_1, base_lr_2],
1328            max_lr=[max_lr_1, max_lr_2],
1329            step_size_up=4,
1330            cycle_momentum=True,
1331            base_momentum=[base_lr_1, base_lr_2],
1332            max_momentum=[max_lr_1, max_lr_2],
1333            mode="exp_range",
1334            gamma=gamma,
1335        )
1336        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
1337
1338    def test_cycle_lr_triangular_mode_step_size_up_down(self):
1339        lr_target = [
1340            1.0,
1341            2.0,
1342            3.0,
1343            4.0,
1344            5.0,
1345            13.0 / 3,
1346            11.0 / 3,
1347            9.0 / 3,
1348            7.0 / 3,
1349            5.0 / 3,
1350            1.0,
1351        ]
1352        lr_targets = [lr_target, lr_target]
1353        momentum_target = [
1354            5.0,
1355            4.0,
1356            3.0,
1357            2.0,
1358            1.0,
1359            5.0 / 3,
1360            7.0 / 3,
1361            3.0,
1362            11.0 / 3,
1363            13.0 / 3,
1364            5.0,
1365        ]
1366        momentum_targets = [momentum_target, momentum_target]
1367
1368        scheduler = CyclicLR(
1369            self.opt,
1370            base_lr=1,
1371            max_lr=5,
1372            step_size_up=4,
1373            step_size_down=6,
1374            cycle_momentum=True,
1375            base_momentum=1,
1376            max_momentum=5,
1377            mode="triangular",
1378        )
1379        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1380
1381    def test_cycle_lr_triangular2_mode_step_size_up_down(self):
1382        lr_base_target = [
1383            1.0,
1384            3.0,
1385            5.0,
1386            13.0 / 3,
1387            11.0 / 3,
1388            9.0 / 3,
1389            7.0 / 3,
1390            5.0 / 3,
1391            1.0,
1392            2.0,
1393            3.0,
1394            8.0 / 3,
1395            7.0 / 3,
1396            6.0 / 3,
1397            5.0 / 3,
1398            4.0 / 3,
1399            1.0,
1400            3.0 / 2,
1401            2.0,
1402            11.0 / 6,
1403            10.0 / 6,
1404            9.0 / 6,
1405            8.0 / 6,
1406            7.0 / 6,
1407        ]
1408        momentum_base_target = [
1409            5.0,
1410            3.0,
1411            1.0,
1412            5.0 / 3,
1413            7.0 / 3,
1414            3.0,
1415            11.0 / 3,
1416            13.0 / 3,
1417            5.0,
1418            4.0,
1419            3.0,
1420            10.0 / 3,
1421            11.0 / 3,
1422            4.0,
1423            13.0 / 3,
1424            14.0 / 3,
1425            5.0,
1426            4.5,
1427            4.0,
1428            25.0 / 6,
1429            13.0 / 3,
1430            4.5,
1431            14.0 / 3,
1432            29.0 / 6,
1433        ]
1434        deltas = [2 * i for i in range(0, 2)]
1435        base_lrs = [1 + delta for delta in deltas]
1436        max_lrs = [5 + delta for delta in deltas]
1437        lr_targets = [[x + delta for x in lr_base_target] for delta in deltas]
1438        momentum_targets = [
1439            [x + delta for x in momentum_base_target] for delta in deltas
1440        ]
1441        scheduler = CyclicLR(
1442            self.opt,
1443            base_lr=base_lrs,
1444            max_lr=max_lrs,
1445            step_size_up=2,
1446            step_size_down=6,
1447            cycle_momentum=True,
1448            base_momentum=base_lrs,
1449            max_momentum=max_lrs,
1450            mode="triangular2",
1451        )
1452        self._test_cycle_lr(
1453            scheduler, lr_targets, momentum_targets, len(lr_base_target)
1454        )
1455
1456    def test_cycle_lr_exp_range_mode_step_size_up_down(self):
1457        base_lr, max_lr = 1, 5
1458        diff_lr = max_lr - base_lr
1459        gamma = 0.9
1460        xs = [
1461            0.0,
1462            0.5,
1463            1.0,
1464            5.0 / 6,
1465            4.0 / 6,
1466            3.0 / 6,
1467            2.0 / 6,
1468            1.0 / 6,
1469            0.0,
1470            0.5,
1471            1.0,
1472            5.0 / 6,
1473            4.0 / 6,
1474        ]
1475        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
1476        lr_targets = [lr_target, lr_target]
1477        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
1478        momentum_targets = [momentum_target, momentum_target]
1479        scheduler = CyclicLR(
1480            self.opt,
1481            base_lr=base_lr,
1482            max_lr=max_lr,
1483            step_size_up=2,
1484            step_size_down=6,
1485            cycle_momentum=True,
1486            base_momentum=base_lr,
1487            max_momentum=max_lr,
1488            mode="exp_range",
1489            gamma=gamma,
1490        )
1491        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1492
1493    def test_cycle_lr_with_momentumless_optimizer(self):
1494        # Note [Temporarily set optimizer to Adam]
1495        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1496        # The TestLRScheduler object carries around an SGD optimizer to avoid having to
1497        # instantiate one for every test. This gets in the way for our very specific case
1498        # in which we need to use Adam (or really any optimizer that doesn't use momentum)
1499        # in order to test that the momentum bug in CyclicLR is fixed (the bug is described
1500        # in more detail in https://github.com/pytorch/pytorch/issues/19003 ).
1501        old_opt = self.opt
1502        self.opt = Adam(
1503            [
1504                {"params": self.net.conv1.parameters()},
1505                {"params": self.net.conv2.parameters(), "lr": 0.5},
1506            ],
1507            lr=0.05,
1508        )
1509
1510        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
1511        lr_targets = [lr_target, lr_target]
1512        momentum_target = [None] * len(lr_target)
1513        momentum_targets = [momentum_target, momentum_target]
1514        scheduler = CyclicLR(
1515            self.opt,
1516            base_lr=1,
1517            max_lr=5,
1518            step_size_up=4,
1519            cycle_momentum=False,
1520            mode="triangular",
1521        )
1522        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
1523
1524        self.opt = old_opt  # set optimizer back to SGD
1525
1526    def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
1527        with self.assertRaises(ValueError):
1528            rprop_opt = Rprop(self.net.parameters())
1529            scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)
1530
1531    def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
1532        adam_opt = Adam(self.net.parameters())
1533        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
1534
1535    def test_cycle_lr_removed_after_out_of_scope(self):
1536        import gc
1537        import weakref
1538
1539        gc.disable()
1540
1541        def test():
1542            adam_opt = Adam(self.net.parameters())
1543            scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
1544            return weakref.ref(scheduler)
1545
1546        ref = test()
1547        assert ref() is None
1548        gc.enable()
1549
1550    def test_cycle_lr_state_dict_picklable(self):
1551        adam_opt = Adam(self.net.parameters())
1552
1553        # Case 1: Built-in mode
1554        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
1555        self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType)
1556        state = scheduler.state_dict()
1557        self.assertNotIn("_scale_fn_ref", state)
1558        self.assertIs(state["_scale_fn_custom"], None)
1559        pickle.dumps(state)
1560
1561        # Case 2: Custom `scale_fn`, a function object
1562        def scale_fn(_):
1563            return 0.5
1564
1565        scheduler = CyclicLR(
1566            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1567        )
1568        state = scheduler.state_dict()
1569        self.assertNotIn("_scale_fn_ref", state)
1570        self.assertIs(state["_scale_fn_custom"], None)
1571        pickle.dumps(state)
1572
1573        # Case 3: Custom `scale_fn`, a callable class
1574        class ScaleFn:
1575            def __init__(self) -> None:
1576                self.x = 0.5
1577
1578            def __call__(self, _):
1579                return self.x
1580
1581        scale_fn = ScaleFn()
1582
1583        scheduler = CyclicLR(
1584            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1585        )
1586        state = scheduler.state_dict()
1587        self.assertNotIn("_scale_fn_ref", state)
1588        self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__)
1589        pickle.dumps(state)
1590
1591    def test_cycle_lr_scale_fn_restored_from_state_dict(self):
1592        adam_opt = Adam(self.net.parameters())
1593
1594        # Case 1: Built-in mode
1595        scheduler = CyclicLR(
1596            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, mode="triangular2"
1597        )
1598        restored_scheduler = CyclicLR(
1599            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False
1600        )
1601        restored_scheduler.load_state_dict(scheduler.state_dict())
1602        self.assertTrue(restored_scheduler.mode == scheduler.mode == "triangular2")
1603        self.assertIsNotNone(restored_scheduler._scale_fn_ref) and self.assertIsNotNone(
1604            scheduler._scale_fn_ref
1605        )
1606        self.assertIs(restored_scheduler._scale_fn_custom, None)
1607        self.assertIs(scheduler._scale_fn_custom, None)
1608
1609        # Case 2: Custom `scale_fn`
1610        def scale_fn(_):
1611            return 0.5
1612
1613        scheduler = CyclicLR(
1614            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1615        )
1616        restored_scheduler = CyclicLR(
1617            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
1618        )
1619        restored_scheduler.load_state_dict(scheduler.state_dict())
1620        self.assertIs(scheduler._scale_fn_custom, scale_fn)
1621        self.assertIs(restored_scheduler._scale_fn_custom, scale_fn)
1622
1623    def test_onecycle_lr_invalid_anneal_strategy(self):
1624        with self.assertRaises(ValueError):
1625            scheduler = OneCycleLR(
1626                self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS"
1627            )
1628
1629    def test_onecycle_lr_invalid_pct_start(self):
1630        with self.assertRaises(ValueError):
1631            scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)
1632
1633    def test_onecycle_lr_cannot_calculate_total_steps(self):
1634        with self.assertRaises(ValueError):
1635            scheduler = OneCycleLR(self.opt, max_lr=1e-3)
1636
1637    def test_onecycle_lr_linear_annealing(self):
1638        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
1639        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
1640        lr_targets = [lr_target, lr_target]
1641        momentum_targets = [momentum_target, momentum_target]
1642        scheduler = OneCycleLR(
1643            self.opt,
1644            max_lr=25,
1645            final_div_factor=2,
1646            base_momentum=1,
1647            max_momentum=22,
1648            total_steps=10,
1649            anneal_strategy="linear",
1650        )
1651        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1652
1653    def test_onecycle_lr_linear_annealing_three_phases(self):
1654        lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25]
1655        momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22]
1656        lr_targets = [lr_target, lr_target]
1657        momentum_targets = [momentum_target, momentum_target]
1658        scheduler = OneCycleLR(
1659            self.opt,
1660            max_lr=25,
1661            div_factor=25,
1662            base_momentum=1,
1663            max_momentum=22,
1664            total_steps=10,
1665            anneal_strategy="linear",
1666            pct_start=0.4,
1667            final_div_factor=4,
1668            three_phase=True,
1669        )
1670        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1671
1672    def test_onecycle_lr_cosine_annealing(self):
1673        def annealing_cos(start, end, pct):
1674            cos_out = math.cos(math.pi * pct) + 1
1675            return end + (start - end) / 2.0 * cos_out
1676
1677        lr_target = [
1678            1,
1679            13,
1680            25,
1681            annealing_cos(25, 0.5, 1 / 7.0),
1682            annealing_cos(25, 0.5, 2 / 7.0),
1683            annealing_cos(25, 0.5, 3 / 7.0),
1684            annealing_cos(25, 0.5, 4 / 7.0),
1685            annealing_cos(25, 0.5, 5 / 7.0),
1686            annealing_cos(25, 0.5, 6 / 7.0),
1687            0.5,
1688        ]
1689        momentum_target = [
1690            22,
1691            11.5,
1692            1,
1693            annealing_cos(1, 22, 1 / 7.0),
1694            annealing_cos(1, 22, 2 / 7.0),
1695            annealing_cos(1, 22, 3 / 7.0),
1696            annealing_cos(1, 22, 4 / 7.0),
1697            annealing_cos(1, 22, 5 / 7.0),
1698            annealing_cos(1, 22, 6 / 7.0),
1699            22,
1700        ]
1701        lr_targets = [lr_target, lr_target]
1702        momentum_targets = [momentum_target, momentum_target]
1703        scheduler = OneCycleLR(
1704            self.opt,
1705            max_lr=25,
1706            final_div_factor=2,
1707            base_momentum=1,
1708            max_momentum=22,
1709            total_steps=10,
1710        )
1711        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1712
1713    def test_onecycle_lr_legacy_state_dict(self):
1714        scheduler = OneCycleLR(
1715            self.opt,
1716            max_lr=25,
1717            final_div_factor=2,
1718            base_momentum=1,
1719            max_momentum=22,
1720            total_steps=10,
1721            anneal_strategy="cos",
1722        )
1723        delattr(scheduler, "_anneal_func_type")
1724        state_dict = scheduler.state_dict()
1725        self.assertNotIn("anneal_func_type", state_dict)
1726        state_dict["anneal_func"] = OneCycleLR._annealing_cos
1727        scheduler.load_state_dict(state_dict)
1728
1729        def annealing_cos(start, end, pct):
1730            cos_out = math.cos(math.pi * pct) + 1
1731            return end + (start - end) / 2.0 * cos_out
1732
1733        lr_target = [
1734            1,
1735            13,
1736            25,
1737            annealing_cos(25, 0.5, 1 / 7.0),
1738            annealing_cos(25, 0.5, 2 / 7.0),
1739            annealing_cos(25, 0.5, 3 / 7.0),
1740            annealing_cos(25, 0.5, 4 / 7.0),
1741            annealing_cos(25, 0.5, 5 / 7.0),
1742            annealing_cos(25, 0.5, 6 / 7.0),
1743            0.5,
1744        ]
1745        momentum_target = [
1746            22,
1747            11.5,
1748            1,
1749            annealing_cos(1, 22, 1 / 7.0),
1750            annealing_cos(1, 22, 2 / 7.0),
1751            annealing_cos(1, 22, 3 / 7.0),
1752            annealing_cos(1, 22, 4 / 7.0),
1753            annealing_cos(1, 22, 5 / 7.0),
1754            annealing_cos(1, 22, 6 / 7.0),
1755            22,
1756        ]
1757        lr_targets = [lr_target, lr_target]
1758        momentum_targets = [momentum_target, momentum_target]
1759        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
1760
1761    def test_cycle_lr_with_adam(self):
1762        old_opt = self.opt
1763        self.opt = Adam(
1764            [
1765                {"params": self.net.conv1.parameters()},
1766                {"params": self.net.conv2.parameters(), "lr": 0.5},
1767            ],
1768            lr=0.05,
1769        )
1770
1771        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
1772        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
1773        lr_targets = [lr_target, lr_target]
1774        momentum_targets = [momentum_target, momentum_target]
1775        scheduler = OneCycleLR(
1776            self.opt,
1777            max_lr=25,
1778            final_div_factor=2,
1779            base_momentum=1,
1780            max_momentum=22,
1781            total_steps=10,
1782            anneal_strategy="linear",
1783        )
1784        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
1785        self.opt = old_opt  # set optimizer back to SGD
1786
1787    def test_lambda_lr(self):
1788        epochs = 10
1789        self.opt.param_groups[0]["lr"] = 0.05
1790        self.opt.param_groups[1]["lr"] = 0.4
1791        targets = [
1792            [0.05 * (0.9**x) for x in range(epochs)],
1793            [0.4 * (0.8**x) for x in range(epochs)],
1794        ]
1795        scheduler = LambdaLR(
1796            self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2]
1797        )
1798        self._test(scheduler, targets, epochs)
1799
1800    def test_multiplicative_lr(self):
1801        epochs = 10
1802        self.opt.param_groups[0]["lr"] = 0.05
1803        self.opt.param_groups[1]["lr"] = 0.4
1804        targets = [
1805            [0.05 * (0.9**x) for x in range(epochs)],
1806            [0.4 * (0.8**x) for x in range(epochs)],
1807        ]
1808        scheduler = MultiplicativeLR(
1809            self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8]
1810        )
1811        self._test(scheduler, targets, epochs)
1812
1813    @parametrize("T_mult", [1, 2, 4])
1814    def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
1815        iters = 100
1816        eta_min = 1e-10
1817        T_i = 10
1818        T_cur = 0
1819        targets = [[0.05], [0.5]]
1820        scheduler = CosineAnnealingWarmRestarts(
1821            self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
1822        )
1823        for _ in range(1, iters, 1):
1824            T_cur += 1
1825            if T_cur >= T_i:
1826                T_cur = T_cur - T_i
1827                T_i = int(T_mult) * T_i
1828            targets[0] += [
1829                eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1830            ]
1831            targets[1] += [
1832                eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1833            ]
1834        self._test(scheduler, targets, iters)
1835
1836    def test_CosineAnnealingWarmRestarts_lr2(self):
1837        iters = 30
1838        eta_min = 1e-10
1839        T_mults = [1, 2, 4]
1840        for T_mult in T_mults:
1841            T_i = 10
1842            T_cur = 0
1843            targets = [[0.05], [0.5]]
1844            scheduler = CosineAnnealingWarmRestarts(
1845                self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
1846            )
1847            for _ in torch.arange(0.1, iters, 0.1):
1848                T_cur = round(T_cur + 0.1, 1)
1849                if T_cur >= T_i:
1850                    T_cur = T_cur - T_i
1851                    T_i = int(T_mult) * T_i
1852                targets[0] += [
1853                    eta_min
1854                    + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1855                ]
1856                targets[1] += [
1857                    eta_min
1858                    + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1859                ]
1860            self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters)
1861
1862    def test_CosineAnnealingWarmRestarts_lr3(self):
1863        epochs_for_T_mults = [
1864            [0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13],
1865            [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3],
1866            [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50],
1867        ]
1868        T_curs_for_T_mults = [
1869            [1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3],
1870            [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3],
1871            [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10],
1872        ]
1873        T_is_for_T_mults = [
1874            [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
1875            [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10],
1876            [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90],
1877        ]
1878        eta_min = 1e-10
1879        T_mults = [1, 2, 3]
1880        for epochs, T_mult, T_curs, T_is in zip(
1881            epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults
1882        ):
1883            targets = [[0.05], [0.5]]
1884            scheduler = CosineAnnealingWarmRestarts(
1885                self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min
1886            )
1887            for T_cur, T_i in zip(T_curs, T_is):
1888                targets[0] += [
1889                    eta_min
1890                    + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1891                ]
1892                targets[1] += [
1893                    eta_min
1894                    + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
1895                ]
1896            self._test_interleaved_CosineAnnealingWarmRestarts(
1897                scheduler, targets, epochs
1898            )
1899
1900    def test_swalr_no_anneal(self):
1901        epochs, swa_start, swa_lr = 10, 5, 0.01
1902        initial_lrs = [group["lr"] for group in self.opt.param_groups]
1903        targets = [
1904            [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1)
1905            for lr in initial_lrs
1906        ]
1907        swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr)
1908        self._test_swalr(swa_scheduler, None, targets, swa_start, epochs)
1909
1910    def test_swalr_cosine_anneal_after_multiplicative(self):
1911        # same swa_lr for different param_groups
1912        epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5
1913        mult_factor = 0.9
1914        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
1915        swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr)
1916
1917        def anneal_coef(t):
1918            if t + 1 >= anneal_epochs:
1919                return 0.0
1920            return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2
1921
1922        initial_lrs = [group["lr"] for group in self.opt.param_groups]
1923        targets_before_swa = [
1924            [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
1925        ]
1926        swa_epochs = epochs - swa_start - 1
1927        targets = [
1928            lrs
1929            + [
1930                lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
1931                for t in range(swa_epochs)
1932            ]
1933            for lrs in targets_before_swa
1934        ]
1935
1936        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)
1937
1938    def test_swalr_linear_anneal_after_multiplicative(self):
1939        # separate swa_lr for different param_groups
1940        epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4
1941        mult_factor = 0.9
1942        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
1943        swa_scheduler = SWALR(
1944            self.opt,
1945            anneal_epochs=anneal_epochs,
1946            anneal_strategy="linear",
1947            swa_lr=swa_lrs,
1948        )
1949
1950        def anneal_coef(t):
1951            if t + 1 >= anneal_epochs:
1952                return 0.0
1953            return 1 - (t + 1) / anneal_epochs
1954
1955        initial_lrs = [group["lr"] for group in self.opt.param_groups]
1956        targets_before_swa = [
1957            [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
1958        ]
1959        swa_epochs = epochs - swa_start - 1
1960        targets = [
1961            lrs
1962            + [
1963                lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
1964                for t in range(swa_epochs)
1965            ]
1966            for lrs, swa_lr in zip(targets_before_swa, swa_lrs)
1967        ]
1968
1969        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)
1970
1971    def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs):
1972        for epoch in range(epochs):
1973            for param_group, target in zip(self.opt.param_groups, targets):
1974                self.assertEqual(
1975                    target[epoch],
1976                    param_group["lr"],
1977                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
1978                        epoch, target[epoch], param_group["lr"]
1979                    ),
1980                    atol=1e-5,
1981                    rtol=0,
1982                )
1983            if epoch >= swa_start:
1984                self.opt.step()
1985                swa_scheduler.step()
1986            elif scheduler is not None:
1987                self.opt.step()
1988                scheduler.step()
1989
1990    def test_swalr_hypers(self):
1991        # Test that SWALR raises errors for incorrect hyper-parameters
1992        with self.assertRaisesRegex(ValueError, "anneal_strategy must"):
1993            swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0)
1994
1995        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
1996            swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0)
1997        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
1998            swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0)
1999        with self.assertRaisesRegex(ValueError, "swa_lr must"):
2000            swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01])
2001
2002    def test_step_lr_state_dict(self):
2003        self._check_scheduler_state_dict(
2004            lambda: StepLR(self.opt, gamma=0.1, step_size=3),
2005            lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1),
2006        )
2007
2008    def test_multi_step_lr_state_dict(self):
2009        self._check_scheduler_state_dict(
2010            lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
2011            lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]),
2012        )
2013
2014    def test_exp_step_lr_state_dict(self):
2015        self._check_scheduler_state_dict(
2016            lambda: ExponentialLR(self.opt, gamma=0.1),
2017            lambda: ExponentialLR(self.opt, gamma=0.01),
2018        )
2019
2020    def test_cosine_lr_state_dict(self):
2021        epochs = 10
2022        eta_min = 1e-10
2023        self._check_scheduler_state_dict(
2024            lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
2025            lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
2026            epochs=epochs,
2027        )
2028
2029    def test_reduce_lr_on_plateau_state_dict(self):
2030        scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2)
2031        for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]:
2032            scheduler.step(score)
2033        scheduler_copy = ReduceLROnPlateau(
2034            self.opt, mode="max", factor=0.5, patience=10
2035        )
2036        scheduler_copy.load_state_dict(scheduler.state_dict())
2037        for key in scheduler.__dict__.keys():
2038            if key not in {"optimizer", "is_better"}:
2039                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2040
2041    def test_lambda_lr_state_dict_fn(self):
2042        scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x)
2043        state = scheduler.state_dict()
2044        self.assertIsNone(state["lr_lambdas"][0])
2045
2046        scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x)
2047        scheduler_copy.load_state_dict(state)
2048        for key in scheduler.__dict__.keys():
2049            if key not in {"optimizer", "lr_lambdas"}:
2050                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2051
2052    def test_lambda_lr_state_dict_obj(self):
2053        scheduler = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(10))
2054        state = scheduler.state_dict()
2055        self.assertIsNotNone(state["lr_lambdas"][0])
2056
2057        scheduler_copy = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(-1))
2058        scheduler_copy.load_state_dict(state)
2059        for key in scheduler.__dict__.keys():
2060            if key not in {"optimizer"}:
2061                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2062
2063    def test_CosineAnnealingWarmRestarts_lr_state_dict(self):
2064        self._check_scheduler_state_dict(
2065            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2),
2066            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100),
2067        )
2068
2069    def test_swa_lr_state_dict(self):
2070        self._check_scheduler_state_dict(
2071            lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5),
2072            lambda: SWALR(
2073                self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0
2074            ),
2075        )
2076
2077    def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
2078        scheduler = constr()
2079        for _ in range(epochs):
2080            scheduler.optimizer.step()
2081            scheduler.step()
2082        scheduler_copy = constr2()
2083        scheduler_copy.load_state_dict(scheduler.state_dict())
2084        for key in scheduler.__dict__.keys():
2085            if key != "optimizer":
2086                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
2087        self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr())
2088
2089    def _test_get_last_lr(self, schedulers, targets, epochs=10):
2090        if isinstance(schedulers, LRScheduler):
2091            schedulers = [schedulers]
2092        optimizers = {scheduler.optimizer for scheduler in schedulers}
2093        for epoch in range(epochs):
2094            result = [scheduler.get_last_lr() for scheduler in schedulers]
2095            [optimizer.step() for optimizer in optimizers]
2096            [scheduler.step() for scheduler in schedulers]
2097            target = [[t[epoch] for t in targets]] * len(schedulers)
2098            for t, r in zip(target, result):
2099                self.assertEqual(
2100                    t,
2101                    r,
2102                    msg=f"LR is wrong in epoch {epoch}: expected {t}, got {r}",
2103                    atol=1e-5,
2104                    rtol=0,
2105                )
2106
2107    def _test_with_epoch(self, schedulers, targets, epochs=10):
2108        if isinstance(schedulers, LRScheduler):
2109            schedulers = [schedulers]
2110        optimizers = {scheduler.optimizer for scheduler in schedulers}
2111        for epoch in range(epochs):
2112            [optimizer.step() for optimizer in optimizers]
2113            with warnings.catch_warnings(record=True) as w:
2114                [
2115                    scheduler.step(epoch) for scheduler in schedulers
2116                ]  # step before assert: skip initial lr
2117                self._check_warning_is_epoch_deprecation_warning(
2118                    w, num_warnings=len(schedulers)
2119                )
2120            for param_group, target in zip(self.opt.param_groups, targets):
2121                self.assertEqual(
2122                    target[epoch],
2123                    param_group["lr"],
2124                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2125                        epoch, target[epoch], param_group["lr"]
2126                    ),
2127                    atol=1e-5,
2128                    rtol=0,
2129                )
2130
2131    def _test(self, schedulers, targets, epochs=10):
2132        if isinstance(schedulers, LRScheduler):
2133            schedulers = [schedulers]
2134        for epoch in range(epochs):
2135            for param_group, target in zip(self.opt.param_groups, targets):
2136                self.assertEqual(
2137                    target[epoch],
2138                    param_group["lr"],
2139                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2140                        epoch, target[epoch], param_group["lr"]
2141                    ),
2142                    atol=1e-5,
2143                    rtol=0,
2144                )
2145            [scheduler.step() for scheduler in schedulers]
2146
2147    def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10):
2148        for index, epoch in enumerate(torch.arange(0, epochs, 0.1)):
2149            epoch = round(epoch.item(), 1)
2150            scheduler.step(epoch)
2151            for param_group, target in zip(self.opt.param_groups, targets):
2152                self.assertEqual(
2153                    target[index],
2154                    param_group["lr"],
2155                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2156                        epoch, target[index], param_group["lr"]
2157                    ),
2158                    atol=1e-5,
2159                    rtol=0,
2160                )
2161
2162    def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs):
2163        for index, epoch in enumerate(epochs):
2164            scheduler.step(epoch)
2165            for param_group, target in zip(self.opt.param_groups, targets):
2166                self.assertEqual(
2167                    target[index],
2168                    param_group["lr"],
2169                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2170                        epoch, target[index], param_group["lr"]
2171                    ),
2172                    atol=1e-5,
2173                    rtol=0,
2174                )
2175
2176    def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10):
2177        self.setUp()
2178        targets = []
2179        for epoch in range(epochs):
2180            closed_form_scheduler.optimizer.step()
2181            with warnings.catch_warnings(record=True) as w:
2182                closed_form_scheduler.step(epoch)
2183                self._check_warning_is_epoch_deprecation_warning(w)
2184            targets.append([group["lr"] for group in self.opt.param_groups])
2185        self.setUp()
2186        for epoch in range(epochs):
2187            self.opt.step()
2188            scheduler.step()
2189            for i, param_group in enumerate(self.opt.param_groups):
2190                self.assertEqual(
2191                    targets[epoch][i],
2192                    param_group["lr"],
2193                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2194                        epoch, targets[epoch][i], param_group["lr"]
2195                    ),
2196                    atol=1e-5,
2197                    rtol=0,
2198                )
2199
2200    def _test_reduce_lr_on_plateau(
2201        self, schedulers, targets, metrics, epochs=10, verbose=False
2202    ):
2203        if isinstance(schedulers, (LRScheduler, ReduceLROnPlateau)):
2204            schedulers = [schedulers]
2205        for epoch in range(epochs):
2206            self.opt.step()
2207            for scheduler in schedulers:
2208                if isinstance(scheduler, ReduceLROnPlateau):
2209                    scheduler.step(metrics[epoch])
2210                else:
2211                    scheduler.step()
2212            if verbose:
2213                print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"]))
2214            for param_group, target in zip(self.opt.param_groups, targets):
2215                self.assertEqual(
2216                    target[epoch],
2217                    param_group["lr"],
2218                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
2219                        epoch, target[epoch], param_group["lr"]
2220                    ),
2221                    atol=1e-5,
2222                    rtol=0,
2223                )
2224
2225    def _test_cycle_lr(
2226        self,
2227        scheduler,
2228        lr_targets,
2229        momentum_targets,
2230        batch_iterations,
2231        verbose=False,
2232        use_beta1=False,
2233    ):
2234        for batch_num in range(batch_iterations):
2235            if verbose:
2236                if "momentum" in self.opt.param_groups[0].keys():
2237                    print(
2238                        "batch{}:\tlr={},momentum={}".format(
2239                            batch_num,
2240                            self.opt.param_groups[0]["lr"],
2241                            self.opt.param_groups[0]["momentum"],
2242                        )
2243                    )
2244                elif use_beta1 and "betas" in self.opt.param_groups[0].keys():
2245                    print(
2246                        "batch{}:\tlr={},beta1={}".format(
2247                            batch_num,
2248                            self.opt.param_groups[0]["lr"],
2249                            self.opt.param_groups[0]["betas"][0],
2250                        )
2251                    )
2252                else:
2253                    print(
2254                        "batch{}:\tlr={}".format(
2255                            batch_num, self.opt.param_groups[0]["lr"]
2256                        )
2257                    )
2258
2259            for param_group, lr_target, momentum_target in zip(
2260                self.opt.param_groups, lr_targets, momentum_targets
2261            ):
2262                self.assertEqual(
2263                    lr_target[batch_num],
2264                    param_group["lr"],
2265                    msg="LR is wrong in batch_num {}: expected {}, got {}".format(
2266                        batch_num, lr_target[batch_num], param_group["lr"]
2267                    ),
2268                    atol=1e-5,
2269                    rtol=0,
2270                )
2271
2272                if use_beta1 and "betas" in param_group.keys():
2273                    self.assertEqual(
2274                        momentum_target[batch_num],
2275                        param_group["betas"][0],
2276                        msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format(
2277                            batch_num,
2278                            momentum_target[batch_num],
2279                            param_group["betas"][0],
2280                        ),
2281                        atol=1e-5,
2282                        rtol=0,
2283                    )
2284                elif "momentum" in param_group.keys():
2285                    self.assertEqual(
2286                        momentum_target[batch_num],
2287                        param_group["momentum"],
2288                        msg="Momentum is wrong in batch_num {}: expected {}, got {}".format(
2289                            batch_num,
2290                            momentum_target[batch_num],
2291                            param_group["momentum"],
2292                        ),
2293                        atol=1e-5,
2294                        rtol=0,
2295                    )
2296            self.opt.step()
2297            scheduler.step()
2298
2299    def test_cosine_then_cyclic(self):
2300        # https://github.com/pytorch/pytorch/issues/21965
2301
2302        max_lr = 0.3
2303        base_lr = 0.1
2304        optim_lr = 0.5
2305
2306        model = torch.nn.Linear(2, 1)
2307        optimizer = SGD(model.parameters(), lr=optim_lr)
2308        lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(
2309            optimizer, T_max=20, eta_min=0.1
2310        )
2311        lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR(
2312            optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3
2313        )
2314
2315        for i in range(40):
2316            optimizer.step()
2317            if i <= lr_scheduler_1.T_max:
2318                lr_scheduler_1.step()
2319            else:
2320                lr_scheduler_2.step()
2321            last_lr = optimizer.param_groups[0]["lr"]
2322
2323        self.assertLessEqual(last_lr, max_lr)
2324
2325    @parametrize(
2326        "LRClass",
2327        [
2328            partial(LambdaLR, lr_lambda=lambda e: e // 10),
2329            partial(MultiplicativeLR, lr_lambda=lambda: 0.95),
2330            partial(StepLR, step_size=30),
2331            partial(MultiStepLR, milestones=[30, 80]),
2332            ConstantLR,
2333            LinearLR,
2334            partial(ExponentialLR, gamma=0.9),
2335            lambda opt, **kwargs: SequentialLR(
2336                opt,
2337                schedulers=[ConstantLR(opt), ConstantLR(opt)],
2338                milestones=[2],
2339                **kwargs,
2340            ),
2341            PolynomialLR,
2342            partial(CosineAnnealingLR, T_max=10),
2343            ReduceLROnPlateau,
2344            partial(CyclicLR, base_lr=0.01, max_lr=0.1),
2345            partial(CosineAnnealingWarmRestarts, T_0=20),
2346            partial(OneCycleLR, max_lr=0.01, total_steps=10),
2347        ],
2348    )
2349    def test_lr_scheduler_verbose_deprecation_warning(self, LRClass):
2350        """Check that a deprecating warning with verbose parameter."""
2351        with self.assertWarnsOnceRegex(
2352            UserWarning, "The verbose parameter is deprecated"
2353        ):
2354            LRClass(self.opt, verbose=True)
2355
2356        with self.assertWarnsOnceRegex(
2357            UserWarning, "The verbose parameter is deprecated"
2358        ):
2359            LRClass(self.opt, verbose=False)
2360
2361        # No warning is raised when verbose is the default value.
2362        with warnings.catch_warnings():
2363            warnings.simplefilter("error", UserWarning)
2364            LRClass(self.opt)
2365
2366    @parametrize(
2367        "LRClass",
2368        [
2369            partial(LambdaLR, lr_lambda=lambda e: e // 10),
2370            partial(MultiplicativeLR, lr_lambda=lambda: 0.95),
2371            partial(StepLR, step_size=30),
2372            partial(MultiStepLR, milestones=[30, 80]),
2373            ConstantLR,
2374            LinearLR,
2375            partial(ExponentialLR, gamma=0.9),
2376            PolynomialLR,
2377            partial(CosineAnnealingLR, T_max=10),
2378            lambda opt, **kwargs: ChainedScheduler(
2379                schedulers=[ConstantLR(opt), ConstantLR(opt)], **kwargs
2380            ),
2381            lambda opt, **kwargs: SequentialLR(
2382                opt,
2383                schedulers=[ConstantLR(opt), ConstantLR(opt)],
2384                milestones=[2],
2385                **kwargs,
2386            ),
2387            ReduceLROnPlateau,
2388            partial(CyclicLR, base_lr=0.01, max_lr=0.1),
2389            partial(OneCycleLR, max_lr=0.01, total_steps=10, anneal_strategy="linear"),
2390            partial(CosineAnnealingWarmRestarts, T_0=20),
2391        ],
2392    )
2393    @parametrize("weights_only", [True, False])
2394    def test_lr_scheduler_state_dict_load(self, LRClass, weights_only):
2395        scheduler = LRClass(self.opt)
2396        state_dict = scheduler.state_dict()
2397
2398        with tempfile.TemporaryFile() as f:
2399            torch.save(state_dict, f)
2400            f.seek(0)
2401            state_dict_loaded = torch.load(f, weights_only=weights_only)
2402            self.assertEqual(state_dict, state_dict_loaded)
2403            # Make sure state_dict can be loaded
2404            scheduler2 = LRClass(self.opt)
2405            scheduler2.load_state_dict(state_dict_loaded)
2406            self.assertEqual(scheduler2.state_dict(), state_dict)
2407
2408    @parametrize(
2409        "LRClass",
2410        [
2411            partial(LambdaLR, lr_lambda=lambda e: e // 10),
2412            partial(MultiplicativeLR, lr_lambda=lambda e: 0.95),
2413            partial(StepLR, step_size=30),
2414            partial(MultiStepLR, milestones=[30, 80]),
2415            ConstantLR,
2416            LinearLR,
2417            partial(ExponentialLR, gamma=0.9),
2418            PolynomialLR,
2419            partial(CosineAnnealingLR, T_max=10),
2420            partial(CosineAnnealingWarmRestarts, T_0=20),
2421        ],
2422    )
2423    def test_constant_initial_lr(self, LRClass):
2424        # Test that the initial learning rate is constant
2425        lr = torch.as_tensor(0.1)
2426        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2427        sch = LRClass(opt)
2428
2429        ori_param_groups = copy.deepcopy(opt.param_groups)
2430
2431        for i in range(2):
2432            opt.step()
2433            sch.step(i)
2434            lr.multiply_(0.1)
2435            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2436                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2437                self.assertEqual(sch.base_lrs, [0.1])
2438
2439    def test_constant_initial_params_cyclelr(self):
2440        # Test that the initial learning rate is constant
2441        lr = torch.as_tensor(0.1)
2442        max_lr = torch.as_tensor(0.2)
2443        base_momentum = torch.as_tensor(0.8)
2444        max_momentum = torch.as_tensor(0.9)
2445        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2446        sch = CyclicLR(
2447            opt,
2448            base_lr=lr,
2449            max_lr=max_lr,
2450            base_momentum=base_momentum,
2451            max_momentum=max_momentum,
2452        )
2453        ori_param_groups = copy.deepcopy(opt.param_groups)
2454
2455        for i in range(2):
2456            lr.multiply_(0.5)
2457            max_lr.multiply_(0.5)
2458            base_momentum.multiply_(0.5)
2459            max_momentum.multiply_(0.5)
2460            opt.step()
2461            sch.step(i)
2462            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2463                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2464                self.assertEqual(group["max_momentum"], ori_group["max_momentum"])
2465                self.assertEqual(group["base_momentum"], ori_group["base_momentum"])
2466                self.assertEqual(sch.base_lrs, [0.1])
2467                self.assertEqual(sch.max_lrs, [0.2])
2468                self.assertEqual(group["max_momentum"], 0.9)
2469                self.assertEqual(group["base_momentum"], 0.8)
2470
2471    def test_constant_initial_params_onecyclelr(self):
2472        # Test that the initial learning rate is constant
2473        lr = torch.as_tensor(0.1)
2474        base_momentum = torch.as_tensor(0.85)
2475        max_momentum = torch.as_tensor(0.95)
2476        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2477        sch = OneCycleLR(
2478            opt,
2479            max_lr=lr,
2480            total_steps=10,
2481            base_momentum=base_momentum,
2482            max_momentum=max_momentum,
2483        )
2484        ori_param_groups = copy.deepcopy(opt.param_groups)
2485
2486        for i in range(2):
2487            lr.multiply_(0.5)
2488            base_momentum.multiply_(0.5)
2489            max_momentum.multiply_(0.5)
2490            opt.step()
2491            sch.step(i)
2492
2493            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2494                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2495                self.assertEqual(group["max_lr"], ori_group["max_lr"])
2496                self.assertEqual(group["min_lr"], ori_group["min_lr"])
2497                self.assertEqual(group["max_momentum"], ori_group["max_momentum"])
2498                self.assertEqual(group["base_momentum"], ori_group["base_momentum"])
2499                self.assertEqual(group["max_momentum"], 0.95)
2500                self.assertEqual(group["base_momentum"], 0.85)
2501
2502    def test_constant_initial_params_swalr(self):
2503        # Test that the initial learning rate is constant
2504        lr = torch.as_tensor(0.1)
2505        swa_lr = torch.as_tensor(0.05)
2506        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
2507        sch = SWALR(opt, swa_lr=swa_lr)
2508        ori_param_groups = copy.deepcopy(opt.param_groups)
2509
2510        for i in range(2):
2511            lr.multiply_(0.5)
2512            swa_lr.multiply_(0.5)
2513            opt.step()
2514            sch.step()
2515            for group, ori_group in zip(opt.param_groups, ori_param_groups):
2516                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
2517                self.assertEqual(group["swa_lr"], ori_group["swa_lr"])
2518                self.assertEqual(group["swa_lr"], 0.05)
2519                self.assertEqual(sch.base_lrs, [0.1])
2520
2521
2522instantiate_parametrized_tests(TestLRScheduler)
2523
2524
2525if __name__ == "__main__":
2526    print("These tests should be run through test/test_optim.py instead")
2527