xref: /aosp_15_r20/external/pytorch/test/test_optim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: optimizer"]
2import functools
3import math
4import tempfile
5import unittest
6from copy import deepcopy
7from typing import Any, Dict, Tuple
8from unittest.mock import patch
9
10from optim.test_lrscheduler import TestLRScheduler  # noqa: F401
11from optim.test_optim import TestDifferentiableOptimizer  # noqa: F401
12from optim.test_swa_utils import TestSWAUtils  # noqa: F401
13
14import torch
15from torch.nn import Parameter
16from torch.optim import Optimizer, SGD
17from torch.optim.lr_scheduler import ReduceLROnPlateau
18from torch.optim.optimizer import (
19    register_optimizer_step_post_hook,
20    register_optimizer_step_pre_hook,
21)
22from torch.testing._internal.common_cuda import TEST_MULTIGPU
23from torch.testing._internal.common_device_type import (
24    instantiate_device_type_tests,
25    largeTensorTest,
26    onlyCPU,
27    onlyCUDA,
28    onlyNativeDeviceTypes,
29    skipMPS,
30    TEST_WITH_ROCM,
31)
32from torch.testing._internal.common_dtype import floating_types_and
33from torch.testing._internal.common_optimizers import (
34    _get_device_type,
35    _get_optim_inputs_including_global_cliquey_kwargs,
36    optim_db,
37    OptimizerErrorEnum,
38    optims,
39    TensorTracker,
40)
41from torch.testing._internal.common_utils import (
42    markDynamoStrictTest,
43    parametrize,
44    run_tests,
45    TEST_WITH_TORCHDYNAMO,
46    TestCase,
47)
48
49
50FP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4}
51
52
53def rosenbrock(tensor):
54    assert tensor.size() == torch.Size(
55        [2]
56    ), f"Requires tensor with 2 scalars but got {tensor.size()}"
57    x, y = tensor
58    return (1 - x) ** 2 + 100 * (y - x**2) ** 2
59
60
61def drosenbrock(tensor):
62    assert tensor.size() == torch.Size(
63        [2]
64    ), f"Requires tensor with 2 scalars but got {tensor.size()}"
65    x, y = tensor
66    return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
67
68
69@markDynamoStrictTest
70class TestOptimRenewed(TestCase):
71    """
72    This test class validates the core optimizers and is structured as the correctness of:
73    - The update algorithms (forloop implementation)
74        * Every optimizer's algorithm is most readably implemented through a big for-loop
75          over all the parameters, which is what we refer to as the forloop or single tensor
76          implementation. These algorithms are manually validated by comparing to the paper
77          and systematically validated by assuring that the loss goes the right direction
78          when the optimizer has been applied.
79        * This implementation should compose with optimizer hyperparameters well, such as
80          supporting Tensor LRs, the capturable API, and sparse and complex parameters.
81    - Each varying implementation
82        * We then have implementations that improve upon the performance of the forloop
83          implementation by leveraging fusion, namely our foreach (mult_tensor) and fused
84          implementations.
85        * These variations are validated numerically by comparing with the forloop version
86          of the optimizer. In fact, we test most variations this way--we see the forloop
87          implementation as the ground truth and expect that improvements to it in any way
88          should be just as correct.
89        * Both params and optimizer states should be validated numerically.
90    - state_dict APIs
91        * The optimizer instance should be serializable
92        * Calling save and load should be deterministic
93        * Moving between devices should be seamless
94        * BC - load_state_dict should be able to handle older optimizer states
95    - Hook APIs (everything should fire in the right order)
96    - LR Scheduler integration (composing should not error + should go the right direction)
97    - Parameter groups (should be equivalent to having multiple optimizers)
98    - Erroring (what should error should error)
99
100    We also cover different ways of generating parameters and grads:
101    - With parameters, we either generate them randomly given specific shapes or we take
102      them from a sample NN module.
103        * Variety is important here because NN modules have type Parameter and randomly
104          generated tensors have type Tensor.
105        * Parameters can be sparse for a subset of the optimizers (check out OptimizerInfo)
106        * Complex parameters should be handled using view_as_real
107        * Parameters can be spread across different devices and different dtypes for any
108          given optimizer
109        * Parameters can be contiguous and noncontiguous
110    - With grads, we follow suit from the parameters.
111        * Grads can also be None, empty, or zero-valued, and this should not disrupt training.
112    """
113
114    @onlyCPU
115    @optims(optim_db)
116    def test_optim_infos_do_not_specify_global_cliquey_kwargs(
117        self, device, dtype, optim_info
118    ):
119        global_cliquey_flags = ["foreach", "fused", "differentiable"]
120        for optim_input in optim_info.optim_inputs_func(device=device):
121            self.assertFalse(
122                any(f for f in global_cliquey_flags if f in optim_input.kwargs)
123            )
124
125    @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
126    def test_errors(self, device, dtype, optim_info):
127        optim_cls = optim_info.optim_cls
128        error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
129
130        for error_input in error_inputs:
131            optim_input = error_input.optimizer_error_input
132            params, kwargs = optim_input.params, optim_input.kwargs
133            if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
134                if issubclass(error_input.error_type, Warning):
135                    with self.assertWarnsRegex(
136                        error_input.error_type, error_input.error_regex
137                    ):
138                        optim_cls(params, **kwargs)
139                else:
140                    with self.assertRaisesRegex(
141                        error_input.error_type, error_input.error_regex
142                    ):
143                        optim_cls(params, **kwargs)
144            elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
145                optim = optim_cls(params, **kwargs)
146                if issubclass(error_input.error_type, Warning):
147                    with self.assertWarnsRegex(
148                        error_input.error_type, error_input.error_regex
149                    ):
150                        optim.step()
151                else:
152                    with self.assertRaisesRegex(
153                        error_input.error_type, error_input.error_regex
154                    ):
155                        optim.step()
156            else:
157                raise NotImplementedError(f"Unknown error type {error_input.error_on}")
158
159    @parametrize("contiguous", [True, False])
160    @parametrize("with_lrsched", [True, False])
161    @optims(optim_db, dtypes=[torch.float32])
162    def test_forloop_goes_right_direction(
163        self, device, dtype, optim_info, contiguous, with_lrsched
164    ):
165        optim_cls = optim_info.optim_cls
166        schedulers_constructors = (
167            optim_info.scheduler_inputs if with_lrsched else [None]
168        )
169
170        for schedulers_constructor in schedulers_constructors:
171            # with tensor LR we need fresh inputs for each scheduler
172            # or mutating it will carry across iters
173            optim_inputs = optim_info.optim_inputs_func(device=device)
174            for optim_input in optim_inputs:
175                if "foreach" in optim_info.supported_impls:
176                    optim_input.kwargs["foreach"] = False  # force forloop
177                if contiguous:
178                    weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
179                    bias = Parameter(torch.randn((10), device=device, dtype=dtype))
180                else:
181                    weight = Parameter(
182                        torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]
183                    )
184                    bias = Parameter(
185                        torch.randn((10, 2), device=device, dtype=dtype)[..., 0]
186                    )
187                input = torch.randn(5, device=device, dtype=dtype)
188
189                optimizer = optim_cls([weight, bias], **optim_input.kwargs)
190                schedulers = [
191                    s(optimizer)
192                    for s in (schedulers_constructor if schedulers_constructor else [])
193                ]
194
195                def closure():
196                    optimizer.zero_grad()
197                    loss = (weight.mv(input) + bias).pow(2).sum()
198                    loss.backward()
199                    if optim_info.only_supports_sparse_grads:
200                        # For this test, we naively convert the Tensor layout, which we know does
201                        # NOT represent the expected use case for optims like SparseAdam!
202                        weight.grad = weight.grad.to_sparse()
203                        bias.grad = bias.grad.to_sparse()
204                    return loss
205
206                initial_value = closure().item()
207                for _ in range(20):
208                    if optim_info.step_requires_closure:
209                        loss = optimizer.step(closure)
210                    else:
211                        loss = closure()
212                        optimizer.step()
213
214                    for scheduler in schedulers:
215                        if isinstance(scheduler, ReduceLROnPlateau):
216                            scheduler.step(loss)
217                        else:
218                            scheduler.step()
219
220                if optim_input.kwargs.get("maximize", False):
221                    self.assertGreater(closure().item(), initial_value)
222                else:
223                    self.assertLess(closure().item(), initial_value)
224
225    @onlyCUDA
226    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
227    @parametrize("with_lrsched", [True, False])
228    @optims(optim_db, dtypes=[torch.float32])
229    def test_forloop_goes_right_direction_multigpu(
230        self, device, dtype, optim_info, with_lrsched
231    ):
232        optim_cls = optim_info.optim_cls
233        schedulers_constructors = (
234            optim_info.scheduler_inputs if with_lrsched else [None]
235        )
236        for schedulers_constructor in schedulers_constructors:
237            # We need a fresh set of inputs if we have a tensor LR
238            # to not carry mutations across iterations.
239            optim_inputs = optim_info.optim_inputs_func(device=device)
240            for optim_input in optim_inputs:
241                if "foreach" in optim_info.supported_impls:
242                    optim_input.kwargs["foreach"] = False  # force forloop
243
244                weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
245                bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
246                inpt = torch.randn(5, device="cuda:0", dtype=dtype)
247
248                optimizer = optim_cls([weight, bias], **optim_input.kwargs)
249                schedulers = [
250                    s(optimizer)
251                    for s in (schedulers_constructor if schedulers_constructor else [])
252                ]
253
254                def closure():
255                    optimizer.zero_grad()
256                    loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
257                    loss.backward()
258                    if optim_info.only_supports_sparse_grads:
259                        # For this test, we naively convert the Tensor layout, which we know does
260                        # NOT represent the expected use case for optims like SparseAdam!
261                        weight.grad = weight.grad.to_sparse()
262                        bias.grad = bias.grad.to_sparse()
263                    return loss
264
265                initial_value = closure().item()
266                for _ in range(20):
267                    loss = optimizer.step(closure)
268                    for scheduler in schedulers:
269                        if isinstance(scheduler, ReduceLROnPlateau):
270                            scheduler.step(loss)
271                        else:
272                            scheduler.step()
273
274                if optim_input.kwargs.get("maximize", False):
275                    self.assertGreater(closure().item(), initial_value)
276                else:
277                    self.assertLess(closure().item(), initial_value)
278
279    @optims(optim_db, dtypes=[torch.float32])
280    def test_param_group_with_lrscheduler_goes_right_direction(
281        self, device, dtype, optim_info
282    ):
283        optim_cls = optim_info.optim_cls
284
285        for schedulers_c in optim_info.scheduler_inputs:
286            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
287            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
288            inpt = torch.randn(5, device=device, dtype=dtype)
289
290            # avoid endless recompiles by wrapping LR in a tensor if we're compiling
291            lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
292            optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
293            schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
294
295            def closure():
296                optimizer.zero_grad()
297                loss = (weight.mv(inpt) + bias).pow(2).sum()
298                loss.backward()
299                if optim_info.only_supports_sparse_grads:
300                    # For this test, we naively convert the Tensor layout, which we know does
301                    # NOT represent the expected use case for optims like SparseAdam!
302                    weight.grad = weight.grad.to_sparse()
303                    bias.grad = bias.grad.to_sparse()
304                return loss
305
306            initial_value = closure().item()
307            for _ in range(20):
308                loss = optimizer.step(closure)
309                for scheduler in schedulers:
310                    if isinstance(scheduler, ReduceLROnPlateau):
311                        scheduler.step(loss)
312                    else:
313                        scheduler.step()
314
315            self.assertLess(closure().item(), initial_value)
316
317    @optims(optim_db, dtypes=[torch.float32])
318    def test_tensor_lr(self, device, dtype, optim_info):
319        optim_cls = optim_info.optim_cls
320
321        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
322        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
323            device, dtype, optim_info, skip=("differentiable",)
324        )
325        for optim_input in all_optim_inputs:
326            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
327            weight_c = weight.clone().detach().requires_grad_(True)
328            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
329            bias_c = bias.clone().detach().requires_grad_(True)
330            inpt = torch.randn(5, device=device, dtype=dtype)
331
332            kwargs = optim_input.kwargs
333            if "lr" in kwargs:
334                del kwargs["lr"]
335
336            kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
337            optimizer_r = optim_cls([weight, bias], **kwargs)
338
339            try:
340                kwargs["lr"] = torch.tensor(kwargs["lr"])
341                optimizer = optim_cls([weight_c, bias_c], **kwargs)
342            except ValueError as e:
343                self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
344                continue
345
346            def closure(optim, w, b, i):
347                optim.zero_grad()
348                loss = (w.mv(i) + b).pow(2).sum()
349                loss.backward()
350                if optim_info.only_supports_sparse_grads:
351                    # For this test, we naively convert the Tensor layout, which we know does
352                    # NOT represent the expected use case for optims like SparseAdam!
353                    w.grad = w.grad.to_sparse()
354                    b.grad = b.grad.to_sparse()
355                return loss
356
357            for _ in range(5):
358                if optim_info.step_requires_closure:
359                    optimizer_r.step(
360                        functools.partial(closure, optimizer_r, weight, bias, inpt)
361                    )
362                    optimizer.step(
363                        functools.partial(closure, optimizer, weight_c, bias_c, inpt)
364                    )
365                else:
366                    closure(optimizer_r, weight, bias, inpt)
367                    closure(optimizer, weight_c, bias_c, inpt)
368
369                self.assertEqual(weight, weight_c)
370                self.assertEqual(bias, bias_c)
371
372    @parametrize("with_lrsched", [True, False])
373    @optims(
374        [o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads],
375        dtypes=[torch.float64],
376    )
377    def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):
378        optim_cls = optim_info.optim_cls
379
380        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
381        # Fused impls do not support sparse gradients
382        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
383            device, dtype, optim_info, skip=("differentiable", "fused")
384        )
385        kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse
386
387        if with_lrsched and len(schedulers_constructors) == 0:
388            return
389
390        supported_inputs = []
391        if len(kwarg_updates) != 0:
392            seen = set()
393            for i in all_optim_inputs:
394                for k in kwarg_updates:
395                    if k in i.kwargs:
396                        del i.kwargs[k]
397                hashable_kwargs = tuple(sorted(i.kwargs.items()))
398                if len(i.kwargs) > 0 and hashable_kwargs not in seen:
399                    supported_inputs.append(i)
400                    seen.add(hashable_kwargs)
401                    if "lr" in kwarg_updates:
402                        i.kwargs["lr"] = kwarg_updates["lr"]
403        else:
404            supported_inputs = all_optim_inputs
405
406        for optim_input in supported_inputs:
407            kwargs = optim_input.kwargs
408            multi_tensor = kwargs.get("foreach", False)
409
410            # For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
411            if multi_tensor:
412                params_t = [
413                    torch.tensor([1.5, 1.5]),
414                    torch.tensor([1.5, 1.5], dtype=dtype),
415                ]
416            else:
417                params_t = [torch.tensor([1.5, 1.5])]
418
419            params = [Parameter(param_t) for param_t in params_t]
420            optimizer = optim_cls(params, **kwargs)
421            schedulers = [
422                s(optimizer) for s in (schedulers_constructors if with_lrsched else [])
423            ]
424
425            if not optim_info.only_supports_sparse_grads:
426                params_c = [Parameter(param_t.clone()) for param_t in params_t]
427                optimizer_c = optim_cls(params_c, **kwargs)
428                schedulers_c = [
429                    s(optimizer_c)
430                    for s in (schedulers_constructors if with_lrsched else [])
431                ]
432
433            solution = torch.tensor([1, 1])
434            with torch.no_grad():
435                initial_dist = sum(param.dist(solution) for param in params)
436
437            def get_grad(param, sparse_grad, w):
438                grad = drosenbrock(param)
439                # NB: We torture test the optimizer by returning an
440                # uncoalesced sparse tensor
441
442                # Depending on w, provide only the x or y gradient
443                if sparse_grad:
444                    if w:
445                        i = torch.tensor([[0, 0]], dtype=torch.int64)
446                        x = grad[0]
447                        v = torch.tensor([x / 4.0, x - x / 4.0])
448                    else:
449                        i = torch.tensor([[1, 1]], dtype=torch.int64)
450                        y = grad[1]
451                        v = torch.tensor([y - y / 4.0, y / 4.0])
452                    grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
453                else:
454                    if w:
455                        grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
456                    else:
457                        grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
458                return grad_out
459
460            def eval(params, sparse_grad, w):
461                optimizer.zero_grad()
462                if multi_tensor:
463                    loss = sum(rosenbrock(param) for param in params)
464                else:
465                    loss = rosenbrock(params[0])
466                loss.backward()
467
468                grads_out = [get_grad(param, sparse_grad, w) for param in params]
469                with torch.no_grad():
470                    params[0].grad = grads_out[0]
471                    if multi_tensor:
472                        params[1].grad = grads_out[1].to(dtype=dtype)
473                return loss
474
475            for i in range(1800):
476                # Do cyclic coordinate descent
477                w = i % 2
478                optimizer.step(functools.partial(eval, params, True, w))
479                for scheduler in schedulers:
480                    if isinstance(scheduler, ReduceLROnPlateau):
481                        scheduler.step(rosenbrock(params[0]))
482                    else:
483                        scheduler.step()
484                if not optim_info.only_supports_sparse_grads:
485                    optimizer_c.step(functools.partial(eval, params_c, False, w))
486                    for scheduler in schedulers_c:
487                        if isinstance(scheduler, ReduceLROnPlateau):
488                            scheduler.step(rosenbrock(params_c[0]))
489                        else:
490                            scheduler.step()
491                    # Tolerance is increased due to floating point error from different
492                    # code path for dense case: x v.s. x - x / 4.0 + x / 4.0
493                    self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
494
495            if not kwargs.get("maximize", False):
496                self.assertLessEqual(
497                    sum(param.dist(solution) for param in params), initial_dist
498                )
499            else:
500                self.assertGreaterEqual(
501                    sum(rosenbrock(param) for param in params),
502                    sum(rosenbrock(param_t) for param_t in params_t),
503                )
504
505    @skipMPS
506    @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
507    def test_complex(self, device, dtype, optim_info):
508        optim_cls = optim_info.optim_cls
509        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
510        # Also skip fused, since our fused kernels do not support complex
511        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
512            device, dtype, optim_info, skip=("differentiable", "fused")
513        )
514        for optim_input in all_optim_inputs:
515            # Last param is intentionally real to test that we can mix real and complex
516            complex_params = [
517                torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
518                torch.randn(10, device=device, dtype=dtype, requires_grad=True),
519                torch.randn(
520                    10, 5, device=device, dtype=torch.float32, requires_grad=True
521                ),
522            ]
523            real_params = [
524                (
525                    torch.view_as_real(param).detach().clone().requires_grad_()
526                    if param.is_complex()
527                    else param.detach().clone().requires_grad_()
528                )
529                for param in complex_params
530            ]
531
532            complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
533            real_optimizer = optim_cls(real_params, **optim_input.kwargs)
534            real_steps = []
535            complex_steps = []
536            grads_losses = []
537
538            def real_closure():
539                for param in real_params:
540                    grad = torch.randn_like(param)
541                    param.grad = grad
542                    real_steps.append(param.detach().clone())
543                    grads_losses.append(grad.clone())
544                loss = torch.randn(1)
545                grads_losses.append(loss.clone())
546                return loss
547
548            def complex_closure():
549                for param in complex_params:
550                    if torch.is_complex(param):
551                        grad = torch.view_as_complex(grads_losses.pop(0))
552                        complex_steps.append(torch.view_as_real_copy(param.detach()))
553                    else:
554                        grad = grads_losses.pop(0)
555                        complex_steps.append(param.detach().clone())
556                    param.grad = grad
557                return grads_losses.pop(0)
558
559            for _ in range(3):
560                if optim_info.step_requires_closure:
561                    # LBFGS, for example, requires closure and calls it internally
562                    real_optimizer.step(real_closure)
563                    complex_optimizer.step(complex_closure)
564                else:
565                    # For other optimizers, we call closure explicitly to set the gradients
566                    real_closure()
567                    complex_closure()
568                    real_optimizer.step()
569                    complex_optimizer.step()
570
571            # Final Parameters should be the same
572            complex_params_asreal = [
573                torch.view_as_real(param) if param.is_complex() else param
574                for param in complex_params
575            ]
576            self.assertEqual(real_params, complex_params_asreal)
577
578            # All intermediate steps should also be the same
579            # also checks steps taken within for example a line search
580            self.assertEqual(complex_steps, real_steps)
581
582    @skipMPS
583    @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
584    def test_complex_2d(self, device, dtype, optim_info):
585        optim_cls = optim_info.optim_cls
586        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
587        # Also skip fused, since our fused kernels do not support complex
588        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
589            device, dtype, optim_info, skip=("differentiable", "fused")
590        )
591        for optim_input in all_optim_inputs:
592            if optim_info.step_requires_closure:
593                # Why? The way we implement complex is by turning complex params into view_as_real
594                # alternatives. For example, an size (M,N) tensor will become (M,N,2). In this test,
595                # we break apart a tensor into its real and imaginary parts, which would be 2x(M,N).
596                # For other pointwise optimizers, this distinction is trivial, but for LBFGS where
597                # there are reductions across all parameters (and all the grads get flattened into
598                # one long Tensor), this ordering matters. Why? Reductions are not deterministic
599                # because addition between floating point numbers is not associative, i.e.,
600                # a + b + c != a + c + b. Thus, we add a seed here to control the discrepancy that
601                # will happen with LBFGS. Note that in test_complex above, there is no need for a seed
602                # nor for increased tolerance, because results should be bitwise equivalent.
603                torch.manual_seed(2024)
604
605            a1 = torch.randn(2, device=device, dtype=dtype, requires_grad=True)
606            a1_real = a1.real.clone().detach()
607            a1_imag = a1.imag.clone().detach()
608            a1_real.requires_grad_()
609            a1_imag.requires_grad_()
610            optim1 = optim_cls([a1], **optim_input.kwargs)
611            optim2 = optim_cls([a1_real, a1_imag], **optim_input.kwargs)
612
613            a1_reals = TensorTracker()
614            a1_imags = TensorTracker()
615            a1_grad_reals = TensorTracker()
616            a1_grad_imags = TensorTracker()
617            losses = TensorTracker()
618
619            def closure1():
620                optim1.zero_grad()
621                loss = rosenbrock(a1).abs()
622                loss.backward()
623
624                # Track clones to best test accuracy
625                a1_reals.add(a1.real)
626                a1_imags.add(a1.imag)
627                a1_grad_reals.add(a1.grad.real)
628                a1_grad_imags.add(a1.grad.imag)
629
630                losses.add(loss)
631
632                return loss
633
634            def closure2():
635                optim2.zero_grad()
636                a1_reals.pop_check_set(a1_real, self)
637                a1_imags.pop_check_set(a1_imag, self)
638                a2 = torch.complex(a1_real, a1_imag)
639                loss = rosenbrock(a2).abs()
640                losses.pop_check_set(loss, self)
641                loss.backward()
642                a1_grad_reals.pop_check_set(a1_real.grad, self)
643                a1_grad_imags.pop_check_set(a1_imag.grad, self)
644                return loss
645
646            for _ in range(3):
647                if optim_info.step_requires_closure:
648                    # LBFGS, for example, requires closure and calls it internally
649                    optim1.step(closure1)
650                    optim2.step(closure2)
651                else:
652                    closure1()
653                    closure2()
654                    optim1.step()
655                    optim2.step()
656
657                self.assertEqual(a1.real, a1_real)
658                self.assertEqual(a1.imag, a1_imag)
659
660            self.assertTrue(a1_reals.all_popped())
661            self.assertTrue(a1_imags.all_popped())
662            self.assertTrue(a1_grad_reals.all_popped())
663            self.assertTrue(a1_grad_imags.all_popped())
664            self.assertTrue(losses.all_popped())
665
666    def _compare_between(
667        self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None
668    ):
669        # why 7? iteration 7 is where we start to see differences for RAdam
670        # params interacting with the small eps value, because that's right
671        # after rho_t becomes greater than 5 in step 6.
672        if assert_eq_kwargs is None:
673            assert_eq_kwargs = {}
674        kIterations = 7
675        tracker = TensorTracker(assert_eq_kwargs)
676        for i in range(kIterations):
677            state, updated_params = [], []
678            if not isinstance(inputs, list):
679                inputs = [inputs, inputs]
680            for input, model, optimizer in zip(inputs, models, optimizers):
681                optimizer.zero_grad()
682
683                if i == 3:
684                    # Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
685                    # is same as the step in 'forloop'.
686                    model[2].requires_grad_(False)
687                if i == 5:
688                    # Unfreeze the layer after 2 iters.
689                    model[2].requires_grad_(True)
690
691                # Test that step behaves as expected (a no-op) when grads are set to None
692                if i != 2:
693                    output = model(input)
694                    loss = output.sum()
695                    loss.backward()
696
697                optimizer.step()
698                state.append(optimizer.state)
699                updated_params.append(model.parameters())
700
701            og_state, new_state = state
702            for og_p, new_p in zip(updated_params[0], updated_params[1]):
703                tracker.add(og_p)
704                tracker.pop_check_set(new_p, self)
705
706                # check that optimizer states are the same
707                og_p_state = og_state[og_p]
708                new_p_state = new_state[new_p]
709                if assert_step_dtype is not None:
710                    if torch.is_tensor(og_p_state.get("step", None)):
711                        self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
712                    if torch.is_tensor(new_p_state.get("step", None)):
713                        self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
714                for k in og_p_state:
715                    tracker.add(og_p_state[k])
716                    tracker.pop_check_set(new_p_state[k], self)
717
718            self.assertTrue(tracker.all_popped())
719
720    def _test_derived_optimizers(
721        self,
722        device,
723        dtype,
724        optim_info,
725        flag,
726        reduced_precision=False,
727        assert_step_dtype=None,
728    ):
729        """
730        Given a flag 'fused' or 'foreach', test for parity of optimizer state
731        and updated parameters between when the flag is set to True and False
732        for provided optimizer configurations.
733        """
734        assert flag in ("foreach", "fused")
735        assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
736
737        optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
738        optim_cls = optim_info.optim_cls
739        for optim_input in optim_inputs:
740            models, optimizers = [], []
741            kwargs = deepcopy(optim_input.kwargs)
742            if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
743                # capturable is not supported on CPU
744                continue
745            for flag_value in (False, True):
746                kwargs[flag] = flag_value
747                input = torch.tensor(
748                    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
749                ).reshape(3, 2)
750
751                torch.manual_seed(1)
752                model = torch.nn.Sequential(
753                    torch.nn.Linear(2, 3),
754                    torch.nn.Sigmoid(),
755                    torch.nn.Linear(3, 1),
756                    torch.nn.Sigmoid(),
757                )
758                model.to(dtype=dtype, device=device)
759
760                # foreach/fused optimizers should be tested with a
761                # zero_size tensor as its last param.
762                # ref: https://github.com/pytorch/pytorch/issues/100701
763                empty_param = torch.empty(
764                    (), device=device, dtype=dtype, requires_grad=True
765                )
766                empty_param.grad = torch.rand_like(empty_param)
767                params = list(model.parameters()) + [empty_param]
768
769                optimizer = optim_cls(params, **kwargs)
770                models.append(model)
771                optimizers.append(optimizer)
772
773            self._compare_between(
774                input, models, optimizers, assert_eq_kwargs, assert_step_dtype
775            )
776
777    @skipMPS  # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
778    @optims(
779        [optim for optim in optim_db if "foreach" in optim.supported_impls],
780        dtypes=[torch.float64],
781    )
782    def test_foreach_matches_forloop(self, device, dtype, optim_info):
783        self._test_derived_optimizers(device, dtype, optim_info, "foreach")
784
785    @onlyCUDA
786    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
787    @parametrize("impl", ["foreach", "fused"])
788    @optims(
789        [
790            optim
791            for optim in optim_db
792            if "foreach" in optim.supported_impls or "fused" in optim.supported_impls
793        ]
794    )
795    def test_mixed_device_dtype(self, device, dtype, optim_info, impl):
796        """
797        Similar in essence to _test_derived_optimizers above. The main difference is that
798        _test_derived_optimizers uses model parameters whereas we randomly pass in
799        parameters of different dtypes and devices here. We need multiple GPUs (vs just a
800        CPU and GPU) because fused adam only works on GPUs. (Thus we only run the tests
801        that call into this helper when TEST_MULTIGPU.)
802        """
803        assert impl in ("foreach", "fused")
804        if impl == "foreach" and "foreach" not in optim_info.supported_impls:
805            return unittest.skip(
806                f"foreach not supported for {optim_info.optim_cls.__name__}"
807            )
808        elif impl == "fused" and "cuda" not in optim_info.supports_fused_on:
809            return unittest.skip(
810                f"fused not supported for {optim_info.optim_cls.__name__} on cuda"
811            )
812
813        params = [
814            torch.rand(2, 3, dtype=torch.float64, device="cuda:0", requires_grad=True),
815            torch.rand(2, 3, dtype=torch.float32, device="cuda:0", requires_grad=True),
816            torch.rand(2, 3, dtype=torch.float16, device="cuda:0", requires_grad=True),
817            torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:0", requires_grad=True),
818            torch.rand(2, 3, dtype=torch.float64, device="cuda:1", requires_grad=True),
819            torch.rand(2, 3, dtype=torch.float32, device="cuda:1", requires_grad=True),
820            torch.rand(2, 3, dtype=torch.float16, device="cuda:1", requires_grad=True),
821            torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:1", requires_grad=True),
822            torch.randint(
823                1024, (2, 3), dtype=torch.int64, device="cuda:1", requires_grad=False
824            ),
825        ]
826
827        for p in params:
828            if p.requires_grad:
829                p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
830
831        kIterations = 7 if impl == "foreach" else 1
832        optim_inputs = optim_info.optim_inputs_func(device=device)
833        optim_cls = optim_info.optim_cls
834        for optim_input in optim_inputs:
835            updated_params, state = [], []
836            kwargs = deepcopy(optim_input.kwargs)
837            if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
838                # capturable is not supported on CPU
839                continue
840            for use_impl in (False, True):
841                kwargs[impl] = use_impl
842                params_clone = []
843                for p in params:
844                    p_clone = p.clone().detach()
845                    if p.requires_grad:
846                        p_clone.requires_grad = True
847                        p_clone.grad = p.grad.clone().detach()
848                        params_clone.append(p_clone)
849
850                optimizer = optim_cls(params_clone, **kwargs)
851                for _ in range(kIterations):
852                    optimizer.step()
853
854                state.append(optimizer.state)
855                updated_params.append(params_clone)
856
857            og_state, new_state = state
858            for og_p, new_p in zip(updated_params[0], updated_params[1]):
859                # Increasing the tolerance as we are collating lots of ops together for optimizers and
860                # the designated tolerances are for single op only.
861                single_rtol, single_atol = torch.testing._comparison.get_tolerances(
862                    new_p.dtype, rtol=None, atol=None
863                )
864                rtol = 5 * single_rtol
865                atol = 5 * single_atol
866
867                self.assertEqual(og_p, new_p, rtol=rtol, atol=atol)
868
869                # check that optimizer states are the same
870                og_p_state = og_state[og_p]
871                new_p_state = new_state[new_p]
872
873                for k in og_p_state:
874                    actual = new_p_state[k]
875                    self.assertEqual(og_p_state[k], actual, rtol=rtol, atol=atol)
876
877    @onlyCUDA
878    @optims(
879        [optim for optim in optim_db if "foreach" in optim.supported_impls],
880        dtypes=[torch.float64],
881    )
882    def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info):
883        # https://github.com/pytorch/pytorch/issues/110940
884        # We coerce step to always be float32 unless the
885        # default dtype is higher prec float64
886        old_default_dtype = torch.get_default_dtype()
887        for default_dtype in [torch.float64, torch.float16]:
888            try:
889                torch.set_default_dtype(default_dtype)
890                self._test_derived_optimizers(
891                    device,
892                    dtype,
893                    optim_info,
894                    "foreach",
895                    reduced_precision=default_dtype == torch.float16,
896                    assert_step_dtype=(
897                        torch.float64
898                        if default_dtype == torch.float64
899                        else torch.float32
900                    ),
901                )
902            finally:
903                torch.set_default_dtype(old_default_dtype)
904
905    @onlyCUDA
906    @largeTensorTest("72GB", "cuda")
907    @optims(
908        [optim for optim in optim_db if "foreach" in optim.supported_impls],
909        dtypes=[torch.float16],
910    )
911    def test_foreach_large_tensor(self, device, dtype, optim_info):
912        optim_cls = optim_info.optim_cls
913        optim_inputs = optim_info.optim_inputs_func(device=device)
914        for optim_input in optim_inputs:
915            params = [torch.ones(2**32, device=device, dtype=dtype)]
916            params[0].grad = torch.zeros_like(params[0])
917            optimizer = optim_cls(params, foreach=True, **optim_input.kwargs)
918            optimizer.step()
919
920    @onlyCUDA
921    @optims(
922        [optim for optim in optim_db if "foreach" in optim.supported_impls],
923        dtypes=[torch.float32],
924    )
925    def test_peak_memory_foreach(self, device, dtype, optim_info):
926        nparams = 10
927        optim_inputs = optim_info.optim_inputs_func(device=device)
928        optim_cls = optim_info.optim_cls
929        for optim_input in optim_inputs:
930            kwargs = deepcopy(optim_input.kwargs)
931            max_mems = []
932            for flag_value in (False, True):
933                kwargs["foreach"] = flag_value
934                # The 16 * 8 = 128 is critical here! Our CUDACachingAllocator allocates in blocks
935                # of 512, meaning any tensor that occupies <512 bytes of memory will allocate a
936                # whole 512 bytes anyway. We use 128 (cuz datasize would be 4 bytes) so that param
937                # is size 512 exactly, making our later calculations for intermediate_size easy.
938                param = torch.rand(16, 8, device=device, dtype=dtype)
939                params = [torch.rand_like(param) for _ in range(nparams)]
940
941                optimizer = optim_cls(params, **kwargs)
942
943                for p in params:
944                    p.grad = torch.rand_like(p)
945
946                optimizer.step()
947                import gc
948
949                gc.collect()
950                torch.cuda.reset_peak_memory_stats()
951                optimizer.step()
952                gc.collect()
953                max_mems.append(torch.cuda.max_memory_allocated())
954
955            st_max_mem, mt_max_mem = max_mems
956            intermediate_size = nparams * param.nelement() * param.element_size()
957            nintermediates = 1  # we expect a budget of 1 intermediate most of the time
958
959            # Check the param group directly to handle if the compiler set capturable
960            if optimizer.param_groups[0].get(
961                "capturable", False
962            ) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
963                # with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
964                # with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
965                # ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
966                nintermediates = 3
967                if optim_cls.__name__ == "NAdam":
968                    # with capturable in NAdam, we have 3 extra intermediates for the
969                    # bias_correction, mus, and mu_nexts
970                    if TEST_WITH_TORCHDYNAMO:
971                        # With dynamo, the eager/FX backend appears to hold memory longer than
972                        # vanilla eager: https://github.com/pytorch/pytorch/issues/125511
973                        nintermediates = 8
974                    else:
975                        nintermediates = 5
976
977                if optim_cls.__name__ == "RAdam":
978                    # RAdam has four intermediates with capturable
979                    # num, unrect_step_size, buffer, grouped_grads
980                    if TEST_WITH_TORCHDYNAMO:
981                        # With dynamo, the eager/FX backend appears to hold memory than
982                        # vanilla eager: https://github.com/pytorch/pytorch/issues/125511
983                        nintermediates = 6
984                    else:
985                        nintermediates = 4
986
987            elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop", "Adafactor"]:
988                # NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
989                # Adagrad uses std and grads at the same time
990                # RMSprop uses avg and grads
991                # Adafactor uses row/col var and its mean
992                nintermediates = 2
993
994                if optim_cls.__name__ == "Adafactor" and kwargs.get("maximize", False):
995                    # When maximize is True, Adafactor also tracks device_grad
996                    nintermediates = 3
997
998            # Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam
999            # which makes the foreach memory check fail
1000            if TEST_WITH_TORCHDYNAMO:
1001                st_max_mem += 6000
1002
1003            expected_max_mem = st_max_mem + intermediate_size * nintermediates
1004            # hipcc currently can't generate efficient code for the small buffer optimization
1005            # code path (see Note [small buffer optimization] for details), thus we always
1006            # dynamically allocate the tensor metadata for ROCM. Adjusting the expected max
1007            # memory usage to account for this.
1008            if TEST_WITH_ROCM:
1009                expected_max_mem *= 1.02
1010
1011            self.assertLessEqual(mt_max_mem, expected_max_mem)
1012
1013    @optims(
1014        [optim for optim in optim_db if "fused" in optim.supported_impls],
1015        dtypes=floating_types_and(
1016            torch.bfloat16,
1017            torch.float16,
1018        ),
1019    )
1020    def test_fused_matches_forloop(self, device, dtype, optim_info):
1021        if _get_device_type(device) not in optim_info.supports_fused_on:
1022            self.skipTest(
1023                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1024            )
1025        if _get_device_type(device) == "mps" and dtype not in (
1026            torch.float16,
1027            torch.float32,
1028        ):
1029            self.skipTest("MPS supports only torch.float16 and torch.float32")
1030        self._test_derived_optimizers(device, dtype, optim_info, "fused")
1031
1032    @optims(
1033        [optim for optim in optim_db if "fused" in optim.supported_impls],
1034        dtypes=(torch.float32,),
1035    )
1036    def test_fused_error_on_params_on_meta(self, device, dtype, optim_info):
1037        if _get_device_type(device) not in optim_info.supports_fused_on:
1038            self.skipTest(
1039                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1040            )
1041
1042        with torch.device("meta"):
1043            model = torch.nn.Sequential(
1044                torch.nn.Linear(2, 3),
1045                torch.nn.Sigmoid(),
1046                torch.nn.Linear(3, 1),
1047                torch.nn.Sigmoid(),
1048            ).to(dtype)
1049
1050        optimizer = optim_info.optim_cls(model.parameters(), fused=True)
1051        with torch.device("meta"):
1052            for p in model.parameters():
1053                p.grad = torch.rand_like(p)
1054
1055        with self.assertRaisesRegex(
1056            RuntimeError,
1057            "`fused=True` requires all the params to be floating point Tensors",
1058        ):
1059            optimizer.step()
1060
1061        optimizer.zero_grad(set_to_none=True)
1062        model.to_empty(device=device)
1063        for p in model.parameters():
1064            p.grad = torch.rand_like(p)
1065        optimizer.step()
1066
1067    @onlyNativeDeviceTypes
1068    @largeTensorTest("64GB")
1069    @optims(
1070        [optim for optim in optim_db if "fused" in optim.supported_impls],
1071        dtypes=[torch.float16],
1072    )
1073    def test_fused_large_tensor(self, device, dtype, optim_info):
1074        if device not in optim_info.supports_fused_on:
1075            self.skipTest(
1076                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1077            )
1078        optim_cls = optim_info.optim_cls
1079        optim_inputs = optim_info.optim_inputs_func(device=device)
1080        for optim_input in optim_inputs:
1081            params = [torch.ones(2**32, device=device, dtype=dtype)]
1082            params[0].grad = torch.zeros_like(params[0])
1083            optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1084            optimizer.step()
1085
1086    @onlyCUDA
1087    @optims(
1088        [optim for optim in optim_db if "fused" in optim.supported_impls],
1089        dtypes=[torch.float32],
1090    )
1091    def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info):
1092        if device not in optim_info.supports_fused_on:
1093            self.skipTest(
1094                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1095            )
1096        optim_cls = optim_info.optim_cls
1097        optim_inputs = optim_info.optim_inputs_func(device=device)
1098        num_params = 5
1099        for optim_input in optim_inputs:
1100            for no_grad_scale in (False, True):
1101                params = [
1102                    torch.ones((1,), device=device, dtype=dtype)
1103                    for _ in range(num_params)
1104                ]
1105                params_c = [param.clone().detach() for param in params]
1106                for p in params:
1107                    p.grad = torch.ones_like(p)
1108                optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1109                optimizer.grad_scale = (
1110                    None
1111                    if no_grad_scale
1112                    else torch.ones((1,), dtype=dtype, device=device)
1113                )
1114                optimizer.found_inf = torch.ones((), dtype=dtype, device=device)
1115                optimizer.step()
1116                for p in params:
1117                    if "step" in optimizer.state[p]:
1118                        self.assertEqual(
1119                            torch.zeros((), dtype=dtype, device=device),
1120                            optimizer.state[p]["step"],
1121                        )
1122                self.assertEqual(params, params_c)
1123
1124    @parametrize("impl", ["fused", "capturable"])
1125    @optims(
1126        [optim for optim in optim_db if "fused" in optim.supported_impls],
1127        dtypes=[torch.float32],
1128    )
1129    def test_cpu_load_state_dict(self, device, dtype, impl, optim_info):
1130        # NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
1131        # How do we get there? Users typically create CUDA models on fused optimizers and then
1132        # store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
1133        # Since this is a unit test, it is more expedient to simulate what the state_dict
1134        # would look like, which is basically CPU tensors with fused/capturable flag = True.
1135        optim_cls = optim_info.optim_cls
1136        opt_name = optim_cls.__name__
1137        if opt_name in ("SGD", "Adagrad") and impl == "capturable":
1138            # Capturable SGD/Adagrad does not exist
1139            self.skipTest("SGD does not currently support capturable")
1140        if _get_device_type(device) == "cpu":
1141            self.skipTest("Test is only for non-cpu devices")
1142        elif (
1143            impl == "fused"
1144            and _get_device_type(device) not in optim_info.supports_fused_on
1145        ):
1146            self.skipTest(f"{device} is not supported for fused on {opt_name}")
1147        elif impl == "capturable" and _get_device_type(device) == "mps":
1148            self.skipTest("MPS does not support capturable")
1149
1150        cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
1151        for optim_input in cpu_optim_inputs:
1152            param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu")
1153            optimizer = optim_cls([param], **optim_input.kwargs)
1154            param.grad = torch.rand_like(param)
1155            optimizer.step()
1156            optim_state_dict_cpu = deepcopy(optimizer.state_dict())
1157            optim_state_dict_cpu["param_groups"][0][impl] = True
1158
1159            # load
1160            optim_input.kwargs[impl] = True
1161            param_device = param.clone().detach().to(device=device)
1162            optimizer_device = optim_cls([param_device], **optim_input.kwargs)
1163            optimizer_device.load_state_dict(optim_state_dict_cpu)
1164            optimizer_device.zero_grad()
1165            param_device.grad = torch.rand_like(param_device)
1166            optimizer_device.step()
1167
1168    @optims(optim_db, dtypes=[torch.float32])
1169    def test_param_groups_weight_decay(self, device, dtype, optim_info):
1170        optim_cls = optim_info.optim_cls
1171        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1172        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1173            device, dtype, optim_info, skip=("differentiable",)
1174        )
1175        for optim_input in all_optim_inputs:
1176            weight_kwargs = optim_input.kwargs
1177            bias_kwargs = deepcopy(optim_input.kwargs)
1178            bias_kwargs["weight_decay"] = 0.0
1179
1180            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1181            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1182            input = torch.randn(5, device=device, dtype=dtype)
1183
1184            optimizer = optim_cls(
1185                [
1186                    dict(params=[weight], **weight_kwargs),
1187                    dict(params=[bias], **bias_kwargs),
1188                ]
1189            )
1190
1191            loss = (weight.mv(input) + bias).pow(2).sum()
1192            initial_value = loss.item()
1193            for _ in range(20):
1194                optimizer.zero_grad()
1195                loss = (weight.mv(input) + bias).pow(2).sum()
1196                loss.backward()
1197                if optim_info.only_supports_sparse_grads:
1198                    # For this test, we naively convert the Tensor layout, which we know does
1199                    # NOT represent the expected use case for optims like SparseAdam!
1200                    weight.grad = weight.grad.to_sparse()
1201                    bias.grad = bias.grad.to_sparse()
1202                optimizer.step()
1203
1204            # Test that the direction of loss moved appropriately
1205            if optim_input.kwargs.get("maximize", False):
1206                self.assertGreater(loss.item(), initial_value)
1207            else:
1208                self.assertLess(loss.item(), initial_value)
1209
1210    @optims(optim_db, dtypes=[torch.float32])
1211    def test_param_groups_lr(self, device, dtype, optim_info):
1212        optim_cls = optim_info.optim_cls
1213        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1214        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1215            device, dtype, optim_info, skip=("differentiable",)
1216        )
1217        for optim_input in all_optim_inputs:
1218            # optim_input.kwargs will be the param group kwargs, which should have >0 lr
1219            if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
1220                optim_input.kwargs["lr"] = 1e-3
1221            outer_kwargs = {"lr": 1e-28}
1222            if optim_cls.__name__ == "Rprop":
1223                # Allow min step size to be 0
1224                outer_kwargs["step_sizes"] = (0, 50)
1225
1226            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1227            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1228            irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype))
1229            irrelevant_clone = irrelevant.clone()
1230            input = torch.randn(5, device=device, dtype=dtype)
1231            optimizer = optim_cls(
1232                [
1233                    dict(params=[weight, bias], **optim_input.kwargs),
1234                    dict(params=[irrelevant]),
1235                ],
1236                **outer_kwargs,
1237            )
1238
1239            loss = (weight.mv(input) + bias).pow(2).sum()
1240            initial_value = loss.item()
1241            for _ in range(20):
1242                optimizer.zero_grad()
1243                loss = (weight.mv(input) + bias).pow(2).sum()
1244                loss.backward()
1245                irrelevant.grad = torch.rand_like(irrelevant)
1246                if optim_info.only_supports_sparse_grads:
1247                    # For this test, we naively convert the Tensor layout, which we know does
1248                    # NOT represent the expected use case for optims like SparseAdam!
1249                    weight.grad = weight.grad.to_sparse()
1250                    bias.grad = bias.grad.to_sparse()
1251                    irrelevant.grad = irrelevant.grad.to_sparse()
1252                optimizer.step()
1253
1254            # Test that the direction of loss moved appropriately
1255            if optim_input.kwargs.get("maximize", False):
1256                self.assertGreater(loss.item(), initial_value)
1257            else:
1258                self.assertLess(loss.item(), initial_value)
1259
1260            # Test that irrelevant parameters were not updated since lr was almost 0
1261            self.assertEqual(irrelevant, irrelevant_clone)
1262
1263    @optims(optim_db, dtypes=[torch.float32])
1264    def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
1265        optim_cls = optim_info.optim_cls
1266        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1267            device, dtype, optim_info
1268        )
1269        params = [
1270            torch.randn(2, 3, requires_grad=False, device=device, dtype=dtype)
1271            for _ in range(2)
1272        ]
1273        old_params = [p.clone().detach() for p in params]
1274
1275        def closure():
1276            return torch.tensor([1], device=device, dtype=dtype)
1277
1278        for optim_input in all_optim_inputs:
1279            optimizer = optim_cls(params, **optim_input.kwargs)
1280            optimizer.step(closure)
1281
1282    @optims(optim_db, dtypes=[torch.float32])
1283    def test_step_is_noop_for_zero_grads(self, device, dtype, optim_info):
1284        optim_cls = optim_info.optim_cls
1285        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1286            device, dtype, optim_info
1287        )
1288        param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True)
1289        old_param = param.clone().detach()
1290
1291        def closure():
1292            return torch.tensor([1], device=device, dtype=dtype)
1293
1294        for optim_input in all_optim_inputs:
1295            kwargs = optim_input.kwargs
1296
1297            # params will decay even if grads are empty if weight_decay != 0,
1298            # and capturable doesn't work for CPU tensors
1299            if kwargs.get("weight_decay", 0) != 0:
1300                continue
1301
1302            # AdamW params will be updated regardless of grads due to lr, so make lr smaller
1303            if optim_cls.__name__ == "AdamW":
1304                kwargs["lr"] = (
1305                    torch.tensor(1e-5)
1306                    if isinstance(kwargs.get("lr", 1e-5), torch.Tensor)
1307                    else 1e-5
1308                )
1309
1310            if kwargs.get("differentiable", False):
1311                params = [param.clone()]
1312            else:
1313                params = [param]
1314
1315            optimizer = optim_cls(params, **kwargs)
1316            if optim_info.only_supports_sparse_grads:
1317                # Intentionally construct a multidimensional empty v for the sparse grad
1318                # Single dim v passes the test while multidim correctly repros the issue
1319                # https://github.com/pytorch/pytorch/issues/82486
1320                i = torch.empty((1, 0), device=device, dtype=dtype)
1321                v = torch.empty((0, 1), device=device, dtype=dtype)
1322                params[0].grad = torch.sparse_coo_tensor(
1323                    i, v, (5, 1), device=device, dtype=dtype
1324                )
1325            else:
1326                params[0].grad = torch.zeros_like(params[0])
1327            optimizer.step(closure)
1328            self.assertEqual(old_param, params[0])
1329
1330    @optims(optim_db, dtypes=[torch.float32])
1331    def test_optimizer_can_be_printed(self, device, dtype, optim_info):
1332        optim_cls = optim_info.optim_cls
1333        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1334            device, dtype, optim_info
1335        )
1336        params = [
1337            Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype))
1338            for _ in range(2)
1339        ]
1340        for optim_input in all_optim_inputs:
1341            optimizer = optim_cls(params, **optim_input.kwargs)
1342            optimizer.__repr__()
1343
1344    @optims(optim_db, dtypes=[torch.float32])
1345    def test_state_dict_deterministic(self, device, dtype, optim_info):
1346        optim_cls = optim_info.optim_cls
1347
1348        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1349        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1350            device, dtype, optim_info, skip=("differentiable",)
1351        )
1352        weight = Parameter(
1353            torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1354        )
1355        bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1356        input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1357        params = [weight, bias]
1358
1359        def fwd_bwd(optim, w, b, i):
1360            optim.zero_grad()
1361            loss = (w.mv(i) + b).pow(2).sum()
1362            loss.backward()
1363            if optim_info.only_supports_sparse_grads:
1364                if w.grad is not None:
1365                    w.grad = w.grad.to_sparse()
1366                if b.grad is not None:
1367                    b.grad = b.grad.to_sparse()
1368            return loss
1369
1370        for optim_input in all_optim_inputs:
1371            optimizer = optim_cls(params, **optim_input.kwargs)
1372            closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1373
1374            # Prime the optimizer
1375            for _ in range(10):
1376                if optim_info.step_requires_closure:
1377                    optimizer.step(closure)
1378                else:
1379                    closure()
1380                    optimizer.step()
1381
1382            # Clone the weights and construct a new optimizer for them
1383            with torch.no_grad():
1384                weight_c = Parameter(weight.clone())
1385                bias_c = Parameter(bias.clone())
1386
1387            optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs)
1388            closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
1389
1390            # Load the state dict from the original optimizer into the new one
1391            optimizer_c.load_state_dict(deepcopy(optimizer.state_dict()))
1392
1393            # Run both optimizers in parallel
1394            for _ in range(10):
1395                if optim_info.step_requires_closure:
1396                    optimizer.step(closure)
1397                    optimizer_c.step(closure_c)
1398                else:
1399                    closure()
1400                    closure_c()
1401                    optimizer.step()
1402                    optimizer_c.step()
1403
1404                self.assertEqual(weight, weight_c)
1405                self.assertEqual(bias, bias_c)
1406
1407            # Make sure state dict is deterministic with equal (not identical) parameters
1408            self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
1409
1410            # Make sure repeated parameters have identical representation (see #36831)
1411            optimizer_c.param_groups.extend(optimizer_c.param_groups)
1412            self.assertEqual(
1413                optimizer.state_dict()["param_groups"][-1],
1414                optimizer_c.state_dict()["param_groups"][-1],
1415            )
1416
1417    @optims(optim_db, dtypes=[torch.float32])
1418    def test_can_load_older_state_dict(self, device, dtype, optim_info):
1419        optim_cls = optim_info.optim_cls
1420
1421        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1422        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1423            device, dtype, optim_info, skip=("differentiable",)
1424        )
1425        for optim_input in all_optim_inputs:
1426            torch.manual_seed(1)
1427            model = torch.nn.Sequential(
1428                torch.nn.Conv2d(4, 2, 1, stride=2),
1429                torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
1430            )
1431            model.to(dtype=dtype, device=device)
1432            input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
1433            optimizer = optim_cls(model.parameters(), **optim_input.kwargs)
1434
1435            def fwd_bwd(optim, mod, i):
1436                optim.zero_grad()
1437                loss = mod(i).sum()
1438                loss.backward()
1439                return loss
1440
1441            for _ in range(3):
1442                if optim_info.step_requires_closure:
1443                    optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1444                else:
1445                    fwd_bwd(optimizer, model, input)
1446                    optimizer.step()
1447
1448            # old_state_dict has all new flags del'd
1449            old_state_dict = deepcopy(optimizer.state_dict())
1450            old_state_dict_pg = old_state_dict["param_groups"]
1451            for group in old_state_dict_pg:
1452                for flag in optim_info.not_og_supported_flags:
1453                    if flag in group:
1454                        del group[flag]
1455
1456            optimizer.load_state_dict(old_state_dict)
1457
1458            # Make sure we can still step
1459            if optim_info.step_requires_closure:
1460                optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1461            else:
1462                fwd_bwd(optimizer, model, input)
1463                optimizer.step()
1464
1465    @optims(optim_db, dtypes=[torch.float32])
1466    def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
1467        optim_cls = optim_info.optim_cls
1468
1469        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1470        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1471            device, dtype, optim_info, skip=("differentiable",)
1472        )
1473        weight = Parameter(
1474            torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1475        )
1476        bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1477        input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1478        params = [weight, bias]
1479
1480        def fwd_bwd(optim, w, b, i):
1481            optim.zero_grad()
1482            loss = (w.mv(i) + b).pow(2).sum()
1483            loss.backward()
1484            if optim_info.only_supports_sparse_grads:
1485                weight.grad = weight.grad.to_sparse()
1486                bias.grad = bias.grad.to_sparse()
1487            return loss
1488
1489        for optim_input in all_optim_inputs:
1490            optimizer = optim_cls(params, **optim_input.kwargs)
1491            closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1492
1493            # Prime the optimizer
1494            for _ in range(3):
1495                optimizer.step(closure)
1496
1497            sd = optimizer.state_dict()
1498
1499            # === Check saved/loaded state_dict are the same (including weights_only load). ===
1500            with tempfile.TemporaryFile() as f:
1501                torch.save(sd, f)
1502                f.seek(0)
1503                sd_copy = torch.load(f)
1504                self.assertEqual(sd_copy, sd)
1505                del sd_copy
1506                f.seek(0)
1507                sd_copy_wo = torch.load(f, weights_only=True)
1508                self.assertEqual(sd_copy_wo, sd)
1509
1510    @optims(optim_db, dtypes=[torch.float32])
1511    def test_load_nontensor_step(self, device, dtype, optim_info):
1512        optim_cls = optim_info.optim_cls
1513
1514        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1515        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1516            device, dtype, optim_info, skip=("differentiable",)
1517        )
1518        params = [
1519            Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1520        ]
1521        for p in params:
1522            p.grad = torch.rand_like(p)
1523            if optim_info.only_supports_sparse_grads:
1524                # For this test, we naively convert the Tensor layout, which we know does
1525                # NOT represent the expected use case for optims like SparseAdam!
1526                p.grad = p.grad.to_sparse()
1527
1528        # Needed for second order optims like LBFGS
1529        closure_loss = torch.rand(1, device=device, dtype=dtype)
1530
1531        def closure():
1532            return closure_loss if optim_info.step_requires_closure else None
1533
1534        for optim_input in all_optim_inputs:
1535            kwargs = optim_input.kwargs
1536            optimizer = optim_cls(params, **optim_input.kwargs)
1537            for _ in range(3):
1538                optimizer.step(closure)
1539            state_dict = deepcopy(optimizer.state_dict())
1540            for p_state in state_dict["state"].values():
1541                if "step" in p_state and torch.is_tensor(p_state["step"]):
1542                    p_state["step"] = p_state["step"].item()
1543            optimizer.load_state_dict(state_dict)
1544            optimizer.step(closure)
1545
1546    @onlyCUDA
1547    @optims(optim_db, dtypes=[torch.float32])
1548    def test_state_dict_with_cuda_params(self, device, dtype, optim_info):
1549        optim_cls = optim_info.optim_cls
1550
1551        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1552        # We limit our configs to CPU only, because we will be moving them to CUDA later
1553        cpu_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1554            "cpu", dtype, optim_info, skip=("differentiable",)
1555        )
1556
1557        # Needed for second order optims like LBFGS
1558        closure_loss = torch.rand(1, device=device, dtype=dtype)
1559
1560        def closure():
1561            return closure_loss if optim_info.step_requires_closure else None
1562
1563        for optim_input in cpu_optim_inputs:
1564            if (
1565                "fused" in optim_input.kwargs
1566                and "cuda" not in optim_info.supports_fused_on
1567            ):
1568                self.skipTest(
1569                    f"cuda is not supported for fused on {optim_cls.__name__}"
1570                )
1571            params = [
1572                Parameter(torch.randn(2, 3, device="cpu", dtype=dtype))
1573                for _ in range(2)
1574            ]
1575            for p in params:
1576                p.grad = torch.randn_like(p)
1577                if optim_info.only_supports_sparse_grads:
1578                    # For this test, we naively convert the Tensor layout, which we know does
1579                    # NOT represent the expected use case for optims like SparseAdam!
1580                    p.grad = p.grad.to_sparse()
1581
1582            optimizer = optim_cls(params, **optim_input.kwargs)
1583
1584            for _ in range(3):
1585                optimizer.step(closure)
1586
1587            with torch.no_grad():
1588                params_cuda = [p.to(device="cuda") for p in params]
1589                for i, p in enumerate(params_cuda):
1590                    p.grad = params[i].grad.to(device="cuda")
1591            optimizer_cuda = optim_cls(params_cuda, **optim_input.kwargs)
1592
1593            state_dict_cpu = deepcopy(optimizer.state_dict())
1594            state_dict_cuda = deepcopy(optimizer.state_dict())
1595            optimizer_cuda.load_state_dict(state_dict_cuda)
1596
1597            # Make sure state_dict_cuda isn't modified by merely calling load_state_dict
1598            self.assertEqual(state_dict_cpu, state_dict_cuda)
1599
1600            # Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
1601            capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
1602            fused = state_dict_cpu["param_groups"][0].get("fused", False)
1603            new_state_dict = optimizer_cuda.state_dict()
1604            for state_cpu, state_cuda in zip(
1605                state_dict_cpu["state"].values(), new_state_dict["state"].values()
1606            ):
1607                if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
1608                    self.assertEqual(
1609                        state_cuda["step"].device.type,
1610                        "cuda" if capturable or fused else "cpu",
1611                    )
1612
1613            for _ in range(5):
1614                optimizer.step(closure)
1615                optimizer_cuda.step(closure)
1616                self.assertEqual(params, params_cuda)
1617                self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
1618
1619    @staticmethod
1620    def _state_dict_pre_hook(optimizer: Optimizer) -> None:
1621        optimizer.state["test"] = 1
1622
1623    @staticmethod
1624    def _state_dict_post_hook(
1625        optimizer: Optimizer, state_dict: Dict[str, Any]
1626    ) -> Dict[str, Any]:
1627        if "test" in state_dict["state"]:
1628            state_dict["state"].pop("test")
1629            state_dict["ran_state_dict_pre_hook"] = True
1630        else:
1631            state_dict["ran_state_dict_pre_hook"] = False
1632        return state_dict
1633
1634    @optims(optim_db, dtypes=[torch.float32])
1635    def test_state_dict_pre_hook(self, device, dtype, optim_info):
1636        optim_cls = optim_info.optim_cls
1637        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1638            device, dtype, optim_info
1639        )
1640        for optim_input in all_optim_inputs:
1641            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1642            optim = optim_cls([param], **optim_input.kwargs)
1643            optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1644            state_dict = optim.state_dict()
1645            self.assertEqual(state_dict["state"]["test"], 1)
1646
1647    @optims(optim_db, dtypes=[torch.float32])
1648    def test_state_dict_post_hook(self, device, dtype, optim_info):
1649        optim_cls = optim_info.optim_cls
1650        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1651            device, dtype, optim_info
1652        )
1653        for optim_input in all_optim_inputs:
1654            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1655            optim = optim_cls([param], **optim_input.kwargs)
1656            optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1657            state_dict = optim.state_dict()
1658            self.assertFalse(state_dict["ran_state_dict_pre_hook"])
1659
1660    @optims(optim_db, dtypes=[torch.float32])
1661    def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
1662        optim_cls = optim_info.optim_cls
1663        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1664            device, dtype, optim_info
1665        )
1666        for optim_input in all_optim_inputs:
1667            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1668            optim = optim_cls([param], **optim_input.kwargs)
1669            optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1670            optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1671            state_dict = optim.state_dict()
1672            self.assertFalse("test" in state_dict["state"])
1673            self.assertTrue(state_dict["ran_state_dict_pre_hook"])
1674
1675    @staticmethod
1676    def _load_state_dict_pre_hook1(
1677        optimizer: Optimizer, state_dict: Dict[str, Any]
1678    ) -> None:
1679        state_dict["param_groups"][0]["lr"] = 0.002
1680
1681    @staticmethod
1682    def _load_state_dict_pre_hook2(
1683        optimizer: Optimizer, state_dict: Dict[str, Any]
1684    ) -> Dict[str, Any]:
1685        # The typical use case for returning a state dict is to drastically modify the state dict.
1686        # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
1687        my_state_dict = deepcopy(state_dict)
1688        my_state_dict["param_groups"][0]["lr"] = 0.003
1689        return my_state_dict
1690
1691    @staticmethod
1692    def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
1693        optimizer.state["ran_load_state_dict_pre_hook2"] = (
1694            optimizer.param_groups[0]["lr"] == 0.003
1695        )
1696        optimizer.state["ran_load_state_dict_post_hook"] = True
1697
1698    @optims(optim_db, dtypes=[torch.float32])
1699    def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
1700        optim_cls = optim_info.optim_cls
1701        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1702            device, dtype, optim_info
1703        )
1704        for optim_input in all_optim_inputs:
1705            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1706            optim = optim_cls([param], **optim_input.kwargs)
1707            state_dict = optim.state_dict()
1708
1709            # usually one would have a new optim instance here, but it's all the same here
1710            optim.register_load_state_dict_pre_hook(
1711                self.__class__._load_state_dict_pre_hook1
1712            )
1713            optim.load_state_dict(state_dict)
1714            self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1715
1716            optim.register_load_state_dict_pre_hook(
1717                self.__class__._load_state_dict_pre_hook2, prepend=True
1718            )
1719            optim.load_state_dict(state_dict)
1720            # If prepend were False would be 0.003 but since prepend is True, the other hook overrides
1721            self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1722
1723    @optims(optim_db, dtypes=[torch.float32])
1724    def test_load_state_dict_post_hook(self, device, dtype, optim_info):
1725        optim_cls = optim_info.optim_cls
1726        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1727            device, dtype, optim_info
1728        )
1729        for optim_input in all_optim_inputs:
1730            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1731            optim = optim_cls([param], **optim_input.kwargs)
1732
1733            optim.register_load_state_dict_post_hook(
1734                self.__class__._load_state_dict_post_hook
1735            )
1736            optim.load_state_dict(optim.state_dict())
1737            self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
1738            self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1739
1740    @optims(optim_db, dtypes=[torch.float32])
1741    def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
1742        optim_cls = optim_info.optim_cls
1743        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1744            device, dtype, optim_info
1745        )
1746        for optim_input in all_optim_inputs:
1747            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1748            optim = optim_cls([param], **optim_input.kwargs)
1749
1750            optim.register_load_state_dict_pre_hook(
1751                self.__class__._load_state_dict_pre_hook2
1752            )
1753            optim.register_load_state_dict_post_hook(
1754                self.__class__._load_state_dict_post_hook
1755            )
1756            optim.load_state_dict(optim.state_dict())
1757            self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
1758            self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1759
1760    @optims(optim_db, dtypes=[torch.float32])
1761    def test_step_post_hook(self, device, dtype, optim_info):
1762        def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1763            nonlocal data
1764            data += 2
1765
1766        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1767
1768        def dummy_closure():
1769            return 1
1770
1771        closure = dummy_closure if optim_info.step_requires_closure else None
1772
1773        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1774            device, dtype, optim_info
1775        )
1776        for optim_input in all_optim_inputs:
1777            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1778            data = 2
1779            hook_handle = optim.register_step_post_hook(post_hook)
1780
1781            optim.step(closure)
1782            optim.step(closure)
1783            # check if post hooks were registered
1784            self.assertEqual(data, 6)
1785
1786            # remove handles, take step and verify that hook is no longer registered
1787            hook_handle.remove()
1788
1789            optim.step(closure)
1790            self.assertEqual(data, 6)
1791
1792    @optims(optim_db, dtypes=[torch.float32])
1793    def test_step_pre_hook(self, device, dtype, optim_info):
1794        def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1795            nonlocal data
1796            data += 2
1797
1798        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1799
1800        def dummy_closure():
1801            return 1
1802
1803        closure = dummy_closure if optim_info.step_requires_closure else None
1804
1805        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1806            device, dtype, optim_info
1807        )
1808        for optim_input in all_optim_inputs:
1809            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1810            data = 5
1811            hook_handle = optim.register_step_pre_hook(pre_hook)
1812
1813            optim.step(closure)
1814            optim.step(closure)
1815            # check if pre hooks were registered
1816            self.assertEqual(data, 9)
1817
1818            # remove handles, take step and verify that hook is no longer registered
1819            hook_handle.remove()
1820
1821            optim.step(closure)
1822            self.assertEqual(data, 9)
1823
1824    @optims(optim_db, dtypes=[torch.float32])
1825    def test_step_all_hooks(self, device, dtype, optim_info):
1826        def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1827            nonlocal data
1828            data.append(0)
1829
1830        def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1831            nonlocal data
1832            data.append(5)
1833
1834        def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1835            nonlocal data
1836            data.append(1)
1837
1838        def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1839            nonlocal data
1840            data.append(2)
1841
1842        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1843
1844        def dummy_closure():
1845            return 1
1846
1847        closure = dummy_closure if optim_info.step_requires_closure else None
1848
1849        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1850            device, dtype, optim_info
1851        )
1852        for optim_input in all_optim_inputs:
1853            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1854            optim2 = SGD(params)
1855            data = []
1856
1857            # register global hooks to both optimizers
1858            global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
1859            global_post_handle = register_optimizer_step_post_hook(global_post_hook)
1860
1861            # register local hooks
1862            first_pre_handle = optim.register_step_pre_hook(local_pre_hook)
1863            first_post_handle = optim.register_step_post_hook(local_post_hook)
1864            second_pre_handle = optim2.register_step_pre_hook(local_pre_hook)
1865            second_post_handle = optim2.register_step_post_hook(local_post_hook)
1866
1867            optim.step(closure)
1868            self.assertListEqual(data, [0, 1, 2, 5])
1869            optim2.step(closure)
1870            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
1871            optim.step(closure)
1872            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1873
1874            # remove all hooks
1875            global_pre_handle.remove()
1876            global_post_handle.remove()
1877            first_pre_handle.remove()
1878            first_post_handle.remove()
1879            second_pre_handle.remove()
1880            second_post_handle.remove()
1881
1882            optim.step(closure)
1883            optim2.step(closure)
1884            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1885
1886    @optims(optim_db, dtypes=[torch.float32])
1887    def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
1888        optim_cls = optim_info.optim_cls
1889
1890        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1891        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1892            device, dtype, optim_info, skip=("differentiable",)
1893        )
1894
1895        params = [
1896            Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1897        ]
1898        for p in params:
1899            p.grad = torch.rand_like(p)
1900            if optim_info.only_supports_sparse_grads:
1901                # For this test, we naively convert the Tensor layout, which we know does
1902                # NOT represent the expected use case for optims like SparseAdam!
1903                p.grad = p.grad.to_sparse()
1904
1905        # Needed for second order optims like LBFGS
1906        def closure():
1907            return 1 if optim_info.step_requires_closure else None
1908
1909        def getPublicAttrs(obj):
1910            return {k for k in obj.__dict__ if not k.startswith("_")}
1911
1912        for optim_input in all_optim_inputs:
1913            optimizer = optim_cls(params, **optim_input.kwargs)
1914
1915            # Make some state
1916            for _ in range(3):
1917                if optim_info.step_requires_closure:
1918                    optimizer.step(closure)
1919                else:
1920                    closure()
1921                    optimizer.step()
1922
1923            self.assertEqual(
1924                getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer))
1925            )
1926
1927    @optims(
1928        [optim for optim in optim_db if optim.step_requires_closure],
1929        dtypes=[torch.float32],
1930    )
1931    def test_second_order_optims_return_consistent_types(
1932        self, device, dtype, optim_info
1933    ):
1934        # Motivated by #7586
1935        optim_cls = optim_info.optim_cls
1936        params = [
1937            torch.randn(10, 5, device=device, dtype=dtype),
1938            torch.randn(10, device=device, dtype=dtype),
1939        ]
1940
1941        def closure():
1942            return torch.tensor([10], device=device, dtype=dtype)
1943
1944        for optim_input in optim_info.optim_inputs_func(device=device):
1945            # Currently, the only second order optim is LBFGS, so we just go ahead and modify
1946            # "tolerance_grad", but this may not scale if we add second order optims in the future
1947            kwargs = optim_input.kwargs
1948            kwargs["tolerance_grad"] = math.inf
1949            optim_inf = optim_cls(params, **kwargs)
1950            kwargs["tolerance_grad"] = -math.inf
1951            optim_neg_inf = optim_cls(params, **kwargs)
1952
1953            res1 = optim_inf.step(closure)
1954            res2 = optim_neg_inf.step(closure)
1955            self.assertEqual(type(res1), type(res2))
1956
1957    @onlyCUDA
1958    @optims(
1959        [
1960            optim
1961            for optim in optim_db
1962            if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on
1963        ],
1964        dtypes=floating_types_and(
1965            torch.bfloat16,
1966            torch.float16,
1967        ),
1968    )
1969    def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
1970        optim_cls = optim_info.optim_cls
1971        optim_inputs = optim_info.optim_inputs_func(device="cpu")
1972        for optim_input in optim_inputs:
1973            inpts, models, optimizers = [], [], []
1974            for dev in ("cpu", "cuda"):
1975                kwargs = optim_input.kwargs
1976                kwargs["fused"] = True
1977                inpt = torch.tensor(
1978                    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
1979                ).reshape(3, 2)
1980
1981                torch.manual_seed(1)
1982                model = torch.nn.Sequential(
1983                    torch.nn.Linear(2, 3),
1984                    torch.nn.Sigmoid(),
1985                    torch.nn.Linear(3, 1),
1986                    torch.nn.Sigmoid(),
1987                )
1988                model.to(dtype=dtype, device=dev)
1989
1990                # foreach/fused optimizers should be tested with a
1991                # zero_size tensor as its last param.
1992                # ref: https://github.com/pytorch/pytorch/issues/100701
1993                empty_param = torch.empty(
1994                    (), device=dev, dtype=dtype, requires_grad=True
1995                )
1996                empty_param.grad = torch.rand_like(empty_param)
1997                params = list(model.parameters()) + [empty_param]
1998
1999                optimizer = optim_cls(params, **kwargs)
2000                inpts.append(inpt)
2001                models.append(model)
2002                optimizers.append(optimizer)
2003        self._compare_between(inpts, models, optimizers)
2004
2005    @onlyCUDA
2006    @optims(
2007        [
2008            o
2009            for o in optim_db
2010            if ("foreach" in o.supported_impls and o.optim_cls.__name__ != "Adafactor")
2011        ],
2012        dtypes=[torch.float32],
2013    )
2014    def test_defaults_changed_to_foreach(self, device, dtype, optim_info):
2015        # Test that the default implementations for optimizers are changed to foreach
2016        # except Adafactor, which defaults to the single tensor impl for memory efficiency.
2017        optim_cls = optim_info.optim_cls
2018        model = torch.nn.Linear(5, 5)
2019        model.to(dtype=dtype, device=device)
2020        inpt = torch.rand(2, 5, dtype=dtype, device=device)
2021
2022        import inspect
2023
2024        module = inspect.getmodule(optim_cls)
2025
2026        for optim_input in optim_info.optim_inputs_func(device=device):
2027            optim = optim_cls(model.parameters(), **optim_input.kwargs)
2028            optim.zero_grad()
2029            output = model(inpt)
2030            loss = output.sum()
2031            loss.backward()
2032            with patch.object(
2033                module, f"_multi_tensor_{optim_cls.__name__.lower()}"
2034            ) as mocked_foreach_impl:
2035                optim.step()
2036                self.assertTrue(mocked_foreach_impl.called)
2037
2038    @optims(optim_db, dtypes=[torch.float32])
2039    def test_non_empty_state(self, device, dtype, optim_info):
2040        # There are internal tests that check that the state is not empty
2041        optim_cls = optim_info.optim_cls
2042        model = torch.nn.Linear(5, 5)
2043        model.to(dtype=dtype, device=device)
2044        inpt = torch.rand(2, 5, dtype=dtype, device=device)
2045
2046        for optim_input in optim_info.optim_inputs_func(device=device):
2047            optim = optim_cls(model.parameters(), **optim_input.kwargs)
2048            optim.zero_grad()
2049            output = model(inpt)
2050            loss = output.sum()
2051            loss.backward()
2052
2053            if optim_info.only_supports_sparse_grads:
2054                for param in model.parameters():
2055                    if param.grad is not None:
2056                        param.grad = param.grad.to_sparse()
2057
2058            if optim_info.step_requires_closure:
2059                optim.step(lambda: 1.0)
2060            else:
2061                optim.step()
2062
2063            for state in optim.state.values():
2064                self.assertGreater(len(state), 0)
2065
2066
2067instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
2068
2069
2070if __name__ == "__main__":
2071    run_tests()
2072