xref: /aosp_15_r20/external/pytorch/test/test_optim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: optimizer"]
2*da0073e9SAndroid Build Coastguard Workerimport functools
3*da0073e9SAndroid Build Coastguard Workerimport math
4*da0073e9SAndroid Build Coastguard Workerimport tempfile
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
7*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Tuple
8*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerfrom optim.test_lrscheduler import TestLRScheduler  # noqa: F401
11*da0073e9SAndroid Build Coastguard Workerfrom optim.test_optim import TestDifferentiableOptimizer  # noqa: F401
12*da0073e9SAndroid Build Coastguard Workerfrom optim.test_swa_utils import TestSWAUtils  # noqa: F401
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport torch
15*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter
16*da0073e9SAndroid Build Coastguard Workerfrom torch.optim import Optimizer, SGD
17*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.lr_scheduler import ReduceLROnPlateau
18*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.optimizer import (
19*da0073e9SAndroid Build Coastguard Worker    register_optimizer_step_post_hook,
20*da0073e9SAndroid Build Coastguard Worker    register_optimizer_step_pre_hook,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
24*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
25*da0073e9SAndroid Build Coastguard Worker    largeTensorTest,
26*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
27*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
28*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
29*da0073e9SAndroid Build Coastguard Worker    skipMPS,
30*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and
33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_optimizers import (
34*da0073e9SAndroid Build Coastguard Worker    _get_device_type,
35*da0073e9SAndroid Build Coastguard Worker    _get_optim_inputs_including_global_cliquey_kwargs,
36*da0073e9SAndroid Build Coastguard Worker    optim_db,
37*da0073e9SAndroid Build Coastguard Worker    OptimizerErrorEnum,
38*da0073e9SAndroid Build Coastguard Worker    optims,
39*da0073e9SAndroid Build Coastguard Worker    TensorTracker,
40*da0073e9SAndroid Build Coastguard Worker)
41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
42*da0073e9SAndroid Build Coastguard Worker    markDynamoStrictTest,
43*da0073e9SAndroid Build Coastguard Worker    parametrize,
44*da0073e9SAndroid Build Coastguard Worker    run_tests,
45*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
46*da0073e9SAndroid Build Coastguard Worker    TestCase,
47*da0073e9SAndroid Build Coastguard Worker)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard WorkerFP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4}
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerdef rosenbrock(tensor):
54*da0073e9SAndroid Build Coastguard Worker    assert tensor.size() == torch.Size(
55*da0073e9SAndroid Build Coastguard Worker        [2]
56*da0073e9SAndroid Build Coastguard Worker    ), f"Requires tensor with 2 scalars but got {tensor.size()}"
57*da0073e9SAndroid Build Coastguard Worker    x, y = tensor
58*da0073e9SAndroid Build Coastguard Worker    return (1 - x) ** 2 + 100 * (y - x**2) ** 2
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerdef drosenbrock(tensor):
62*da0073e9SAndroid Build Coastguard Worker    assert tensor.size() == torch.Size(
63*da0073e9SAndroid Build Coastguard Worker        [2]
64*da0073e9SAndroid Build Coastguard Worker    ), f"Requires tensor with 2 scalars but got {tensor.size()}"
65*da0073e9SAndroid Build Coastguard Worker    x, y = tensor
66*da0073e9SAndroid Build Coastguard Worker    return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest
70*da0073e9SAndroid Build Coastguard Workerclass TestOptimRenewed(TestCase):
71*da0073e9SAndroid Build Coastguard Worker    """
72*da0073e9SAndroid Build Coastguard Worker    This test class validates the core optimizers and is structured as the correctness of:
73*da0073e9SAndroid Build Coastguard Worker    - The update algorithms (forloop implementation)
74*da0073e9SAndroid Build Coastguard Worker        * Every optimizer's algorithm is most readably implemented through a big for-loop
75*da0073e9SAndroid Build Coastguard Worker          over all the parameters, which is what we refer to as the forloop or single tensor
76*da0073e9SAndroid Build Coastguard Worker          implementation. These algorithms are manually validated by comparing to the paper
77*da0073e9SAndroid Build Coastguard Worker          and systematically validated by assuring that the loss goes the right direction
78*da0073e9SAndroid Build Coastguard Worker          when the optimizer has been applied.
79*da0073e9SAndroid Build Coastguard Worker        * This implementation should compose with optimizer hyperparameters well, such as
80*da0073e9SAndroid Build Coastguard Worker          supporting Tensor LRs, the capturable API, and sparse and complex parameters.
81*da0073e9SAndroid Build Coastguard Worker    - Each varying implementation
82*da0073e9SAndroid Build Coastguard Worker        * We then have implementations that improve upon the performance of the forloop
83*da0073e9SAndroid Build Coastguard Worker          implementation by leveraging fusion, namely our foreach (mult_tensor) and fused
84*da0073e9SAndroid Build Coastguard Worker          implementations.
85*da0073e9SAndroid Build Coastguard Worker        * These variations are validated numerically by comparing with the forloop version
86*da0073e9SAndroid Build Coastguard Worker          of the optimizer. In fact, we test most variations this way--we see the forloop
87*da0073e9SAndroid Build Coastguard Worker          implementation as the ground truth and expect that improvements to it in any way
88*da0073e9SAndroid Build Coastguard Worker          should be just as correct.
89*da0073e9SAndroid Build Coastguard Worker        * Both params and optimizer states should be validated numerically.
90*da0073e9SAndroid Build Coastguard Worker    - state_dict APIs
91*da0073e9SAndroid Build Coastguard Worker        * The optimizer instance should be serializable
92*da0073e9SAndroid Build Coastguard Worker        * Calling save and load should be deterministic
93*da0073e9SAndroid Build Coastguard Worker        * Moving between devices should be seamless
94*da0073e9SAndroid Build Coastguard Worker        * BC - load_state_dict should be able to handle older optimizer states
95*da0073e9SAndroid Build Coastguard Worker    - Hook APIs (everything should fire in the right order)
96*da0073e9SAndroid Build Coastguard Worker    - LR Scheduler integration (composing should not error + should go the right direction)
97*da0073e9SAndroid Build Coastguard Worker    - Parameter groups (should be equivalent to having multiple optimizers)
98*da0073e9SAndroid Build Coastguard Worker    - Erroring (what should error should error)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    We also cover different ways of generating parameters and grads:
101*da0073e9SAndroid Build Coastguard Worker    - With parameters, we either generate them randomly given specific shapes or we take
102*da0073e9SAndroid Build Coastguard Worker      them from a sample NN module.
103*da0073e9SAndroid Build Coastguard Worker        * Variety is important here because NN modules have type Parameter and randomly
104*da0073e9SAndroid Build Coastguard Worker          generated tensors have type Tensor.
105*da0073e9SAndroid Build Coastguard Worker        * Parameters can be sparse for a subset of the optimizers (check out OptimizerInfo)
106*da0073e9SAndroid Build Coastguard Worker        * Complex parameters should be handled using view_as_real
107*da0073e9SAndroid Build Coastguard Worker        * Parameters can be spread across different devices and different dtypes for any
108*da0073e9SAndroid Build Coastguard Worker          given optimizer
109*da0073e9SAndroid Build Coastguard Worker        * Parameters can be contiguous and noncontiguous
110*da0073e9SAndroid Build Coastguard Worker    - With grads, we follow suit from the parameters.
111*da0073e9SAndroid Build Coastguard Worker        * Grads can also be None, empty, or zero-valued, and this should not disrupt training.
112*da0073e9SAndroid Build Coastguard Worker    """
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
115*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db)
116*da0073e9SAndroid Build Coastguard Worker    def test_optim_infos_do_not_specify_global_cliquey_kwargs(
117*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, optim_info
118*da0073e9SAndroid Build Coastguard Worker    ):
119*da0073e9SAndroid Build Coastguard Worker        global_cliquey_flags = ["foreach", "fused", "differentiable"]
120*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_info.optim_inputs_func(device=device):
121*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
122*da0073e9SAndroid Build Coastguard Worker                any(f for f in global_cliquey_flags if f in optim_input.kwargs)
123*da0073e9SAndroid Build Coastguard Worker            )
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
126*da0073e9SAndroid Build Coastguard Worker    def test_errors(self, device, dtype, optim_info):
127*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
128*da0073e9SAndroid Build Coastguard Worker        error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        for error_input in error_inputs:
131*da0073e9SAndroid Build Coastguard Worker            optim_input = error_input.optimizer_error_input
132*da0073e9SAndroid Build Coastguard Worker            params, kwargs = optim_input.params, optim_input.kwargs
133*da0073e9SAndroid Build Coastguard Worker            if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
134*da0073e9SAndroid Build Coastguard Worker                if issubclass(error_input.error_type, Warning):
135*da0073e9SAndroid Build Coastguard Worker                    with self.assertWarnsRegex(
136*da0073e9SAndroid Build Coastguard Worker                        error_input.error_type, error_input.error_regex
137*da0073e9SAndroid Build Coastguard Worker                    ):
138*da0073e9SAndroid Build Coastguard Worker                        optim_cls(params, **kwargs)
139*da0073e9SAndroid Build Coastguard Worker                else:
140*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
141*da0073e9SAndroid Build Coastguard Worker                        error_input.error_type, error_input.error_regex
142*da0073e9SAndroid Build Coastguard Worker                    ):
143*da0073e9SAndroid Build Coastguard Worker                        optim_cls(params, **kwargs)
144*da0073e9SAndroid Build Coastguard Worker            elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
145*da0073e9SAndroid Build Coastguard Worker                optim = optim_cls(params, **kwargs)
146*da0073e9SAndroid Build Coastguard Worker                if issubclass(error_input.error_type, Warning):
147*da0073e9SAndroid Build Coastguard Worker                    with self.assertWarnsRegex(
148*da0073e9SAndroid Build Coastguard Worker                        error_input.error_type, error_input.error_regex
149*da0073e9SAndroid Build Coastguard Worker                    ):
150*da0073e9SAndroid Build Coastguard Worker                        optim.step()
151*da0073e9SAndroid Build Coastguard Worker                else:
152*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
153*da0073e9SAndroid Build Coastguard Worker                        error_input.error_type, error_input.error_regex
154*da0073e9SAndroid Build Coastguard Worker                    ):
155*da0073e9SAndroid Build Coastguard Worker                        optim.step()
156*da0073e9SAndroid Build Coastguard Worker            else:
157*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError(f"Unknown error type {error_input.error_on}")
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    @parametrize("contiguous", [True, False])
160*da0073e9SAndroid Build Coastguard Worker    @parametrize("with_lrsched", [True, False])
161*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
162*da0073e9SAndroid Build Coastguard Worker    def test_forloop_goes_right_direction(
163*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, optim_info, contiguous, with_lrsched
164*da0073e9SAndroid Build Coastguard Worker    ):
165*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
166*da0073e9SAndroid Build Coastguard Worker        schedulers_constructors = (
167*da0073e9SAndroid Build Coastguard Worker            optim_info.scheduler_inputs if with_lrsched else [None]
168*da0073e9SAndroid Build Coastguard Worker        )
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        for schedulers_constructor in schedulers_constructors:
171*da0073e9SAndroid Build Coastguard Worker            # with tensor LR we need fresh inputs for each scheduler
172*da0073e9SAndroid Build Coastguard Worker            # or mutating it will carry across iters
173*da0073e9SAndroid Build Coastguard Worker            optim_inputs = optim_info.optim_inputs_func(device=device)
174*da0073e9SAndroid Build Coastguard Worker            for optim_input in optim_inputs:
175*da0073e9SAndroid Build Coastguard Worker                if "foreach" in optim_info.supported_impls:
176*da0073e9SAndroid Build Coastguard Worker                    optim_input.kwargs["foreach"] = False  # force forloop
177*da0073e9SAndroid Build Coastguard Worker                if contiguous:
178*da0073e9SAndroid Build Coastguard Worker                    weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
179*da0073e9SAndroid Build Coastguard Worker                    bias = Parameter(torch.randn((10), device=device, dtype=dtype))
180*da0073e9SAndroid Build Coastguard Worker                else:
181*da0073e9SAndroid Build Coastguard Worker                    weight = Parameter(
182*da0073e9SAndroid Build Coastguard Worker                        torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]
183*da0073e9SAndroid Build Coastguard Worker                    )
184*da0073e9SAndroid Build Coastguard Worker                    bias = Parameter(
185*da0073e9SAndroid Build Coastguard Worker                        torch.randn((10, 2), device=device, dtype=dtype)[..., 0]
186*da0073e9SAndroid Build Coastguard Worker                    )
187*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(5, device=device, dtype=dtype)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls([weight, bias], **optim_input.kwargs)
190*da0073e9SAndroid Build Coastguard Worker                schedulers = [
191*da0073e9SAndroid Build Coastguard Worker                    s(optimizer)
192*da0073e9SAndroid Build Coastguard Worker                    for s in (schedulers_constructor if schedulers_constructor else [])
193*da0073e9SAndroid Build Coastguard Worker                ]
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker                def closure():
196*da0073e9SAndroid Build Coastguard Worker                    optimizer.zero_grad()
197*da0073e9SAndroid Build Coastguard Worker                    loss = (weight.mv(input) + bias).pow(2).sum()
198*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
199*da0073e9SAndroid Build Coastguard Worker                    if optim_info.only_supports_sparse_grads:
200*da0073e9SAndroid Build Coastguard Worker                        # For this test, we naively convert the Tensor layout, which we know does
201*da0073e9SAndroid Build Coastguard Worker                        # NOT represent the expected use case for optims like SparseAdam!
202*da0073e9SAndroid Build Coastguard Worker                        weight.grad = weight.grad.to_sparse()
203*da0073e9SAndroid Build Coastguard Worker                        bias.grad = bias.grad.to_sparse()
204*da0073e9SAndroid Build Coastguard Worker                    return loss
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker                initial_value = closure().item()
207*da0073e9SAndroid Build Coastguard Worker                for _ in range(20):
208*da0073e9SAndroid Build Coastguard Worker                    if optim_info.step_requires_closure:
209*da0073e9SAndroid Build Coastguard Worker                        loss = optimizer.step(closure)
210*da0073e9SAndroid Build Coastguard Worker                    else:
211*da0073e9SAndroid Build Coastguard Worker                        loss = closure()
212*da0073e9SAndroid Build Coastguard Worker                        optimizer.step()
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker                    for scheduler in schedulers:
215*da0073e9SAndroid Build Coastguard Worker                        if isinstance(scheduler, ReduceLROnPlateau):
216*da0073e9SAndroid Build Coastguard Worker                            scheduler.step(loss)
217*da0073e9SAndroid Build Coastguard Worker                        else:
218*da0073e9SAndroid Build Coastguard Worker                            scheduler.step()
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker                if optim_input.kwargs.get("maximize", False):
221*da0073e9SAndroid Build Coastguard Worker                    self.assertGreater(closure().item(), initial_value)
222*da0073e9SAndroid Build Coastguard Worker                else:
223*da0073e9SAndroid Build Coastguard Worker                    self.assertLess(closure().item(), initial_value)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
226*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
227*da0073e9SAndroid Build Coastguard Worker    @parametrize("with_lrsched", [True, False])
228*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
229*da0073e9SAndroid Build Coastguard Worker    def test_forloop_goes_right_direction_multigpu(
230*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, optim_info, with_lrsched
231*da0073e9SAndroid Build Coastguard Worker    ):
232*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
233*da0073e9SAndroid Build Coastguard Worker        schedulers_constructors = (
234*da0073e9SAndroid Build Coastguard Worker            optim_info.scheduler_inputs if with_lrsched else [None]
235*da0073e9SAndroid Build Coastguard Worker        )
236*da0073e9SAndroid Build Coastguard Worker        for schedulers_constructor in schedulers_constructors:
237*da0073e9SAndroid Build Coastguard Worker            # We need a fresh set of inputs if we have a tensor LR
238*da0073e9SAndroid Build Coastguard Worker            # to not carry mutations across iterations.
239*da0073e9SAndroid Build Coastguard Worker            optim_inputs = optim_info.optim_inputs_func(device=device)
240*da0073e9SAndroid Build Coastguard Worker            for optim_input in optim_inputs:
241*da0073e9SAndroid Build Coastguard Worker                if "foreach" in optim_info.supported_impls:
242*da0073e9SAndroid Build Coastguard Worker                    optim_input.kwargs["foreach"] = False  # force forloop
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker                weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
245*da0073e9SAndroid Build Coastguard Worker                bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
246*da0073e9SAndroid Build Coastguard Worker                inpt = torch.randn(5, device="cuda:0", dtype=dtype)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls([weight, bias], **optim_input.kwargs)
249*da0073e9SAndroid Build Coastguard Worker                schedulers = [
250*da0073e9SAndroid Build Coastguard Worker                    s(optimizer)
251*da0073e9SAndroid Build Coastguard Worker                    for s in (schedulers_constructor if schedulers_constructor else [])
252*da0073e9SAndroid Build Coastguard Worker                ]
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker                def closure():
255*da0073e9SAndroid Build Coastguard Worker                    optimizer.zero_grad()
256*da0073e9SAndroid Build Coastguard Worker                    loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
257*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
258*da0073e9SAndroid Build Coastguard Worker                    if optim_info.only_supports_sparse_grads:
259*da0073e9SAndroid Build Coastguard Worker                        # For this test, we naively convert the Tensor layout, which we know does
260*da0073e9SAndroid Build Coastguard Worker                        # NOT represent the expected use case for optims like SparseAdam!
261*da0073e9SAndroid Build Coastguard Worker                        weight.grad = weight.grad.to_sparse()
262*da0073e9SAndroid Build Coastguard Worker                        bias.grad = bias.grad.to_sparse()
263*da0073e9SAndroid Build Coastguard Worker                    return loss
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker                initial_value = closure().item()
266*da0073e9SAndroid Build Coastguard Worker                for _ in range(20):
267*da0073e9SAndroid Build Coastguard Worker                    loss = optimizer.step(closure)
268*da0073e9SAndroid Build Coastguard Worker                    for scheduler in schedulers:
269*da0073e9SAndroid Build Coastguard Worker                        if isinstance(scheduler, ReduceLROnPlateau):
270*da0073e9SAndroid Build Coastguard Worker                            scheduler.step(loss)
271*da0073e9SAndroid Build Coastguard Worker                        else:
272*da0073e9SAndroid Build Coastguard Worker                            scheduler.step()
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker                if optim_input.kwargs.get("maximize", False):
275*da0073e9SAndroid Build Coastguard Worker                    self.assertGreater(closure().item(), initial_value)
276*da0073e9SAndroid Build Coastguard Worker                else:
277*da0073e9SAndroid Build Coastguard Worker                    self.assertLess(closure().item(), initial_value)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
280*da0073e9SAndroid Build Coastguard Worker    def test_param_group_with_lrscheduler_goes_right_direction(
281*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, optim_info
282*da0073e9SAndroid Build Coastguard Worker    ):
283*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        for schedulers_c in optim_info.scheduler_inputs:
286*da0073e9SAndroid Build Coastguard Worker            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
287*da0073e9SAndroid Build Coastguard Worker            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
288*da0073e9SAndroid Build Coastguard Worker            inpt = torch.randn(5, device=device, dtype=dtype)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker            # avoid endless recompiles by wrapping LR in a tensor if we're compiling
291*da0073e9SAndroid Build Coastguard Worker            lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
292*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
293*da0073e9SAndroid Build Coastguard Worker            schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker            def closure():
296*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
297*da0073e9SAndroid Build Coastguard Worker                loss = (weight.mv(inpt) + bias).pow(2).sum()
298*da0073e9SAndroid Build Coastguard Worker                loss.backward()
299*da0073e9SAndroid Build Coastguard Worker                if optim_info.only_supports_sparse_grads:
300*da0073e9SAndroid Build Coastguard Worker                    # For this test, we naively convert the Tensor layout, which we know does
301*da0073e9SAndroid Build Coastguard Worker                    # NOT represent the expected use case for optims like SparseAdam!
302*da0073e9SAndroid Build Coastguard Worker                    weight.grad = weight.grad.to_sparse()
303*da0073e9SAndroid Build Coastguard Worker                    bias.grad = bias.grad.to_sparse()
304*da0073e9SAndroid Build Coastguard Worker                return loss
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker            initial_value = closure().item()
307*da0073e9SAndroid Build Coastguard Worker            for _ in range(20):
308*da0073e9SAndroid Build Coastguard Worker                loss = optimizer.step(closure)
309*da0073e9SAndroid Build Coastguard Worker                for scheduler in schedulers:
310*da0073e9SAndroid Build Coastguard Worker                    if isinstance(scheduler, ReduceLROnPlateau):
311*da0073e9SAndroid Build Coastguard Worker                        scheduler.step(loss)
312*da0073e9SAndroid Build Coastguard Worker                    else:
313*da0073e9SAndroid Build Coastguard Worker                        scheduler.step()
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker            self.assertLess(closure().item(), initial_value)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
318*da0073e9SAndroid Build Coastguard Worker    def test_tensor_lr(self, device, dtype, optim_info):
319*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
322*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
323*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
324*da0073e9SAndroid Build Coastguard Worker        )
325*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
326*da0073e9SAndroid Build Coastguard Worker            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
327*da0073e9SAndroid Build Coastguard Worker            weight_c = weight.clone().detach().requires_grad_(True)
328*da0073e9SAndroid Build Coastguard Worker            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
329*da0073e9SAndroid Build Coastguard Worker            bias_c = bias.clone().detach().requires_grad_(True)
330*da0073e9SAndroid Build Coastguard Worker            inpt = torch.randn(5, device=device, dtype=dtype)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker            kwargs = optim_input.kwargs
333*da0073e9SAndroid Build Coastguard Worker            if "lr" in kwargs:
334*da0073e9SAndroid Build Coastguard Worker                del kwargs["lr"]
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker            kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
337*da0073e9SAndroid Build Coastguard Worker            optimizer_r = optim_cls([weight, bias], **kwargs)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker            try:
340*da0073e9SAndroid Build Coastguard Worker                kwargs["lr"] = torch.tensor(kwargs["lr"])
341*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls([weight_c, bias_c], **kwargs)
342*da0073e9SAndroid Build Coastguard Worker            except ValueError as e:
343*da0073e9SAndroid Build Coastguard Worker                self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
344*da0073e9SAndroid Build Coastguard Worker                continue
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker            def closure(optim, w, b, i):
347*da0073e9SAndroid Build Coastguard Worker                optim.zero_grad()
348*da0073e9SAndroid Build Coastguard Worker                loss = (w.mv(i) + b).pow(2).sum()
349*da0073e9SAndroid Build Coastguard Worker                loss.backward()
350*da0073e9SAndroid Build Coastguard Worker                if optim_info.only_supports_sparse_grads:
351*da0073e9SAndroid Build Coastguard Worker                    # For this test, we naively convert the Tensor layout, which we know does
352*da0073e9SAndroid Build Coastguard Worker                    # NOT represent the expected use case for optims like SparseAdam!
353*da0073e9SAndroid Build Coastguard Worker                    w.grad = w.grad.to_sparse()
354*da0073e9SAndroid Build Coastguard Worker                    b.grad = b.grad.to_sparse()
355*da0073e9SAndroid Build Coastguard Worker                return loss
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker            for _ in range(5):
358*da0073e9SAndroid Build Coastguard Worker                if optim_info.step_requires_closure:
359*da0073e9SAndroid Build Coastguard Worker                    optimizer_r.step(
360*da0073e9SAndroid Build Coastguard Worker                        functools.partial(closure, optimizer_r, weight, bias, inpt)
361*da0073e9SAndroid Build Coastguard Worker                    )
362*da0073e9SAndroid Build Coastguard Worker                    optimizer.step(
363*da0073e9SAndroid Build Coastguard Worker                        functools.partial(closure, optimizer, weight_c, bias_c, inpt)
364*da0073e9SAndroid Build Coastguard Worker                    )
365*da0073e9SAndroid Build Coastguard Worker                else:
366*da0073e9SAndroid Build Coastguard Worker                    closure(optimizer_r, weight, bias, inpt)
367*da0073e9SAndroid Build Coastguard Worker                    closure(optimizer, weight_c, bias_c, inpt)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(weight, weight_c)
370*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bias, bias_c)
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker    @parametrize("with_lrsched", [True, False])
373*da0073e9SAndroid Build Coastguard Worker    @optims(
374*da0073e9SAndroid Build Coastguard Worker        [o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads],
375*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float64],
376*da0073e9SAndroid Build Coastguard Worker    )
377*da0073e9SAndroid Build Coastguard Worker    def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):
378*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
381*da0073e9SAndroid Build Coastguard Worker        # Fused impls do not support sparse gradients
382*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
383*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable", "fused")
384*da0073e9SAndroid Build Coastguard Worker        )
385*da0073e9SAndroid Build Coastguard Worker        kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        if with_lrsched and len(schedulers_constructors) == 0:
388*da0073e9SAndroid Build Coastguard Worker            return
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        supported_inputs = []
391*da0073e9SAndroid Build Coastguard Worker        if len(kwarg_updates) != 0:
392*da0073e9SAndroid Build Coastguard Worker            seen = set()
393*da0073e9SAndroid Build Coastguard Worker            for i in all_optim_inputs:
394*da0073e9SAndroid Build Coastguard Worker                for k in kwarg_updates:
395*da0073e9SAndroid Build Coastguard Worker                    if k in i.kwargs:
396*da0073e9SAndroid Build Coastguard Worker                        del i.kwargs[k]
397*da0073e9SAndroid Build Coastguard Worker                hashable_kwargs = tuple(sorted(i.kwargs.items()))
398*da0073e9SAndroid Build Coastguard Worker                if len(i.kwargs) > 0 and hashable_kwargs not in seen:
399*da0073e9SAndroid Build Coastguard Worker                    supported_inputs.append(i)
400*da0073e9SAndroid Build Coastguard Worker                    seen.add(hashable_kwargs)
401*da0073e9SAndroid Build Coastguard Worker                    if "lr" in kwarg_updates:
402*da0073e9SAndroid Build Coastguard Worker                        i.kwargs["lr"] = kwarg_updates["lr"]
403*da0073e9SAndroid Build Coastguard Worker        else:
404*da0073e9SAndroid Build Coastguard Worker            supported_inputs = all_optim_inputs
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker        for optim_input in supported_inputs:
407*da0073e9SAndroid Build Coastguard Worker            kwargs = optim_input.kwargs
408*da0073e9SAndroid Build Coastguard Worker            multi_tensor = kwargs.get("foreach", False)
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker            # For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
411*da0073e9SAndroid Build Coastguard Worker            if multi_tensor:
412*da0073e9SAndroid Build Coastguard Worker                params_t = [
413*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([1.5, 1.5]),
414*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([1.5, 1.5], dtype=dtype),
415*da0073e9SAndroid Build Coastguard Worker                ]
416*da0073e9SAndroid Build Coastguard Worker            else:
417*da0073e9SAndroid Build Coastguard Worker                params_t = [torch.tensor([1.5, 1.5])]
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker            params = [Parameter(param_t) for param_t in params_t]
420*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **kwargs)
421*da0073e9SAndroid Build Coastguard Worker            schedulers = [
422*da0073e9SAndroid Build Coastguard Worker                s(optimizer) for s in (schedulers_constructors if with_lrsched else [])
423*da0073e9SAndroid Build Coastguard Worker            ]
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker            if not optim_info.only_supports_sparse_grads:
426*da0073e9SAndroid Build Coastguard Worker                params_c = [Parameter(param_t.clone()) for param_t in params_t]
427*da0073e9SAndroid Build Coastguard Worker                optimizer_c = optim_cls(params_c, **kwargs)
428*da0073e9SAndroid Build Coastguard Worker                schedulers_c = [
429*da0073e9SAndroid Build Coastguard Worker                    s(optimizer_c)
430*da0073e9SAndroid Build Coastguard Worker                    for s in (schedulers_constructors if with_lrsched else [])
431*da0073e9SAndroid Build Coastguard Worker                ]
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker            solution = torch.tensor([1, 1])
434*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
435*da0073e9SAndroid Build Coastguard Worker                initial_dist = sum(param.dist(solution) for param in params)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker            def get_grad(param, sparse_grad, w):
438*da0073e9SAndroid Build Coastguard Worker                grad = drosenbrock(param)
439*da0073e9SAndroid Build Coastguard Worker                # NB: We torture test the optimizer by returning an
440*da0073e9SAndroid Build Coastguard Worker                # uncoalesced sparse tensor
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker                # Depending on w, provide only the x or y gradient
443*da0073e9SAndroid Build Coastguard Worker                if sparse_grad:
444*da0073e9SAndroid Build Coastguard Worker                    if w:
445*da0073e9SAndroid Build Coastguard Worker                        i = torch.tensor([[0, 0]], dtype=torch.int64)
446*da0073e9SAndroid Build Coastguard Worker                        x = grad[0]
447*da0073e9SAndroid Build Coastguard Worker                        v = torch.tensor([x / 4.0, x - x / 4.0])
448*da0073e9SAndroid Build Coastguard Worker                    else:
449*da0073e9SAndroid Build Coastguard Worker                        i = torch.tensor([[1, 1]], dtype=torch.int64)
450*da0073e9SAndroid Build Coastguard Worker                        y = grad[1]
451*da0073e9SAndroid Build Coastguard Worker                        v = torch.tensor([y - y / 4.0, y / 4.0])
452*da0073e9SAndroid Build Coastguard Worker                    grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
453*da0073e9SAndroid Build Coastguard Worker                else:
454*da0073e9SAndroid Build Coastguard Worker                    if w:
455*da0073e9SAndroid Build Coastguard Worker                        grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
456*da0073e9SAndroid Build Coastguard Worker                    else:
457*da0073e9SAndroid Build Coastguard Worker                        grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
458*da0073e9SAndroid Build Coastguard Worker                return grad_out
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker            def eval(params, sparse_grad, w):
461*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
462*da0073e9SAndroid Build Coastguard Worker                if multi_tensor:
463*da0073e9SAndroid Build Coastguard Worker                    loss = sum(rosenbrock(param) for param in params)
464*da0073e9SAndroid Build Coastguard Worker                else:
465*da0073e9SAndroid Build Coastguard Worker                    loss = rosenbrock(params[0])
466*da0073e9SAndroid Build Coastguard Worker                loss.backward()
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker                grads_out = [get_grad(param, sparse_grad, w) for param in params]
469*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
470*da0073e9SAndroid Build Coastguard Worker                    params[0].grad = grads_out[0]
471*da0073e9SAndroid Build Coastguard Worker                    if multi_tensor:
472*da0073e9SAndroid Build Coastguard Worker                        params[1].grad = grads_out[1].to(dtype=dtype)
473*da0073e9SAndroid Build Coastguard Worker                return loss
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker            for i in range(1800):
476*da0073e9SAndroid Build Coastguard Worker                # Do cyclic coordinate descent
477*da0073e9SAndroid Build Coastguard Worker                w = i % 2
478*da0073e9SAndroid Build Coastguard Worker                optimizer.step(functools.partial(eval, params, True, w))
479*da0073e9SAndroid Build Coastguard Worker                for scheduler in schedulers:
480*da0073e9SAndroid Build Coastguard Worker                    if isinstance(scheduler, ReduceLROnPlateau):
481*da0073e9SAndroid Build Coastguard Worker                        scheduler.step(rosenbrock(params[0]))
482*da0073e9SAndroid Build Coastguard Worker                    else:
483*da0073e9SAndroid Build Coastguard Worker                        scheduler.step()
484*da0073e9SAndroid Build Coastguard Worker                if not optim_info.only_supports_sparse_grads:
485*da0073e9SAndroid Build Coastguard Worker                    optimizer_c.step(functools.partial(eval, params_c, False, w))
486*da0073e9SAndroid Build Coastguard Worker                    for scheduler in schedulers_c:
487*da0073e9SAndroid Build Coastguard Worker                        if isinstance(scheduler, ReduceLROnPlateau):
488*da0073e9SAndroid Build Coastguard Worker                            scheduler.step(rosenbrock(params_c[0]))
489*da0073e9SAndroid Build Coastguard Worker                        else:
490*da0073e9SAndroid Build Coastguard Worker                            scheduler.step()
491*da0073e9SAndroid Build Coastguard Worker                    # Tolerance is increased due to floating point error from different
492*da0073e9SAndroid Build Coastguard Worker                    # code path for dense case: x v.s. x - x / 4.0 + x / 4.0
493*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker            if not kwargs.get("maximize", False):
496*da0073e9SAndroid Build Coastguard Worker                self.assertLessEqual(
497*da0073e9SAndroid Build Coastguard Worker                    sum(param.dist(solution) for param in params), initial_dist
498*da0073e9SAndroid Build Coastguard Worker                )
499*da0073e9SAndroid Build Coastguard Worker            else:
500*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(
501*da0073e9SAndroid Build Coastguard Worker                    sum(rosenbrock(param) for param in params),
502*da0073e9SAndroid Build Coastguard Worker                    sum(rosenbrock(param_t) for param_t in params_t),
503*da0073e9SAndroid Build Coastguard Worker                )
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    @skipMPS
506*da0073e9SAndroid Build Coastguard Worker    @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
507*da0073e9SAndroid Build Coastguard Worker    def test_complex(self, device, dtype, optim_info):
508*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
509*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
510*da0073e9SAndroid Build Coastguard Worker        # Also skip fused, since our fused kernels do not support complex
511*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
512*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable", "fused")
513*da0073e9SAndroid Build Coastguard Worker        )
514*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
515*da0073e9SAndroid Build Coastguard Worker            # Last param is intentionally real to test that we can mix real and complex
516*da0073e9SAndroid Build Coastguard Worker            complex_params = [
517*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
518*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, device=device, dtype=dtype, requires_grad=True),
519*da0073e9SAndroid Build Coastguard Worker                torch.randn(
520*da0073e9SAndroid Build Coastguard Worker                    10, 5, device=device, dtype=torch.float32, requires_grad=True
521*da0073e9SAndroid Build Coastguard Worker                ),
522*da0073e9SAndroid Build Coastguard Worker            ]
523*da0073e9SAndroid Build Coastguard Worker            real_params = [
524*da0073e9SAndroid Build Coastguard Worker                (
525*da0073e9SAndroid Build Coastguard Worker                    torch.view_as_real(param).detach().clone().requires_grad_()
526*da0073e9SAndroid Build Coastguard Worker                    if param.is_complex()
527*da0073e9SAndroid Build Coastguard Worker                    else param.detach().clone().requires_grad_()
528*da0073e9SAndroid Build Coastguard Worker                )
529*da0073e9SAndroid Build Coastguard Worker                for param in complex_params
530*da0073e9SAndroid Build Coastguard Worker            ]
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker            complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
533*da0073e9SAndroid Build Coastguard Worker            real_optimizer = optim_cls(real_params, **optim_input.kwargs)
534*da0073e9SAndroid Build Coastguard Worker            real_steps = []
535*da0073e9SAndroid Build Coastguard Worker            complex_steps = []
536*da0073e9SAndroid Build Coastguard Worker            grads_losses = []
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker            def real_closure():
539*da0073e9SAndroid Build Coastguard Worker                for param in real_params:
540*da0073e9SAndroid Build Coastguard Worker                    grad = torch.randn_like(param)
541*da0073e9SAndroid Build Coastguard Worker                    param.grad = grad
542*da0073e9SAndroid Build Coastguard Worker                    real_steps.append(param.detach().clone())
543*da0073e9SAndroid Build Coastguard Worker                    grads_losses.append(grad.clone())
544*da0073e9SAndroid Build Coastguard Worker                loss = torch.randn(1)
545*da0073e9SAndroid Build Coastguard Worker                grads_losses.append(loss.clone())
546*da0073e9SAndroid Build Coastguard Worker                return loss
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker            def complex_closure():
549*da0073e9SAndroid Build Coastguard Worker                for param in complex_params:
550*da0073e9SAndroid Build Coastguard Worker                    if torch.is_complex(param):
551*da0073e9SAndroid Build Coastguard Worker                        grad = torch.view_as_complex(grads_losses.pop(0))
552*da0073e9SAndroid Build Coastguard Worker                        complex_steps.append(torch.view_as_real_copy(param.detach()))
553*da0073e9SAndroid Build Coastguard Worker                    else:
554*da0073e9SAndroid Build Coastguard Worker                        grad = grads_losses.pop(0)
555*da0073e9SAndroid Build Coastguard Worker                        complex_steps.append(param.detach().clone())
556*da0073e9SAndroid Build Coastguard Worker                    param.grad = grad
557*da0073e9SAndroid Build Coastguard Worker                return grads_losses.pop(0)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
560*da0073e9SAndroid Build Coastguard Worker                if optim_info.step_requires_closure:
561*da0073e9SAndroid Build Coastguard Worker                    # LBFGS, for example, requires closure and calls it internally
562*da0073e9SAndroid Build Coastguard Worker                    real_optimizer.step(real_closure)
563*da0073e9SAndroid Build Coastguard Worker                    complex_optimizer.step(complex_closure)
564*da0073e9SAndroid Build Coastguard Worker                else:
565*da0073e9SAndroid Build Coastguard Worker                    # For other optimizers, we call closure explicitly to set the gradients
566*da0073e9SAndroid Build Coastguard Worker                    real_closure()
567*da0073e9SAndroid Build Coastguard Worker                    complex_closure()
568*da0073e9SAndroid Build Coastguard Worker                    real_optimizer.step()
569*da0073e9SAndroid Build Coastguard Worker                    complex_optimizer.step()
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker            # Final Parameters should be the same
572*da0073e9SAndroid Build Coastguard Worker            complex_params_asreal = [
573*da0073e9SAndroid Build Coastguard Worker                torch.view_as_real(param) if param.is_complex() else param
574*da0073e9SAndroid Build Coastguard Worker                for param in complex_params
575*da0073e9SAndroid Build Coastguard Worker            ]
576*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(real_params, complex_params_asreal)
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker            # All intermediate steps should also be the same
579*da0073e9SAndroid Build Coastguard Worker            # also checks steps taken within for example a line search
580*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(complex_steps, real_steps)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    @skipMPS
583*da0073e9SAndroid Build Coastguard Worker    @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
584*da0073e9SAndroid Build Coastguard Worker    def test_complex_2d(self, device, dtype, optim_info):
585*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
586*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
587*da0073e9SAndroid Build Coastguard Worker        # Also skip fused, since our fused kernels do not support complex
588*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
589*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable", "fused")
590*da0073e9SAndroid Build Coastguard Worker        )
591*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
592*da0073e9SAndroid Build Coastguard Worker            if optim_info.step_requires_closure:
593*da0073e9SAndroid Build Coastguard Worker                # Why? The way we implement complex is by turning complex params into view_as_real
594*da0073e9SAndroid Build Coastguard Worker                # alternatives. For example, an size (M,N) tensor will become (M,N,2). In this test,
595*da0073e9SAndroid Build Coastguard Worker                # we break apart a tensor into its real and imaginary parts, which would be 2x(M,N).
596*da0073e9SAndroid Build Coastguard Worker                # For other pointwise optimizers, this distinction is trivial, but for LBFGS where
597*da0073e9SAndroid Build Coastguard Worker                # there are reductions across all parameters (and all the grads get flattened into
598*da0073e9SAndroid Build Coastguard Worker                # one long Tensor), this ordering matters. Why? Reductions are not deterministic
599*da0073e9SAndroid Build Coastguard Worker                # because addition between floating point numbers is not associative, i.e.,
600*da0073e9SAndroid Build Coastguard Worker                # a + b + c != a + c + b. Thus, we add a seed here to control the discrepancy that
601*da0073e9SAndroid Build Coastguard Worker                # will happen with LBFGS. Note that in test_complex above, there is no need for a seed
602*da0073e9SAndroid Build Coastguard Worker                # nor for increased tolerance, because results should be bitwise equivalent.
603*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(2024)
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker            a1 = torch.randn(2, device=device, dtype=dtype, requires_grad=True)
606*da0073e9SAndroid Build Coastguard Worker            a1_real = a1.real.clone().detach()
607*da0073e9SAndroid Build Coastguard Worker            a1_imag = a1.imag.clone().detach()
608*da0073e9SAndroid Build Coastguard Worker            a1_real.requires_grad_()
609*da0073e9SAndroid Build Coastguard Worker            a1_imag.requires_grad_()
610*da0073e9SAndroid Build Coastguard Worker            optim1 = optim_cls([a1], **optim_input.kwargs)
611*da0073e9SAndroid Build Coastguard Worker            optim2 = optim_cls([a1_real, a1_imag], **optim_input.kwargs)
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker            a1_reals = TensorTracker()
614*da0073e9SAndroid Build Coastguard Worker            a1_imags = TensorTracker()
615*da0073e9SAndroid Build Coastguard Worker            a1_grad_reals = TensorTracker()
616*da0073e9SAndroid Build Coastguard Worker            a1_grad_imags = TensorTracker()
617*da0073e9SAndroid Build Coastguard Worker            losses = TensorTracker()
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker            def closure1():
620*da0073e9SAndroid Build Coastguard Worker                optim1.zero_grad()
621*da0073e9SAndroid Build Coastguard Worker                loss = rosenbrock(a1).abs()
622*da0073e9SAndroid Build Coastguard Worker                loss.backward()
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker                # Track clones to best test accuracy
625*da0073e9SAndroid Build Coastguard Worker                a1_reals.add(a1.real)
626*da0073e9SAndroid Build Coastguard Worker                a1_imags.add(a1.imag)
627*da0073e9SAndroid Build Coastguard Worker                a1_grad_reals.add(a1.grad.real)
628*da0073e9SAndroid Build Coastguard Worker                a1_grad_imags.add(a1.grad.imag)
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker                losses.add(loss)
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker                return loss
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker            def closure2():
635*da0073e9SAndroid Build Coastguard Worker                optim2.zero_grad()
636*da0073e9SAndroid Build Coastguard Worker                a1_reals.pop_check_set(a1_real, self)
637*da0073e9SAndroid Build Coastguard Worker                a1_imags.pop_check_set(a1_imag, self)
638*da0073e9SAndroid Build Coastguard Worker                a2 = torch.complex(a1_real, a1_imag)
639*da0073e9SAndroid Build Coastguard Worker                loss = rosenbrock(a2).abs()
640*da0073e9SAndroid Build Coastguard Worker                losses.pop_check_set(loss, self)
641*da0073e9SAndroid Build Coastguard Worker                loss.backward()
642*da0073e9SAndroid Build Coastguard Worker                a1_grad_reals.pop_check_set(a1_real.grad, self)
643*da0073e9SAndroid Build Coastguard Worker                a1_grad_imags.pop_check_set(a1_imag.grad, self)
644*da0073e9SAndroid Build Coastguard Worker                return loss
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
647*da0073e9SAndroid Build Coastguard Worker                if optim_info.step_requires_closure:
648*da0073e9SAndroid Build Coastguard Worker                    # LBFGS, for example, requires closure and calls it internally
649*da0073e9SAndroid Build Coastguard Worker                    optim1.step(closure1)
650*da0073e9SAndroid Build Coastguard Worker                    optim2.step(closure2)
651*da0073e9SAndroid Build Coastguard Worker                else:
652*da0073e9SAndroid Build Coastguard Worker                    closure1()
653*da0073e9SAndroid Build Coastguard Worker                    closure2()
654*da0073e9SAndroid Build Coastguard Worker                    optim1.step()
655*da0073e9SAndroid Build Coastguard Worker                    optim2.step()
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a1.real, a1_real)
658*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a1.imag, a1_imag)
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(a1_reals.all_popped())
661*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(a1_imags.all_popped())
662*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(a1_grad_reals.all_popped())
663*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(a1_grad_imags.all_popped())
664*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(losses.all_popped())
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker    def _compare_between(
667*da0073e9SAndroid Build Coastguard Worker        self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None
668*da0073e9SAndroid Build Coastguard Worker    ):
669*da0073e9SAndroid Build Coastguard Worker        # why 7? iteration 7 is where we start to see differences for RAdam
670*da0073e9SAndroid Build Coastguard Worker        # params interacting with the small eps value, because that's right
671*da0073e9SAndroid Build Coastguard Worker        # after rho_t becomes greater than 5 in step 6.
672*da0073e9SAndroid Build Coastguard Worker        if assert_eq_kwargs is None:
673*da0073e9SAndroid Build Coastguard Worker            assert_eq_kwargs = {}
674*da0073e9SAndroid Build Coastguard Worker        kIterations = 7
675*da0073e9SAndroid Build Coastguard Worker        tracker = TensorTracker(assert_eq_kwargs)
676*da0073e9SAndroid Build Coastguard Worker        for i in range(kIterations):
677*da0073e9SAndroid Build Coastguard Worker            state, updated_params = [], []
678*da0073e9SAndroid Build Coastguard Worker            if not isinstance(inputs, list):
679*da0073e9SAndroid Build Coastguard Worker                inputs = [inputs, inputs]
680*da0073e9SAndroid Build Coastguard Worker            for input, model, optimizer in zip(inputs, models, optimizers):
681*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker                if i == 3:
684*da0073e9SAndroid Build Coastguard Worker                    # Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
685*da0073e9SAndroid Build Coastguard Worker                    # is same as the step in 'forloop'.
686*da0073e9SAndroid Build Coastguard Worker                    model[2].requires_grad_(False)
687*da0073e9SAndroid Build Coastguard Worker                if i == 5:
688*da0073e9SAndroid Build Coastguard Worker                    # Unfreeze the layer after 2 iters.
689*da0073e9SAndroid Build Coastguard Worker                    model[2].requires_grad_(True)
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker                # Test that step behaves as expected (a no-op) when grads are set to None
692*da0073e9SAndroid Build Coastguard Worker                if i != 2:
693*da0073e9SAndroid Build Coastguard Worker                    output = model(input)
694*da0073e9SAndroid Build Coastguard Worker                    loss = output.sum()
695*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
698*da0073e9SAndroid Build Coastguard Worker                state.append(optimizer.state)
699*da0073e9SAndroid Build Coastguard Worker                updated_params.append(model.parameters())
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker            og_state, new_state = state
702*da0073e9SAndroid Build Coastguard Worker            for og_p, new_p in zip(updated_params[0], updated_params[1]):
703*da0073e9SAndroid Build Coastguard Worker                tracker.add(og_p)
704*da0073e9SAndroid Build Coastguard Worker                tracker.pop_check_set(new_p, self)
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker                # check that optimizer states are the same
707*da0073e9SAndroid Build Coastguard Worker                og_p_state = og_state[og_p]
708*da0073e9SAndroid Build Coastguard Worker                new_p_state = new_state[new_p]
709*da0073e9SAndroid Build Coastguard Worker                if assert_step_dtype is not None:
710*da0073e9SAndroid Build Coastguard Worker                    if torch.is_tensor(og_p_state.get("step", None)):
711*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
712*da0073e9SAndroid Build Coastguard Worker                    if torch.is_tensor(new_p_state.get("step", None)):
713*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
714*da0073e9SAndroid Build Coastguard Worker                for k in og_p_state:
715*da0073e9SAndroid Build Coastguard Worker                    tracker.add(og_p_state[k])
716*da0073e9SAndroid Build Coastguard Worker                    tracker.pop_check_set(new_p_state[k], self)
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(tracker.all_popped())
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker    def _test_derived_optimizers(
721*da0073e9SAndroid Build Coastguard Worker        self,
722*da0073e9SAndroid Build Coastguard Worker        device,
723*da0073e9SAndroid Build Coastguard Worker        dtype,
724*da0073e9SAndroid Build Coastguard Worker        optim_info,
725*da0073e9SAndroid Build Coastguard Worker        flag,
726*da0073e9SAndroid Build Coastguard Worker        reduced_precision=False,
727*da0073e9SAndroid Build Coastguard Worker        assert_step_dtype=None,
728*da0073e9SAndroid Build Coastguard Worker    ):
729*da0073e9SAndroid Build Coastguard Worker        """
730*da0073e9SAndroid Build Coastguard Worker        Given a flag 'fused' or 'foreach', test for parity of optimizer state
731*da0073e9SAndroid Build Coastguard Worker        and updated parameters between when the flag is set to True and False
732*da0073e9SAndroid Build Coastguard Worker        for provided optimizer configurations.
733*da0073e9SAndroid Build Coastguard Worker        """
734*da0073e9SAndroid Build Coastguard Worker        assert flag in ("foreach", "fused")
735*da0073e9SAndroid Build Coastguard Worker        assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
738*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
739*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
740*da0073e9SAndroid Build Coastguard Worker            models, optimizers = [], []
741*da0073e9SAndroid Build Coastguard Worker            kwargs = deepcopy(optim_input.kwargs)
742*da0073e9SAndroid Build Coastguard Worker            if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
743*da0073e9SAndroid Build Coastguard Worker                # capturable is not supported on CPU
744*da0073e9SAndroid Build Coastguard Worker                continue
745*da0073e9SAndroid Build Coastguard Worker            for flag_value in (False, True):
746*da0073e9SAndroid Build Coastguard Worker                kwargs[flag] = flag_value
747*da0073e9SAndroid Build Coastguard Worker                input = torch.tensor(
748*da0073e9SAndroid Build Coastguard Worker                    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
749*da0073e9SAndroid Build Coastguard Worker                ).reshape(3, 2)
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(1)
752*da0073e9SAndroid Build Coastguard Worker                model = torch.nn.Sequential(
753*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(2, 3),
754*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Sigmoid(),
755*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(3, 1),
756*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Sigmoid(),
757*da0073e9SAndroid Build Coastguard Worker                )
758*da0073e9SAndroid Build Coastguard Worker                model.to(dtype=dtype, device=device)
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker                # foreach/fused optimizers should be tested with a
761*da0073e9SAndroid Build Coastguard Worker                # zero_size tensor as its last param.
762*da0073e9SAndroid Build Coastguard Worker                # ref: https://github.com/pytorch/pytorch/issues/100701
763*da0073e9SAndroid Build Coastguard Worker                empty_param = torch.empty(
764*da0073e9SAndroid Build Coastguard Worker                    (), device=device, dtype=dtype, requires_grad=True
765*da0073e9SAndroid Build Coastguard Worker                )
766*da0073e9SAndroid Build Coastguard Worker                empty_param.grad = torch.rand_like(empty_param)
767*da0073e9SAndroid Build Coastguard Worker                params = list(model.parameters()) + [empty_param]
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls(params, **kwargs)
770*da0073e9SAndroid Build Coastguard Worker                models.append(model)
771*da0073e9SAndroid Build Coastguard Worker                optimizers.append(optimizer)
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker            self._compare_between(
774*da0073e9SAndroid Build Coastguard Worker                input, models, optimizers, assert_eq_kwargs, assert_step_dtype
775*da0073e9SAndroid Build Coastguard Worker            )
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker    @skipMPS  # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
778*da0073e9SAndroid Build Coastguard Worker    @optims(
779*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "foreach" in optim.supported_impls],
780*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float64],
781*da0073e9SAndroid Build Coastguard Worker    )
782*da0073e9SAndroid Build Coastguard Worker    def test_foreach_matches_forloop(self, device, dtype, optim_info):
783*da0073e9SAndroid Build Coastguard Worker        self._test_derived_optimizers(device, dtype, optim_info, "foreach")
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
786*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
787*da0073e9SAndroid Build Coastguard Worker    @parametrize("impl", ["foreach", "fused"])
788*da0073e9SAndroid Build Coastguard Worker    @optims(
789*da0073e9SAndroid Build Coastguard Worker        [
790*da0073e9SAndroid Build Coastguard Worker            optim
791*da0073e9SAndroid Build Coastguard Worker            for optim in optim_db
792*da0073e9SAndroid Build Coastguard Worker            if "foreach" in optim.supported_impls or "fused" in optim.supported_impls
793*da0073e9SAndroid Build Coastguard Worker        ]
794*da0073e9SAndroid Build Coastguard Worker    )
795*da0073e9SAndroid Build Coastguard Worker    def test_mixed_device_dtype(self, device, dtype, optim_info, impl):
796*da0073e9SAndroid Build Coastguard Worker        """
797*da0073e9SAndroid Build Coastguard Worker        Similar in essence to _test_derived_optimizers above. The main difference is that
798*da0073e9SAndroid Build Coastguard Worker        _test_derived_optimizers uses model parameters whereas we randomly pass in
799*da0073e9SAndroid Build Coastguard Worker        parameters of different dtypes and devices here. We need multiple GPUs (vs just a
800*da0073e9SAndroid Build Coastguard Worker        CPU and GPU) because fused adam only works on GPUs. (Thus we only run the tests
801*da0073e9SAndroid Build Coastguard Worker        that call into this helper when TEST_MULTIGPU.)
802*da0073e9SAndroid Build Coastguard Worker        """
803*da0073e9SAndroid Build Coastguard Worker        assert impl in ("foreach", "fused")
804*da0073e9SAndroid Build Coastguard Worker        if impl == "foreach" and "foreach" not in optim_info.supported_impls:
805*da0073e9SAndroid Build Coastguard Worker            return unittest.skip(
806*da0073e9SAndroid Build Coastguard Worker                f"foreach not supported for {optim_info.optim_cls.__name__}"
807*da0073e9SAndroid Build Coastguard Worker            )
808*da0073e9SAndroid Build Coastguard Worker        elif impl == "fused" and "cuda" not in optim_info.supports_fused_on:
809*da0073e9SAndroid Build Coastguard Worker            return unittest.skip(
810*da0073e9SAndroid Build Coastguard Worker                f"fused not supported for {optim_info.optim_cls.__name__} on cuda"
811*da0073e9SAndroid Build Coastguard Worker            )
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker        params = [
814*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.float64, device="cuda:0", requires_grad=True),
815*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.float32, device="cuda:0", requires_grad=True),
816*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.float16, device="cuda:0", requires_grad=True),
817*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:0", requires_grad=True),
818*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.float64, device="cuda:1", requires_grad=True),
819*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.float32, device="cuda:1", requires_grad=True),
820*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.float16, device="cuda:1", requires_grad=True),
821*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:1", requires_grad=True),
822*da0073e9SAndroid Build Coastguard Worker            torch.randint(
823*da0073e9SAndroid Build Coastguard Worker                1024, (2, 3), dtype=torch.int64, device="cuda:1", requires_grad=False
824*da0073e9SAndroid Build Coastguard Worker            ),
825*da0073e9SAndroid Build Coastguard Worker        ]
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker        for p in params:
828*da0073e9SAndroid Build Coastguard Worker            if p.requires_grad:
829*da0073e9SAndroid Build Coastguard Worker                p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        kIterations = 7 if impl == "foreach" else 1
832*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device)
833*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
834*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
835*da0073e9SAndroid Build Coastguard Worker            updated_params, state = [], []
836*da0073e9SAndroid Build Coastguard Worker            kwargs = deepcopy(optim_input.kwargs)
837*da0073e9SAndroid Build Coastguard Worker            if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
838*da0073e9SAndroid Build Coastguard Worker                # capturable is not supported on CPU
839*da0073e9SAndroid Build Coastguard Worker                continue
840*da0073e9SAndroid Build Coastguard Worker            for use_impl in (False, True):
841*da0073e9SAndroid Build Coastguard Worker                kwargs[impl] = use_impl
842*da0073e9SAndroid Build Coastguard Worker                params_clone = []
843*da0073e9SAndroid Build Coastguard Worker                for p in params:
844*da0073e9SAndroid Build Coastguard Worker                    p_clone = p.clone().detach()
845*da0073e9SAndroid Build Coastguard Worker                    if p.requires_grad:
846*da0073e9SAndroid Build Coastguard Worker                        p_clone.requires_grad = True
847*da0073e9SAndroid Build Coastguard Worker                        p_clone.grad = p.grad.clone().detach()
848*da0073e9SAndroid Build Coastguard Worker                        params_clone.append(p_clone)
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls(params_clone, **kwargs)
851*da0073e9SAndroid Build Coastguard Worker                for _ in range(kIterations):
852*da0073e9SAndroid Build Coastguard Worker                    optimizer.step()
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker                state.append(optimizer.state)
855*da0073e9SAndroid Build Coastguard Worker                updated_params.append(params_clone)
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker            og_state, new_state = state
858*da0073e9SAndroid Build Coastguard Worker            for og_p, new_p in zip(updated_params[0], updated_params[1]):
859*da0073e9SAndroid Build Coastguard Worker                # Increasing the tolerance as we are collating lots of ops together for optimizers and
860*da0073e9SAndroid Build Coastguard Worker                # the designated tolerances are for single op only.
861*da0073e9SAndroid Build Coastguard Worker                single_rtol, single_atol = torch.testing._comparison.get_tolerances(
862*da0073e9SAndroid Build Coastguard Worker                    new_p.dtype, rtol=None, atol=None
863*da0073e9SAndroid Build Coastguard Worker                )
864*da0073e9SAndroid Build Coastguard Worker                rtol = 5 * single_rtol
865*da0073e9SAndroid Build Coastguard Worker                atol = 5 * single_atol
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(og_p, new_p, rtol=rtol, atol=atol)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker                # check that optimizer states are the same
870*da0073e9SAndroid Build Coastguard Worker                og_p_state = og_state[og_p]
871*da0073e9SAndroid Build Coastguard Worker                new_p_state = new_state[new_p]
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker                for k in og_p_state:
874*da0073e9SAndroid Build Coastguard Worker                    actual = new_p_state[k]
875*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(og_p_state[k], actual, rtol=rtol, atol=atol)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
878*da0073e9SAndroid Build Coastguard Worker    @optims(
879*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "foreach" in optim.supported_impls],
880*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float64],
881*da0073e9SAndroid Build Coastguard Worker    )
882*da0073e9SAndroid Build Coastguard Worker    def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info):
883*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/110940
884*da0073e9SAndroid Build Coastguard Worker        # We coerce step to always be float32 unless the
885*da0073e9SAndroid Build Coastguard Worker        # default dtype is higher prec float64
886*da0073e9SAndroid Build Coastguard Worker        old_default_dtype = torch.get_default_dtype()
887*da0073e9SAndroid Build Coastguard Worker        for default_dtype in [torch.float64, torch.float16]:
888*da0073e9SAndroid Build Coastguard Worker            try:
889*da0073e9SAndroid Build Coastguard Worker                torch.set_default_dtype(default_dtype)
890*da0073e9SAndroid Build Coastguard Worker                self._test_derived_optimizers(
891*da0073e9SAndroid Build Coastguard Worker                    device,
892*da0073e9SAndroid Build Coastguard Worker                    dtype,
893*da0073e9SAndroid Build Coastguard Worker                    optim_info,
894*da0073e9SAndroid Build Coastguard Worker                    "foreach",
895*da0073e9SAndroid Build Coastguard Worker                    reduced_precision=default_dtype == torch.float16,
896*da0073e9SAndroid Build Coastguard Worker                    assert_step_dtype=(
897*da0073e9SAndroid Build Coastguard Worker                        torch.float64
898*da0073e9SAndroid Build Coastguard Worker                        if default_dtype == torch.float64
899*da0073e9SAndroid Build Coastguard Worker                        else torch.float32
900*da0073e9SAndroid Build Coastguard Worker                    ),
901*da0073e9SAndroid Build Coastguard Worker                )
902*da0073e9SAndroid Build Coastguard Worker            finally:
903*da0073e9SAndroid Build Coastguard Worker                torch.set_default_dtype(old_default_dtype)
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
906*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("72GB", "cuda")
907*da0073e9SAndroid Build Coastguard Worker    @optims(
908*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "foreach" in optim.supported_impls],
909*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float16],
910*da0073e9SAndroid Build Coastguard Worker    )
911*da0073e9SAndroid Build Coastguard Worker    def test_foreach_large_tensor(self, device, dtype, optim_info):
912*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
913*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device)
914*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
915*da0073e9SAndroid Build Coastguard Worker            params = [torch.ones(2**32, device=device, dtype=dtype)]
916*da0073e9SAndroid Build Coastguard Worker            params[0].grad = torch.zeros_like(params[0])
917*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, foreach=True, **optim_input.kwargs)
918*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
921*da0073e9SAndroid Build Coastguard Worker    @optims(
922*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "foreach" in optim.supported_impls],
923*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
924*da0073e9SAndroid Build Coastguard Worker    )
925*da0073e9SAndroid Build Coastguard Worker    def test_peak_memory_foreach(self, device, dtype, optim_info):
926*da0073e9SAndroid Build Coastguard Worker        nparams = 10
927*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device)
928*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
929*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
930*da0073e9SAndroid Build Coastguard Worker            kwargs = deepcopy(optim_input.kwargs)
931*da0073e9SAndroid Build Coastguard Worker            max_mems = []
932*da0073e9SAndroid Build Coastguard Worker            for flag_value in (False, True):
933*da0073e9SAndroid Build Coastguard Worker                kwargs["foreach"] = flag_value
934*da0073e9SAndroid Build Coastguard Worker                # The 16 * 8 = 128 is critical here! Our CUDACachingAllocator allocates in blocks
935*da0073e9SAndroid Build Coastguard Worker                # of 512, meaning any tensor that occupies <512 bytes of memory will allocate a
936*da0073e9SAndroid Build Coastguard Worker                # whole 512 bytes anyway. We use 128 (cuz datasize would be 4 bytes) so that param
937*da0073e9SAndroid Build Coastguard Worker                # is size 512 exactly, making our later calculations for intermediate_size easy.
938*da0073e9SAndroid Build Coastguard Worker                param = torch.rand(16, 8, device=device, dtype=dtype)
939*da0073e9SAndroid Build Coastguard Worker                params = [torch.rand_like(param) for _ in range(nparams)]
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls(params, **kwargs)
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker                for p in params:
944*da0073e9SAndroid Build Coastguard Worker                    p.grad = torch.rand_like(p)
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
947*da0073e9SAndroid Build Coastguard Worker                import gc
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker                gc.collect()
950*da0073e9SAndroid Build Coastguard Worker                torch.cuda.reset_peak_memory_stats()
951*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
952*da0073e9SAndroid Build Coastguard Worker                gc.collect()
953*da0073e9SAndroid Build Coastguard Worker                max_mems.append(torch.cuda.max_memory_allocated())
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker            st_max_mem, mt_max_mem = max_mems
956*da0073e9SAndroid Build Coastguard Worker            intermediate_size = nparams * param.nelement() * param.element_size()
957*da0073e9SAndroid Build Coastguard Worker            nintermediates = 1  # we expect a budget of 1 intermediate most of the time
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker            # Check the param group directly to handle if the compiler set capturable
960*da0073e9SAndroid Build Coastguard Worker            if optimizer.param_groups[0].get(
961*da0073e9SAndroid Build Coastguard Worker                "capturable", False
962*da0073e9SAndroid Build Coastguard Worker            ) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
963*da0073e9SAndroid Build Coastguard Worker                # with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
964*da0073e9SAndroid Build Coastguard Worker                # with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
965*da0073e9SAndroid Build Coastguard Worker                # ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
966*da0073e9SAndroid Build Coastguard Worker                nintermediates = 3
967*da0073e9SAndroid Build Coastguard Worker                if optim_cls.__name__ == "NAdam":
968*da0073e9SAndroid Build Coastguard Worker                    # with capturable in NAdam, we have 3 extra intermediates for the
969*da0073e9SAndroid Build Coastguard Worker                    # bias_correction, mus, and mu_nexts
970*da0073e9SAndroid Build Coastguard Worker                    if TEST_WITH_TORCHDYNAMO:
971*da0073e9SAndroid Build Coastguard Worker                        # With dynamo, the eager/FX backend appears to hold memory longer than
972*da0073e9SAndroid Build Coastguard Worker                        # vanilla eager: https://github.com/pytorch/pytorch/issues/125511
973*da0073e9SAndroid Build Coastguard Worker                        nintermediates = 8
974*da0073e9SAndroid Build Coastguard Worker                    else:
975*da0073e9SAndroid Build Coastguard Worker                        nintermediates = 5
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker                if optim_cls.__name__ == "RAdam":
978*da0073e9SAndroid Build Coastguard Worker                    # RAdam has four intermediates with capturable
979*da0073e9SAndroid Build Coastguard Worker                    # num, unrect_step_size, buffer, grouped_grads
980*da0073e9SAndroid Build Coastguard Worker                    if TEST_WITH_TORCHDYNAMO:
981*da0073e9SAndroid Build Coastguard Worker                        # With dynamo, the eager/FX backend appears to hold memory than
982*da0073e9SAndroid Build Coastguard Worker                        # vanilla eager: https://github.com/pytorch/pytorch/issues/125511
983*da0073e9SAndroid Build Coastguard Worker                        nintermediates = 6
984*da0073e9SAndroid Build Coastguard Worker                    else:
985*da0073e9SAndroid Build Coastguard Worker                        nintermediates = 4
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker            elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop", "Adafactor"]:
988*da0073e9SAndroid Build Coastguard Worker                # NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
989*da0073e9SAndroid Build Coastguard Worker                # Adagrad uses std and grads at the same time
990*da0073e9SAndroid Build Coastguard Worker                # RMSprop uses avg and grads
991*da0073e9SAndroid Build Coastguard Worker                # Adafactor uses row/col var and its mean
992*da0073e9SAndroid Build Coastguard Worker                nintermediates = 2
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker                if optim_cls.__name__ == "Adafactor" and kwargs.get("maximize", False):
995*da0073e9SAndroid Build Coastguard Worker                    # When maximize is True, Adafactor also tracks device_grad
996*da0073e9SAndroid Build Coastguard Worker                    nintermediates = 3
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker            # Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam
999*da0073e9SAndroid Build Coastguard Worker            # which makes the foreach memory check fail
1000*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_TORCHDYNAMO:
1001*da0073e9SAndroid Build Coastguard Worker                st_max_mem += 6000
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker            expected_max_mem = st_max_mem + intermediate_size * nintermediates
1004*da0073e9SAndroid Build Coastguard Worker            # hipcc currently can't generate efficient code for the small buffer optimization
1005*da0073e9SAndroid Build Coastguard Worker            # code path (see Note [small buffer optimization] for details), thus we always
1006*da0073e9SAndroid Build Coastguard Worker            # dynamically allocate the tensor metadata for ROCM. Adjusting the expected max
1007*da0073e9SAndroid Build Coastguard Worker            # memory usage to account for this.
1008*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_ROCM:
1009*da0073e9SAndroid Build Coastguard Worker                expected_max_mem *= 1.02
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker            self.assertLessEqual(mt_max_mem, expected_max_mem)
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker    @optims(
1014*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "fused" in optim.supported_impls],
1015*da0073e9SAndroid Build Coastguard Worker        dtypes=floating_types_and(
1016*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
1017*da0073e9SAndroid Build Coastguard Worker            torch.float16,
1018*da0073e9SAndroid Build Coastguard Worker        ),
1019*da0073e9SAndroid Build Coastguard Worker    )
1020*da0073e9SAndroid Build Coastguard Worker    def test_fused_matches_forloop(self, device, dtype, optim_info):
1021*da0073e9SAndroid Build Coastguard Worker        if _get_device_type(device) not in optim_info.supports_fused_on:
1022*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
1023*da0073e9SAndroid Build Coastguard Worker                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1024*da0073e9SAndroid Build Coastguard Worker            )
1025*da0073e9SAndroid Build Coastguard Worker        if _get_device_type(device) == "mps" and dtype not in (
1026*da0073e9SAndroid Build Coastguard Worker            torch.float16,
1027*da0073e9SAndroid Build Coastguard Worker            torch.float32,
1028*da0073e9SAndroid Build Coastguard Worker        ):
1029*da0073e9SAndroid Build Coastguard Worker            self.skipTest("MPS supports only torch.float16 and torch.float32")
1030*da0073e9SAndroid Build Coastguard Worker        self._test_derived_optimizers(device, dtype, optim_info, "fused")
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker    @optims(
1033*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "fused" in optim.supported_impls],
1034*da0073e9SAndroid Build Coastguard Worker        dtypes=(torch.float32,),
1035*da0073e9SAndroid Build Coastguard Worker    )
1036*da0073e9SAndroid Build Coastguard Worker    def test_fused_error_on_params_on_meta(self, device, dtype, optim_info):
1037*da0073e9SAndroid Build Coastguard Worker        if _get_device_type(device) not in optim_info.supports_fused_on:
1038*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
1039*da0073e9SAndroid Build Coastguard Worker                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1040*da0073e9SAndroid Build Coastguard Worker            )
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker        with torch.device("meta"):
1043*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.Sequential(
1044*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(2, 3),
1045*da0073e9SAndroid Build Coastguard Worker                torch.nn.Sigmoid(),
1046*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(3, 1),
1047*da0073e9SAndroid Build Coastguard Worker                torch.nn.Sigmoid(),
1048*da0073e9SAndroid Build Coastguard Worker            ).to(dtype)
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker        optimizer = optim_info.optim_cls(model.parameters(), fused=True)
1051*da0073e9SAndroid Build Coastguard Worker        with torch.device("meta"):
1052*da0073e9SAndroid Build Coastguard Worker            for p in model.parameters():
1053*da0073e9SAndroid Build Coastguard Worker                p.grad = torch.rand_like(p)
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1056*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1057*da0073e9SAndroid Build Coastguard Worker            "`fused=True` requires all the params to be floating point Tensors",
1058*da0073e9SAndroid Build Coastguard Worker        ):
1059*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        optimizer.zero_grad(set_to_none=True)
1062*da0073e9SAndroid Build Coastguard Worker        model.to_empty(device=device)
1063*da0073e9SAndroid Build Coastguard Worker        for p in model.parameters():
1064*da0073e9SAndroid Build Coastguard Worker            p.grad = torch.rand_like(p)
1065*da0073e9SAndroid Build Coastguard Worker        optimizer.step()
1066*da0073e9SAndroid Build Coastguard Worker
1067*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1068*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("64GB")
1069*da0073e9SAndroid Build Coastguard Worker    @optims(
1070*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "fused" in optim.supported_impls],
1071*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float16],
1072*da0073e9SAndroid Build Coastguard Worker    )
1073*da0073e9SAndroid Build Coastguard Worker    def test_fused_large_tensor(self, device, dtype, optim_info):
1074*da0073e9SAndroid Build Coastguard Worker        if device not in optim_info.supports_fused_on:
1075*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
1076*da0073e9SAndroid Build Coastguard Worker                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1077*da0073e9SAndroid Build Coastguard Worker            )
1078*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1079*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device)
1080*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
1081*da0073e9SAndroid Build Coastguard Worker            params = [torch.ones(2**32, device=device, dtype=dtype)]
1082*da0073e9SAndroid Build Coastguard Worker            params[0].grad = torch.zeros_like(params[0])
1083*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1084*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
1085*da0073e9SAndroid Build Coastguard Worker
1086*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1087*da0073e9SAndroid Build Coastguard Worker    @optims(
1088*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "fused" in optim.supported_impls],
1089*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
1090*da0073e9SAndroid Build Coastguard Worker    )
1091*da0073e9SAndroid Build Coastguard Worker    def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info):
1092*da0073e9SAndroid Build Coastguard Worker        if device not in optim_info.supports_fused_on:
1093*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
1094*da0073e9SAndroid Build Coastguard Worker                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
1095*da0073e9SAndroid Build Coastguard Worker            )
1096*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1097*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device=device)
1098*da0073e9SAndroid Build Coastguard Worker        num_params = 5
1099*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
1100*da0073e9SAndroid Build Coastguard Worker            for no_grad_scale in (False, True):
1101*da0073e9SAndroid Build Coastguard Worker                params = [
1102*da0073e9SAndroid Build Coastguard Worker                    torch.ones((1,), device=device, dtype=dtype)
1103*da0073e9SAndroid Build Coastguard Worker                    for _ in range(num_params)
1104*da0073e9SAndroid Build Coastguard Worker                ]
1105*da0073e9SAndroid Build Coastguard Worker                params_c = [param.clone().detach() for param in params]
1106*da0073e9SAndroid Build Coastguard Worker                for p in params:
1107*da0073e9SAndroid Build Coastguard Worker                    p.grad = torch.ones_like(p)
1108*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
1109*da0073e9SAndroid Build Coastguard Worker                optimizer.grad_scale = (
1110*da0073e9SAndroid Build Coastguard Worker                    None
1111*da0073e9SAndroid Build Coastguard Worker                    if no_grad_scale
1112*da0073e9SAndroid Build Coastguard Worker                    else torch.ones((1,), dtype=dtype, device=device)
1113*da0073e9SAndroid Build Coastguard Worker                )
1114*da0073e9SAndroid Build Coastguard Worker                optimizer.found_inf = torch.ones((), dtype=dtype, device=device)
1115*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
1116*da0073e9SAndroid Build Coastguard Worker                for p in params:
1117*da0073e9SAndroid Build Coastguard Worker                    if "step" in optimizer.state[p]:
1118*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
1119*da0073e9SAndroid Build Coastguard Worker                            torch.zeros((), dtype=dtype, device=device),
1120*da0073e9SAndroid Build Coastguard Worker                            optimizer.state[p]["step"],
1121*da0073e9SAndroid Build Coastguard Worker                        )
1122*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(params, params_c)
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker    @parametrize("impl", ["fused", "capturable"])
1125*da0073e9SAndroid Build Coastguard Worker    @optims(
1126*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if "fused" in optim.supported_impls],
1127*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
1128*da0073e9SAndroid Build Coastguard Worker    )
1129*da0073e9SAndroid Build Coastguard Worker    def test_cpu_load_state_dict(self, device, dtype, impl, optim_info):
1130*da0073e9SAndroid Build Coastguard Worker        # NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
1131*da0073e9SAndroid Build Coastguard Worker        # How do we get there? Users typically create CUDA models on fused optimizers and then
1132*da0073e9SAndroid Build Coastguard Worker        # store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
1133*da0073e9SAndroid Build Coastguard Worker        # Since this is a unit test, it is more expedient to simulate what the state_dict
1134*da0073e9SAndroid Build Coastguard Worker        # would look like, which is basically CPU tensors with fused/capturable flag = True.
1135*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1136*da0073e9SAndroid Build Coastguard Worker        opt_name = optim_cls.__name__
1137*da0073e9SAndroid Build Coastguard Worker        if opt_name in ("SGD", "Adagrad") and impl == "capturable":
1138*da0073e9SAndroid Build Coastguard Worker            # Capturable SGD/Adagrad does not exist
1139*da0073e9SAndroid Build Coastguard Worker            self.skipTest("SGD does not currently support capturable")
1140*da0073e9SAndroid Build Coastguard Worker        if _get_device_type(device) == "cpu":
1141*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Test is only for non-cpu devices")
1142*da0073e9SAndroid Build Coastguard Worker        elif (
1143*da0073e9SAndroid Build Coastguard Worker            impl == "fused"
1144*da0073e9SAndroid Build Coastguard Worker            and _get_device_type(device) not in optim_info.supports_fused_on
1145*da0073e9SAndroid Build Coastguard Worker        ):
1146*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{device} is not supported for fused on {opt_name}")
1147*da0073e9SAndroid Build Coastguard Worker        elif impl == "capturable" and _get_device_type(device) == "mps":
1148*da0073e9SAndroid Build Coastguard Worker            self.skipTest("MPS does not support capturable")
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker        cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
1151*da0073e9SAndroid Build Coastguard Worker        for optim_input in cpu_optim_inputs:
1152*da0073e9SAndroid Build Coastguard Worker            param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu")
1153*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls([param], **optim_input.kwargs)
1154*da0073e9SAndroid Build Coastguard Worker            param.grad = torch.rand_like(param)
1155*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
1156*da0073e9SAndroid Build Coastguard Worker            optim_state_dict_cpu = deepcopy(optimizer.state_dict())
1157*da0073e9SAndroid Build Coastguard Worker            optim_state_dict_cpu["param_groups"][0][impl] = True
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker            # load
1160*da0073e9SAndroid Build Coastguard Worker            optim_input.kwargs[impl] = True
1161*da0073e9SAndroid Build Coastguard Worker            param_device = param.clone().detach().to(device=device)
1162*da0073e9SAndroid Build Coastguard Worker            optimizer_device = optim_cls([param_device], **optim_input.kwargs)
1163*da0073e9SAndroid Build Coastguard Worker            optimizer_device.load_state_dict(optim_state_dict_cpu)
1164*da0073e9SAndroid Build Coastguard Worker            optimizer_device.zero_grad()
1165*da0073e9SAndroid Build Coastguard Worker            param_device.grad = torch.rand_like(param_device)
1166*da0073e9SAndroid Build Coastguard Worker            optimizer_device.step()
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1169*da0073e9SAndroid Build Coastguard Worker    def test_param_groups_weight_decay(self, device, dtype, optim_info):
1170*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1171*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1172*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1173*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
1174*da0073e9SAndroid Build Coastguard Worker        )
1175*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1176*da0073e9SAndroid Build Coastguard Worker            weight_kwargs = optim_input.kwargs
1177*da0073e9SAndroid Build Coastguard Worker            bias_kwargs = deepcopy(optim_input.kwargs)
1178*da0073e9SAndroid Build Coastguard Worker            bias_kwargs["weight_decay"] = 0.0
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1181*da0073e9SAndroid Build Coastguard Worker            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1182*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(5, device=device, dtype=dtype)
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(
1185*da0073e9SAndroid Build Coastguard Worker                [
1186*da0073e9SAndroid Build Coastguard Worker                    dict(params=[weight], **weight_kwargs),
1187*da0073e9SAndroid Build Coastguard Worker                    dict(params=[bias], **bias_kwargs),
1188*da0073e9SAndroid Build Coastguard Worker                ]
1189*da0073e9SAndroid Build Coastguard Worker            )
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker            loss = (weight.mv(input) + bias).pow(2).sum()
1192*da0073e9SAndroid Build Coastguard Worker            initial_value = loss.item()
1193*da0073e9SAndroid Build Coastguard Worker            for _ in range(20):
1194*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
1195*da0073e9SAndroid Build Coastguard Worker                loss = (weight.mv(input) + bias).pow(2).sum()
1196*da0073e9SAndroid Build Coastguard Worker                loss.backward()
1197*da0073e9SAndroid Build Coastguard Worker                if optim_info.only_supports_sparse_grads:
1198*da0073e9SAndroid Build Coastguard Worker                    # For this test, we naively convert the Tensor layout, which we know does
1199*da0073e9SAndroid Build Coastguard Worker                    # NOT represent the expected use case for optims like SparseAdam!
1200*da0073e9SAndroid Build Coastguard Worker                    weight.grad = weight.grad.to_sparse()
1201*da0073e9SAndroid Build Coastguard Worker                    bias.grad = bias.grad.to_sparse()
1202*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker            # Test that the direction of loss moved appropriately
1205*da0073e9SAndroid Build Coastguard Worker            if optim_input.kwargs.get("maximize", False):
1206*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(loss.item(), initial_value)
1207*da0073e9SAndroid Build Coastguard Worker            else:
1208*da0073e9SAndroid Build Coastguard Worker                self.assertLess(loss.item(), initial_value)
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1211*da0073e9SAndroid Build Coastguard Worker    def test_param_groups_lr(self, device, dtype, optim_info):
1212*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1213*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1214*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1215*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
1216*da0073e9SAndroid Build Coastguard Worker        )
1217*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1218*da0073e9SAndroid Build Coastguard Worker            # optim_input.kwargs will be the param group kwargs, which should have >0 lr
1219*da0073e9SAndroid Build Coastguard Worker            if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
1220*da0073e9SAndroid Build Coastguard Worker                optim_input.kwargs["lr"] = 1e-3
1221*da0073e9SAndroid Build Coastguard Worker            outer_kwargs = {"lr": 1e-28}
1222*da0073e9SAndroid Build Coastguard Worker            if optim_cls.__name__ == "Rprop":
1223*da0073e9SAndroid Build Coastguard Worker                # Allow min step size to be 0
1224*da0073e9SAndroid Build Coastguard Worker                outer_kwargs["step_sizes"] = (0, 50)
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker            weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
1227*da0073e9SAndroid Build Coastguard Worker            bias = Parameter(torch.randn((10), device=device, dtype=dtype))
1228*da0073e9SAndroid Build Coastguard Worker            irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype))
1229*da0073e9SAndroid Build Coastguard Worker            irrelevant_clone = irrelevant.clone()
1230*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(5, device=device, dtype=dtype)
1231*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(
1232*da0073e9SAndroid Build Coastguard Worker                [
1233*da0073e9SAndroid Build Coastguard Worker                    dict(params=[weight, bias], **optim_input.kwargs),
1234*da0073e9SAndroid Build Coastguard Worker                    dict(params=[irrelevant]),
1235*da0073e9SAndroid Build Coastguard Worker                ],
1236*da0073e9SAndroid Build Coastguard Worker                **outer_kwargs,
1237*da0073e9SAndroid Build Coastguard Worker            )
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker            loss = (weight.mv(input) + bias).pow(2).sum()
1240*da0073e9SAndroid Build Coastguard Worker            initial_value = loss.item()
1241*da0073e9SAndroid Build Coastguard Worker            for _ in range(20):
1242*da0073e9SAndroid Build Coastguard Worker                optimizer.zero_grad()
1243*da0073e9SAndroid Build Coastguard Worker                loss = (weight.mv(input) + bias).pow(2).sum()
1244*da0073e9SAndroid Build Coastguard Worker                loss.backward()
1245*da0073e9SAndroid Build Coastguard Worker                irrelevant.grad = torch.rand_like(irrelevant)
1246*da0073e9SAndroid Build Coastguard Worker                if optim_info.only_supports_sparse_grads:
1247*da0073e9SAndroid Build Coastguard Worker                    # For this test, we naively convert the Tensor layout, which we know does
1248*da0073e9SAndroid Build Coastguard Worker                    # NOT represent the expected use case for optims like SparseAdam!
1249*da0073e9SAndroid Build Coastguard Worker                    weight.grad = weight.grad.to_sparse()
1250*da0073e9SAndroid Build Coastguard Worker                    bias.grad = bias.grad.to_sparse()
1251*da0073e9SAndroid Build Coastguard Worker                    irrelevant.grad = irrelevant.grad.to_sparse()
1252*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Worker            # Test that the direction of loss moved appropriately
1255*da0073e9SAndroid Build Coastguard Worker            if optim_input.kwargs.get("maximize", False):
1256*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(loss.item(), initial_value)
1257*da0073e9SAndroid Build Coastguard Worker            else:
1258*da0073e9SAndroid Build Coastguard Worker                self.assertLess(loss.item(), initial_value)
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker            # Test that irrelevant parameters were not updated since lr was almost 0
1261*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(irrelevant, irrelevant_clone)
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1264*da0073e9SAndroid Build Coastguard Worker    def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
1265*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1266*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1267*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1268*da0073e9SAndroid Build Coastguard Worker        )
1269*da0073e9SAndroid Build Coastguard Worker        params = [
1270*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, requires_grad=False, device=device, dtype=dtype)
1271*da0073e9SAndroid Build Coastguard Worker            for _ in range(2)
1272*da0073e9SAndroid Build Coastguard Worker        ]
1273*da0073e9SAndroid Build Coastguard Worker        old_params = [p.clone().detach() for p in params]
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker        def closure():
1276*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([1], device=device, dtype=dtype)
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1279*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **optim_input.kwargs)
1280*da0073e9SAndroid Build Coastguard Worker            optimizer.step(closure)
1281*da0073e9SAndroid Build Coastguard Worker
1282*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1283*da0073e9SAndroid Build Coastguard Worker    def test_step_is_noop_for_zero_grads(self, device, dtype, optim_info):
1284*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1285*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1286*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1287*da0073e9SAndroid Build Coastguard Worker        )
1288*da0073e9SAndroid Build Coastguard Worker        param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True)
1289*da0073e9SAndroid Build Coastguard Worker        old_param = param.clone().detach()
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker        def closure():
1292*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([1], device=device, dtype=dtype)
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1295*da0073e9SAndroid Build Coastguard Worker            kwargs = optim_input.kwargs
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker            # params will decay even if grads are empty if weight_decay != 0,
1298*da0073e9SAndroid Build Coastguard Worker            # and capturable doesn't work for CPU tensors
1299*da0073e9SAndroid Build Coastguard Worker            if kwargs.get("weight_decay", 0) != 0:
1300*da0073e9SAndroid Build Coastguard Worker                continue
1301*da0073e9SAndroid Build Coastguard Worker
1302*da0073e9SAndroid Build Coastguard Worker            # AdamW params will be updated regardless of grads due to lr, so make lr smaller
1303*da0073e9SAndroid Build Coastguard Worker            if optim_cls.__name__ == "AdamW":
1304*da0073e9SAndroid Build Coastguard Worker                kwargs["lr"] = (
1305*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(1e-5)
1306*da0073e9SAndroid Build Coastguard Worker                    if isinstance(kwargs.get("lr", 1e-5), torch.Tensor)
1307*da0073e9SAndroid Build Coastguard Worker                    else 1e-5
1308*da0073e9SAndroid Build Coastguard Worker                )
1309*da0073e9SAndroid Build Coastguard Worker
1310*da0073e9SAndroid Build Coastguard Worker            if kwargs.get("differentiable", False):
1311*da0073e9SAndroid Build Coastguard Worker                params = [param.clone()]
1312*da0073e9SAndroid Build Coastguard Worker            else:
1313*da0073e9SAndroid Build Coastguard Worker                params = [param]
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **kwargs)
1316*da0073e9SAndroid Build Coastguard Worker            if optim_info.only_supports_sparse_grads:
1317*da0073e9SAndroid Build Coastguard Worker                # Intentionally construct a multidimensional empty v for the sparse grad
1318*da0073e9SAndroid Build Coastguard Worker                # Single dim v passes the test while multidim correctly repros the issue
1319*da0073e9SAndroid Build Coastguard Worker                # https://github.com/pytorch/pytorch/issues/82486
1320*da0073e9SAndroid Build Coastguard Worker                i = torch.empty((1, 0), device=device, dtype=dtype)
1321*da0073e9SAndroid Build Coastguard Worker                v = torch.empty((0, 1), device=device, dtype=dtype)
1322*da0073e9SAndroid Build Coastguard Worker                params[0].grad = torch.sparse_coo_tensor(
1323*da0073e9SAndroid Build Coastguard Worker                    i, v, (5, 1), device=device, dtype=dtype
1324*da0073e9SAndroid Build Coastguard Worker                )
1325*da0073e9SAndroid Build Coastguard Worker            else:
1326*da0073e9SAndroid Build Coastguard Worker                params[0].grad = torch.zeros_like(params[0])
1327*da0073e9SAndroid Build Coastguard Worker            optimizer.step(closure)
1328*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(old_param, params[0])
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1331*da0073e9SAndroid Build Coastguard Worker    def test_optimizer_can_be_printed(self, device, dtype, optim_info):
1332*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1333*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1334*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1335*da0073e9SAndroid Build Coastguard Worker        )
1336*da0073e9SAndroid Build Coastguard Worker        params = [
1337*da0073e9SAndroid Build Coastguard Worker            Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype))
1338*da0073e9SAndroid Build Coastguard Worker            for _ in range(2)
1339*da0073e9SAndroid Build Coastguard Worker        ]
1340*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1341*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **optim_input.kwargs)
1342*da0073e9SAndroid Build Coastguard Worker            optimizer.__repr__()
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1345*da0073e9SAndroid Build Coastguard Worker    def test_state_dict_deterministic(self, device, dtype, optim_info):
1346*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1347*da0073e9SAndroid Build Coastguard Worker
1348*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1349*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1350*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
1351*da0073e9SAndroid Build Coastguard Worker        )
1352*da0073e9SAndroid Build Coastguard Worker        weight = Parameter(
1353*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1354*da0073e9SAndroid Build Coastguard Worker        )
1355*da0073e9SAndroid Build Coastguard Worker        bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1356*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1357*da0073e9SAndroid Build Coastguard Worker        params = [weight, bias]
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker        def fwd_bwd(optim, w, b, i):
1360*da0073e9SAndroid Build Coastguard Worker            optim.zero_grad()
1361*da0073e9SAndroid Build Coastguard Worker            loss = (w.mv(i) + b).pow(2).sum()
1362*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1363*da0073e9SAndroid Build Coastguard Worker            if optim_info.only_supports_sparse_grads:
1364*da0073e9SAndroid Build Coastguard Worker                if w.grad is not None:
1365*da0073e9SAndroid Build Coastguard Worker                    w.grad = w.grad.to_sparse()
1366*da0073e9SAndroid Build Coastguard Worker                if b.grad is not None:
1367*da0073e9SAndroid Build Coastguard Worker                    b.grad = b.grad.to_sparse()
1368*da0073e9SAndroid Build Coastguard Worker            return loss
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1371*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **optim_input.kwargs)
1372*da0073e9SAndroid Build Coastguard Worker            closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker            # Prime the optimizer
1375*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
1376*da0073e9SAndroid Build Coastguard Worker                if optim_info.step_requires_closure:
1377*da0073e9SAndroid Build Coastguard Worker                    optimizer.step(closure)
1378*da0073e9SAndroid Build Coastguard Worker                else:
1379*da0073e9SAndroid Build Coastguard Worker                    closure()
1380*da0073e9SAndroid Build Coastguard Worker                    optimizer.step()
1381*da0073e9SAndroid Build Coastguard Worker
1382*da0073e9SAndroid Build Coastguard Worker            # Clone the weights and construct a new optimizer for them
1383*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1384*da0073e9SAndroid Build Coastguard Worker                weight_c = Parameter(weight.clone())
1385*da0073e9SAndroid Build Coastguard Worker                bias_c = Parameter(bias.clone())
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker            optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs)
1388*da0073e9SAndroid Build Coastguard Worker            closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker            # Load the state dict from the original optimizer into the new one
1391*da0073e9SAndroid Build Coastguard Worker            optimizer_c.load_state_dict(deepcopy(optimizer.state_dict()))
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker            # Run both optimizers in parallel
1394*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
1395*da0073e9SAndroid Build Coastguard Worker                if optim_info.step_requires_closure:
1396*da0073e9SAndroid Build Coastguard Worker                    optimizer.step(closure)
1397*da0073e9SAndroid Build Coastguard Worker                    optimizer_c.step(closure_c)
1398*da0073e9SAndroid Build Coastguard Worker                else:
1399*da0073e9SAndroid Build Coastguard Worker                    closure()
1400*da0073e9SAndroid Build Coastguard Worker                    closure_c()
1401*da0073e9SAndroid Build Coastguard Worker                    optimizer.step()
1402*da0073e9SAndroid Build Coastguard Worker                    optimizer_c.step()
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(weight, weight_c)
1405*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bias, bias_c)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker            # Make sure state dict is deterministic with equal (not identical) parameters
1408*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker            # Make sure repeated parameters have identical representation (see #36831)
1411*da0073e9SAndroid Build Coastguard Worker            optimizer_c.param_groups.extend(optimizer_c.param_groups)
1412*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1413*da0073e9SAndroid Build Coastguard Worker                optimizer.state_dict()["param_groups"][-1],
1414*da0073e9SAndroid Build Coastguard Worker                optimizer_c.state_dict()["param_groups"][-1],
1415*da0073e9SAndroid Build Coastguard Worker            )
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1418*da0073e9SAndroid Build Coastguard Worker    def test_can_load_older_state_dict(self, device, dtype, optim_info):
1419*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1422*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1423*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
1424*da0073e9SAndroid Build Coastguard Worker        )
1425*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1426*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(1)
1427*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.Sequential(
1428*da0073e9SAndroid Build Coastguard Worker                torch.nn.Conv2d(4, 2, 1, stride=2),
1429*da0073e9SAndroid Build Coastguard Worker                torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
1430*da0073e9SAndroid Build Coastguard Worker            )
1431*da0073e9SAndroid Build Coastguard Worker            model.to(dtype=dtype, device=device)
1432*da0073e9SAndroid Build Coastguard Worker            input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
1433*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(model.parameters(), **optim_input.kwargs)
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker            def fwd_bwd(optim, mod, i):
1436*da0073e9SAndroid Build Coastguard Worker                optim.zero_grad()
1437*da0073e9SAndroid Build Coastguard Worker                loss = mod(i).sum()
1438*da0073e9SAndroid Build Coastguard Worker                loss.backward()
1439*da0073e9SAndroid Build Coastguard Worker                return loss
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
1442*da0073e9SAndroid Build Coastguard Worker                if optim_info.step_requires_closure:
1443*da0073e9SAndroid Build Coastguard Worker                    optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1444*da0073e9SAndroid Build Coastguard Worker                else:
1445*da0073e9SAndroid Build Coastguard Worker                    fwd_bwd(optimizer, model, input)
1446*da0073e9SAndroid Build Coastguard Worker                    optimizer.step()
1447*da0073e9SAndroid Build Coastguard Worker
1448*da0073e9SAndroid Build Coastguard Worker            # old_state_dict has all new flags del'd
1449*da0073e9SAndroid Build Coastguard Worker            old_state_dict = deepcopy(optimizer.state_dict())
1450*da0073e9SAndroid Build Coastguard Worker            old_state_dict_pg = old_state_dict["param_groups"]
1451*da0073e9SAndroid Build Coastguard Worker            for group in old_state_dict_pg:
1452*da0073e9SAndroid Build Coastguard Worker                for flag in optim_info.not_og_supported_flags:
1453*da0073e9SAndroid Build Coastguard Worker                    if flag in group:
1454*da0073e9SAndroid Build Coastguard Worker                        del group[flag]
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker            optimizer.load_state_dict(old_state_dict)
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker            # Make sure we can still step
1459*da0073e9SAndroid Build Coastguard Worker            if optim_info.step_requires_closure:
1460*da0073e9SAndroid Build Coastguard Worker                optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1461*da0073e9SAndroid Build Coastguard Worker            else:
1462*da0073e9SAndroid Build Coastguard Worker                fwd_bwd(optimizer, model, input)
1463*da0073e9SAndroid Build Coastguard Worker                optimizer.step()
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1466*da0073e9SAndroid Build Coastguard Worker    def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
1467*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1470*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1471*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
1472*da0073e9SAndroid Build Coastguard Worker        )
1473*da0073e9SAndroid Build Coastguard Worker        weight = Parameter(
1474*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
1475*da0073e9SAndroid Build Coastguard Worker        )
1476*da0073e9SAndroid Build Coastguard Worker        bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
1477*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
1478*da0073e9SAndroid Build Coastguard Worker        params = [weight, bias]
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        def fwd_bwd(optim, w, b, i):
1481*da0073e9SAndroid Build Coastguard Worker            optim.zero_grad()
1482*da0073e9SAndroid Build Coastguard Worker            loss = (w.mv(i) + b).pow(2).sum()
1483*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1484*da0073e9SAndroid Build Coastguard Worker            if optim_info.only_supports_sparse_grads:
1485*da0073e9SAndroid Build Coastguard Worker                weight.grad = weight.grad.to_sparse()
1486*da0073e9SAndroid Build Coastguard Worker                bias.grad = bias.grad.to_sparse()
1487*da0073e9SAndroid Build Coastguard Worker            return loss
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1490*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **optim_input.kwargs)
1491*da0073e9SAndroid Build Coastguard Worker            closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker            # Prime the optimizer
1494*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
1495*da0073e9SAndroid Build Coastguard Worker                optimizer.step(closure)
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker            sd = optimizer.state_dict()
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker            # === Check saved/loaded state_dict are the same (including weights_only load). ===
1500*da0073e9SAndroid Build Coastguard Worker            with tempfile.TemporaryFile() as f:
1501*da0073e9SAndroid Build Coastguard Worker                torch.save(sd, f)
1502*da0073e9SAndroid Build Coastguard Worker                f.seek(0)
1503*da0073e9SAndroid Build Coastguard Worker                sd_copy = torch.load(f)
1504*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sd_copy, sd)
1505*da0073e9SAndroid Build Coastguard Worker                del sd_copy
1506*da0073e9SAndroid Build Coastguard Worker                f.seek(0)
1507*da0073e9SAndroid Build Coastguard Worker                sd_copy_wo = torch.load(f, weights_only=True)
1508*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sd_copy_wo, sd)
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1511*da0073e9SAndroid Build Coastguard Worker    def test_load_nontensor_step(self, device, dtype, optim_info):
1512*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1515*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1516*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
1517*da0073e9SAndroid Build Coastguard Worker        )
1518*da0073e9SAndroid Build Coastguard Worker        params = [
1519*da0073e9SAndroid Build Coastguard Worker            Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1520*da0073e9SAndroid Build Coastguard Worker        ]
1521*da0073e9SAndroid Build Coastguard Worker        for p in params:
1522*da0073e9SAndroid Build Coastguard Worker            p.grad = torch.rand_like(p)
1523*da0073e9SAndroid Build Coastguard Worker            if optim_info.only_supports_sparse_grads:
1524*da0073e9SAndroid Build Coastguard Worker                # For this test, we naively convert the Tensor layout, which we know does
1525*da0073e9SAndroid Build Coastguard Worker                # NOT represent the expected use case for optims like SparseAdam!
1526*da0073e9SAndroid Build Coastguard Worker                p.grad = p.grad.to_sparse()
1527*da0073e9SAndroid Build Coastguard Worker
1528*da0073e9SAndroid Build Coastguard Worker        # Needed for second order optims like LBFGS
1529*da0073e9SAndroid Build Coastguard Worker        closure_loss = torch.rand(1, device=device, dtype=dtype)
1530*da0073e9SAndroid Build Coastguard Worker
1531*da0073e9SAndroid Build Coastguard Worker        def closure():
1532*da0073e9SAndroid Build Coastguard Worker            return closure_loss if optim_info.step_requires_closure else None
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1535*da0073e9SAndroid Build Coastguard Worker            kwargs = optim_input.kwargs
1536*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **optim_input.kwargs)
1537*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
1538*da0073e9SAndroid Build Coastguard Worker                optimizer.step(closure)
1539*da0073e9SAndroid Build Coastguard Worker            state_dict = deepcopy(optimizer.state_dict())
1540*da0073e9SAndroid Build Coastguard Worker            for p_state in state_dict["state"].values():
1541*da0073e9SAndroid Build Coastguard Worker                if "step" in p_state and torch.is_tensor(p_state["step"]):
1542*da0073e9SAndroid Build Coastguard Worker                    p_state["step"] = p_state["step"].item()
1543*da0073e9SAndroid Build Coastguard Worker            optimizer.load_state_dict(state_dict)
1544*da0073e9SAndroid Build Coastguard Worker            optimizer.step(closure)
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1547*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1548*da0073e9SAndroid Build Coastguard Worker    def test_state_dict_with_cuda_params(self, device, dtype, optim_info):
1549*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1552*da0073e9SAndroid Build Coastguard Worker        # We limit our configs to CPU only, because we will be moving them to CUDA later
1553*da0073e9SAndroid Build Coastguard Worker        cpu_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1554*da0073e9SAndroid Build Coastguard Worker            "cpu", dtype, optim_info, skip=("differentiable",)
1555*da0073e9SAndroid Build Coastguard Worker        )
1556*da0073e9SAndroid Build Coastguard Worker
1557*da0073e9SAndroid Build Coastguard Worker        # Needed for second order optims like LBFGS
1558*da0073e9SAndroid Build Coastguard Worker        closure_loss = torch.rand(1, device=device, dtype=dtype)
1559*da0073e9SAndroid Build Coastguard Worker
1560*da0073e9SAndroid Build Coastguard Worker        def closure():
1561*da0073e9SAndroid Build Coastguard Worker            return closure_loss if optim_info.step_requires_closure else None
1562*da0073e9SAndroid Build Coastguard Worker
1563*da0073e9SAndroid Build Coastguard Worker        for optim_input in cpu_optim_inputs:
1564*da0073e9SAndroid Build Coastguard Worker            if (
1565*da0073e9SAndroid Build Coastguard Worker                "fused" in optim_input.kwargs
1566*da0073e9SAndroid Build Coastguard Worker                and "cuda" not in optim_info.supports_fused_on
1567*da0073e9SAndroid Build Coastguard Worker            ):
1568*da0073e9SAndroid Build Coastguard Worker                self.skipTest(
1569*da0073e9SAndroid Build Coastguard Worker                    f"cuda is not supported for fused on {optim_cls.__name__}"
1570*da0073e9SAndroid Build Coastguard Worker                )
1571*da0073e9SAndroid Build Coastguard Worker            params = [
1572*da0073e9SAndroid Build Coastguard Worker                Parameter(torch.randn(2, 3, device="cpu", dtype=dtype))
1573*da0073e9SAndroid Build Coastguard Worker                for _ in range(2)
1574*da0073e9SAndroid Build Coastguard Worker            ]
1575*da0073e9SAndroid Build Coastguard Worker            for p in params:
1576*da0073e9SAndroid Build Coastguard Worker                p.grad = torch.randn_like(p)
1577*da0073e9SAndroid Build Coastguard Worker                if optim_info.only_supports_sparse_grads:
1578*da0073e9SAndroid Build Coastguard Worker                    # For this test, we naively convert the Tensor layout, which we know does
1579*da0073e9SAndroid Build Coastguard Worker                    # NOT represent the expected use case for optims like SparseAdam!
1580*da0073e9SAndroid Build Coastguard Worker                    p.grad = p.grad.to_sparse()
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **optim_input.kwargs)
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
1585*da0073e9SAndroid Build Coastguard Worker                optimizer.step(closure)
1586*da0073e9SAndroid Build Coastguard Worker
1587*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1588*da0073e9SAndroid Build Coastguard Worker                params_cuda = [p.to(device="cuda") for p in params]
1589*da0073e9SAndroid Build Coastguard Worker                for i, p in enumerate(params_cuda):
1590*da0073e9SAndroid Build Coastguard Worker                    p.grad = params[i].grad.to(device="cuda")
1591*da0073e9SAndroid Build Coastguard Worker            optimizer_cuda = optim_cls(params_cuda, **optim_input.kwargs)
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker            state_dict_cpu = deepcopy(optimizer.state_dict())
1594*da0073e9SAndroid Build Coastguard Worker            state_dict_cuda = deepcopy(optimizer.state_dict())
1595*da0073e9SAndroid Build Coastguard Worker            optimizer_cuda.load_state_dict(state_dict_cuda)
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker            # Make sure state_dict_cuda isn't modified by merely calling load_state_dict
1598*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(state_dict_cpu, state_dict_cuda)
1599*da0073e9SAndroid Build Coastguard Worker
1600*da0073e9SAndroid Build Coastguard Worker            # Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
1601*da0073e9SAndroid Build Coastguard Worker            capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
1602*da0073e9SAndroid Build Coastguard Worker            fused = state_dict_cpu["param_groups"][0].get("fused", False)
1603*da0073e9SAndroid Build Coastguard Worker            new_state_dict = optimizer_cuda.state_dict()
1604*da0073e9SAndroid Build Coastguard Worker            for state_cpu, state_cuda in zip(
1605*da0073e9SAndroid Build Coastguard Worker                state_dict_cpu["state"].values(), new_state_dict["state"].values()
1606*da0073e9SAndroid Build Coastguard Worker            ):
1607*da0073e9SAndroid Build Coastguard Worker                if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
1608*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
1609*da0073e9SAndroid Build Coastguard Worker                        state_cuda["step"].device.type,
1610*da0073e9SAndroid Build Coastguard Worker                        "cuda" if capturable or fused else "cpu",
1611*da0073e9SAndroid Build Coastguard Worker                    )
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker            for _ in range(5):
1614*da0073e9SAndroid Build Coastguard Worker                optimizer.step(closure)
1615*da0073e9SAndroid Build Coastguard Worker                optimizer_cuda.step(closure)
1616*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(params, params_cuda)
1617*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1620*da0073e9SAndroid Build Coastguard Worker    def _state_dict_pre_hook(optimizer: Optimizer) -> None:
1621*da0073e9SAndroid Build Coastguard Worker        optimizer.state["test"] = 1
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1624*da0073e9SAndroid Build Coastguard Worker    def _state_dict_post_hook(
1625*da0073e9SAndroid Build Coastguard Worker        optimizer: Optimizer, state_dict: Dict[str, Any]
1626*da0073e9SAndroid Build Coastguard Worker    ) -> Dict[str, Any]:
1627*da0073e9SAndroid Build Coastguard Worker        if "test" in state_dict["state"]:
1628*da0073e9SAndroid Build Coastguard Worker            state_dict["state"].pop("test")
1629*da0073e9SAndroid Build Coastguard Worker            state_dict["ran_state_dict_pre_hook"] = True
1630*da0073e9SAndroid Build Coastguard Worker        else:
1631*da0073e9SAndroid Build Coastguard Worker            state_dict["ran_state_dict_pre_hook"] = False
1632*da0073e9SAndroid Build Coastguard Worker        return state_dict
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1635*da0073e9SAndroid Build Coastguard Worker    def test_state_dict_pre_hook(self, device, dtype, optim_info):
1636*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1637*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1638*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1639*da0073e9SAndroid Build Coastguard Worker        )
1640*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1641*da0073e9SAndroid Build Coastguard Worker            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1642*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls([param], **optim_input.kwargs)
1643*da0073e9SAndroid Build Coastguard Worker            optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1644*da0073e9SAndroid Build Coastguard Worker            state_dict = optim.state_dict()
1645*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(state_dict["state"]["test"], 1)
1646*da0073e9SAndroid Build Coastguard Worker
1647*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1648*da0073e9SAndroid Build Coastguard Worker    def test_state_dict_post_hook(self, device, dtype, optim_info):
1649*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1650*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1651*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1652*da0073e9SAndroid Build Coastguard Worker        )
1653*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1654*da0073e9SAndroid Build Coastguard Worker            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1655*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls([param], **optim_input.kwargs)
1656*da0073e9SAndroid Build Coastguard Worker            optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1657*da0073e9SAndroid Build Coastguard Worker            state_dict = optim.state_dict()
1658*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(state_dict["ran_state_dict_pre_hook"])
1659*da0073e9SAndroid Build Coastguard Worker
1660*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1661*da0073e9SAndroid Build Coastguard Worker    def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
1662*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1663*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1664*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1665*da0073e9SAndroid Build Coastguard Worker        )
1666*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1667*da0073e9SAndroid Build Coastguard Worker            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1668*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls([param], **optim_input.kwargs)
1669*da0073e9SAndroid Build Coastguard Worker            optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
1670*da0073e9SAndroid Build Coastguard Worker            optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
1671*da0073e9SAndroid Build Coastguard Worker            state_dict = optim.state_dict()
1672*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("test" in state_dict["state"])
1673*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(state_dict["ran_state_dict_pre_hook"])
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1676*da0073e9SAndroid Build Coastguard Worker    def _load_state_dict_pre_hook1(
1677*da0073e9SAndroid Build Coastguard Worker        optimizer: Optimizer, state_dict: Dict[str, Any]
1678*da0073e9SAndroid Build Coastguard Worker    ) -> None:
1679*da0073e9SAndroid Build Coastguard Worker        state_dict["param_groups"][0]["lr"] = 0.002
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1682*da0073e9SAndroid Build Coastguard Worker    def _load_state_dict_pre_hook2(
1683*da0073e9SAndroid Build Coastguard Worker        optimizer: Optimizer, state_dict: Dict[str, Any]
1684*da0073e9SAndroid Build Coastguard Worker    ) -> Dict[str, Any]:
1685*da0073e9SAndroid Build Coastguard Worker        # The typical use case for returning a state dict is to drastically modify the state dict.
1686*da0073e9SAndroid Build Coastguard Worker        # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
1687*da0073e9SAndroid Build Coastguard Worker        my_state_dict = deepcopy(state_dict)
1688*da0073e9SAndroid Build Coastguard Worker        my_state_dict["param_groups"][0]["lr"] = 0.003
1689*da0073e9SAndroid Build Coastguard Worker        return my_state_dict
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1692*da0073e9SAndroid Build Coastguard Worker    def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
1693*da0073e9SAndroid Build Coastguard Worker        optimizer.state["ran_load_state_dict_pre_hook2"] = (
1694*da0073e9SAndroid Build Coastguard Worker            optimizer.param_groups[0]["lr"] == 0.003
1695*da0073e9SAndroid Build Coastguard Worker        )
1696*da0073e9SAndroid Build Coastguard Worker        optimizer.state["ran_load_state_dict_post_hook"] = True
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1699*da0073e9SAndroid Build Coastguard Worker    def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
1700*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1701*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1702*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1703*da0073e9SAndroid Build Coastguard Worker        )
1704*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1705*da0073e9SAndroid Build Coastguard Worker            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1706*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls([param], **optim_input.kwargs)
1707*da0073e9SAndroid Build Coastguard Worker            state_dict = optim.state_dict()
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker            # usually one would have a new optim instance here, but it's all the same here
1710*da0073e9SAndroid Build Coastguard Worker            optim.register_load_state_dict_pre_hook(
1711*da0073e9SAndroid Build Coastguard Worker                self.__class__._load_state_dict_pre_hook1
1712*da0073e9SAndroid Build Coastguard Worker            )
1713*da0073e9SAndroid Build Coastguard Worker            optim.load_state_dict(state_dict)
1714*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1715*da0073e9SAndroid Build Coastguard Worker
1716*da0073e9SAndroid Build Coastguard Worker            optim.register_load_state_dict_pre_hook(
1717*da0073e9SAndroid Build Coastguard Worker                self.__class__._load_state_dict_pre_hook2, prepend=True
1718*da0073e9SAndroid Build Coastguard Worker            )
1719*da0073e9SAndroid Build Coastguard Worker            optim.load_state_dict(state_dict)
1720*da0073e9SAndroid Build Coastguard Worker            # If prepend were False would be 0.003 but since prepend is True, the other hook overrides
1721*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1724*da0073e9SAndroid Build Coastguard Worker    def test_load_state_dict_post_hook(self, device, dtype, optim_info):
1725*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1726*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1727*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1728*da0073e9SAndroid Build Coastguard Worker        )
1729*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1730*da0073e9SAndroid Build Coastguard Worker            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1731*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls([param], **optim_input.kwargs)
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker            optim.register_load_state_dict_post_hook(
1734*da0073e9SAndroid Build Coastguard Worker                self.__class__._load_state_dict_post_hook
1735*da0073e9SAndroid Build Coastguard Worker            )
1736*da0073e9SAndroid Build Coastguard Worker            optim.load_state_dict(optim.state_dict())
1737*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
1738*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1739*da0073e9SAndroid Build Coastguard Worker
1740*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1741*da0073e9SAndroid Build Coastguard Worker    def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
1742*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1743*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1744*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1745*da0073e9SAndroid Build Coastguard Worker        )
1746*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1747*da0073e9SAndroid Build Coastguard Worker            param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1748*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls([param], **optim_input.kwargs)
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker            optim.register_load_state_dict_pre_hook(
1751*da0073e9SAndroid Build Coastguard Worker                self.__class__._load_state_dict_pre_hook2
1752*da0073e9SAndroid Build Coastguard Worker            )
1753*da0073e9SAndroid Build Coastguard Worker            optim.register_load_state_dict_post_hook(
1754*da0073e9SAndroid Build Coastguard Worker                self.__class__._load_state_dict_post_hook
1755*da0073e9SAndroid Build Coastguard Worker            )
1756*da0073e9SAndroid Build Coastguard Worker            optim.load_state_dict(optim.state_dict())
1757*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
1758*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1761*da0073e9SAndroid Build Coastguard Worker    def test_step_post_hook(self, device, dtype, optim_info):
1762*da0073e9SAndroid Build Coastguard Worker        def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1763*da0073e9SAndroid Build Coastguard Worker            nonlocal data
1764*da0073e9SAndroid Build Coastguard Worker            data += 2
1765*da0073e9SAndroid Build Coastguard Worker
1766*da0073e9SAndroid Build Coastguard Worker        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1767*da0073e9SAndroid Build Coastguard Worker
1768*da0073e9SAndroid Build Coastguard Worker        def dummy_closure():
1769*da0073e9SAndroid Build Coastguard Worker            return 1
1770*da0073e9SAndroid Build Coastguard Worker
1771*da0073e9SAndroid Build Coastguard Worker        closure = dummy_closure if optim_info.step_requires_closure else None
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1774*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1775*da0073e9SAndroid Build Coastguard Worker        )
1776*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1777*da0073e9SAndroid Build Coastguard Worker            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1778*da0073e9SAndroid Build Coastguard Worker            data = 2
1779*da0073e9SAndroid Build Coastguard Worker            hook_handle = optim.register_step_post_hook(post_hook)
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1782*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1783*da0073e9SAndroid Build Coastguard Worker            # check if post hooks were registered
1784*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(data, 6)
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker            # remove handles, take step and verify that hook is no longer registered
1787*da0073e9SAndroid Build Coastguard Worker            hook_handle.remove()
1788*da0073e9SAndroid Build Coastguard Worker
1789*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1790*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(data, 6)
1791*da0073e9SAndroid Build Coastguard Worker
1792*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1793*da0073e9SAndroid Build Coastguard Worker    def test_step_pre_hook(self, device, dtype, optim_info):
1794*da0073e9SAndroid Build Coastguard Worker        def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1795*da0073e9SAndroid Build Coastguard Worker            nonlocal data
1796*da0073e9SAndroid Build Coastguard Worker            data += 2
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1799*da0073e9SAndroid Build Coastguard Worker
1800*da0073e9SAndroid Build Coastguard Worker        def dummy_closure():
1801*da0073e9SAndroid Build Coastguard Worker            return 1
1802*da0073e9SAndroid Build Coastguard Worker
1803*da0073e9SAndroid Build Coastguard Worker        closure = dummy_closure if optim_info.step_requires_closure else None
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1806*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1807*da0073e9SAndroid Build Coastguard Worker        )
1808*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1809*da0073e9SAndroid Build Coastguard Worker            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1810*da0073e9SAndroid Build Coastguard Worker            data = 5
1811*da0073e9SAndroid Build Coastguard Worker            hook_handle = optim.register_step_pre_hook(pre_hook)
1812*da0073e9SAndroid Build Coastguard Worker
1813*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1814*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1815*da0073e9SAndroid Build Coastguard Worker            # check if pre hooks were registered
1816*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(data, 9)
1817*da0073e9SAndroid Build Coastguard Worker
1818*da0073e9SAndroid Build Coastguard Worker            # remove handles, take step and verify that hook is no longer registered
1819*da0073e9SAndroid Build Coastguard Worker            hook_handle.remove()
1820*da0073e9SAndroid Build Coastguard Worker
1821*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1822*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(data, 9)
1823*da0073e9SAndroid Build Coastguard Worker
1824*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1825*da0073e9SAndroid Build Coastguard Worker    def test_step_all_hooks(self, device, dtype, optim_info):
1826*da0073e9SAndroid Build Coastguard Worker        def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1827*da0073e9SAndroid Build Coastguard Worker            nonlocal data
1828*da0073e9SAndroid Build Coastguard Worker            data.append(0)
1829*da0073e9SAndroid Build Coastguard Worker
1830*da0073e9SAndroid Build Coastguard Worker        def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1831*da0073e9SAndroid Build Coastguard Worker            nonlocal data
1832*da0073e9SAndroid Build Coastguard Worker            data.append(5)
1833*da0073e9SAndroid Build Coastguard Worker
1834*da0073e9SAndroid Build Coastguard Worker        def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1835*da0073e9SAndroid Build Coastguard Worker            nonlocal data
1836*da0073e9SAndroid Build Coastguard Worker            data.append(1)
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
1839*da0073e9SAndroid Build Coastguard Worker            nonlocal data
1840*da0073e9SAndroid Build Coastguard Worker            data.append(2)
1841*da0073e9SAndroid Build Coastguard Worker
1842*da0073e9SAndroid Build Coastguard Worker        params = [torch.tensor([1, 1], device=device, dtype=dtype)]
1843*da0073e9SAndroid Build Coastguard Worker
1844*da0073e9SAndroid Build Coastguard Worker        def dummy_closure():
1845*da0073e9SAndroid Build Coastguard Worker            return 1
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker        closure = dummy_closure if optim_info.step_requires_closure else None
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1850*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info
1851*da0073e9SAndroid Build Coastguard Worker        )
1852*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1853*da0073e9SAndroid Build Coastguard Worker            optim = optim_info.optim_cls(params, **optim_input.kwargs)
1854*da0073e9SAndroid Build Coastguard Worker            optim2 = SGD(params)
1855*da0073e9SAndroid Build Coastguard Worker            data = []
1856*da0073e9SAndroid Build Coastguard Worker
1857*da0073e9SAndroid Build Coastguard Worker            # register global hooks to both optimizers
1858*da0073e9SAndroid Build Coastguard Worker            global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
1859*da0073e9SAndroid Build Coastguard Worker            global_post_handle = register_optimizer_step_post_hook(global_post_hook)
1860*da0073e9SAndroid Build Coastguard Worker
1861*da0073e9SAndroid Build Coastguard Worker            # register local hooks
1862*da0073e9SAndroid Build Coastguard Worker            first_pre_handle = optim.register_step_pre_hook(local_pre_hook)
1863*da0073e9SAndroid Build Coastguard Worker            first_post_handle = optim.register_step_post_hook(local_post_hook)
1864*da0073e9SAndroid Build Coastguard Worker            second_pre_handle = optim2.register_step_pre_hook(local_pre_hook)
1865*da0073e9SAndroid Build Coastguard Worker            second_post_handle = optim2.register_step_post_hook(local_post_hook)
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1868*da0073e9SAndroid Build Coastguard Worker            self.assertListEqual(data, [0, 1, 2, 5])
1869*da0073e9SAndroid Build Coastguard Worker            optim2.step(closure)
1870*da0073e9SAndroid Build Coastguard Worker            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
1871*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1872*da0073e9SAndroid Build Coastguard Worker            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker            # remove all hooks
1875*da0073e9SAndroid Build Coastguard Worker            global_pre_handle.remove()
1876*da0073e9SAndroid Build Coastguard Worker            global_post_handle.remove()
1877*da0073e9SAndroid Build Coastguard Worker            first_pre_handle.remove()
1878*da0073e9SAndroid Build Coastguard Worker            first_post_handle.remove()
1879*da0073e9SAndroid Build Coastguard Worker            second_pre_handle.remove()
1880*da0073e9SAndroid Build Coastguard Worker            second_post_handle.remove()
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Worker            optim.step(closure)
1883*da0073e9SAndroid Build Coastguard Worker            optim2.step(closure)
1884*da0073e9SAndroid Build Coastguard Worker            self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
1885*da0073e9SAndroid Build Coastguard Worker
1886*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
1887*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
1888*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1891*da0073e9SAndroid Build Coastguard Worker        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1892*da0073e9SAndroid Build Coastguard Worker            device, dtype, optim_info, skip=("differentiable",)
1893*da0073e9SAndroid Build Coastguard Worker        )
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker        params = [
1896*da0073e9SAndroid Build Coastguard Worker            Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
1897*da0073e9SAndroid Build Coastguard Worker        ]
1898*da0073e9SAndroid Build Coastguard Worker        for p in params:
1899*da0073e9SAndroid Build Coastguard Worker            p.grad = torch.rand_like(p)
1900*da0073e9SAndroid Build Coastguard Worker            if optim_info.only_supports_sparse_grads:
1901*da0073e9SAndroid Build Coastguard Worker                # For this test, we naively convert the Tensor layout, which we know does
1902*da0073e9SAndroid Build Coastguard Worker                # NOT represent the expected use case for optims like SparseAdam!
1903*da0073e9SAndroid Build Coastguard Worker                p.grad = p.grad.to_sparse()
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker        # Needed for second order optims like LBFGS
1906*da0073e9SAndroid Build Coastguard Worker        def closure():
1907*da0073e9SAndroid Build Coastguard Worker            return 1 if optim_info.step_requires_closure else None
1908*da0073e9SAndroid Build Coastguard Worker
1909*da0073e9SAndroid Build Coastguard Worker        def getPublicAttrs(obj):
1910*da0073e9SAndroid Build Coastguard Worker            return {k for k in obj.__dict__ if not k.startswith("_")}
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker        for optim_input in all_optim_inputs:
1913*da0073e9SAndroid Build Coastguard Worker            optimizer = optim_cls(params, **optim_input.kwargs)
1914*da0073e9SAndroid Build Coastguard Worker
1915*da0073e9SAndroid Build Coastguard Worker            # Make some state
1916*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
1917*da0073e9SAndroid Build Coastguard Worker                if optim_info.step_requires_closure:
1918*da0073e9SAndroid Build Coastguard Worker                    optimizer.step(closure)
1919*da0073e9SAndroid Build Coastguard Worker                else:
1920*da0073e9SAndroid Build Coastguard Worker                    closure()
1921*da0073e9SAndroid Build Coastguard Worker                    optimizer.step()
1922*da0073e9SAndroid Build Coastguard Worker
1923*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1924*da0073e9SAndroid Build Coastguard Worker                getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer))
1925*da0073e9SAndroid Build Coastguard Worker            )
1926*da0073e9SAndroid Build Coastguard Worker
1927*da0073e9SAndroid Build Coastguard Worker    @optims(
1928*da0073e9SAndroid Build Coastguard Worker        [optim for optim in optim_db if optim.step_requires_closure],
1929*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
1930*da0073e9SAndroid Build Coastguard Worker    )
1931*da0073e9SAndroid Build Coastguard Worker    def test_second_order_optims_return_consistent_types(
1932*da0073e9SAndroid Build Coastguard Worker        self, device, dtype, optim_info
1933*da0073e9SAndroid Build Coastguard Worker    ):
1934*da0073e9SAndroid Build Coastguard Worker        # Motivated by #7586
1935*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1936*da0073e9SAndroid Build Coastguard Worker        params = [
1937*da0073e9SAndroid Build Coastguard Worker            torch.randn(10, 5, device=device, dtype=dtype),
1938*da0073e9SAndroid Build Coastguard Worker            torch.randn(10, device=device, dtype=dtype),
1939*da0073e9SAndroid Build Coastguard Worker        ]
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker        def closure():
1942*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([10], device=device, dtype=dtype)
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_info.optim_inputs_func(device=device):
1945*da0073e9SAndroid Build Coastguard Worker            # Currently, the only second order optim is LBFGS, so we just go ahead and modify
1946*da0073e9SAndroid Build Coastguard Worker            # "tolerance_grad", but this may not scale if we add second order optims in the future
1947*da0073e9SAndroid Build Coastguard Worker            kwargs = optim_input.kwargs
1948*da0073e9SAndroid Build Coastguard Worker            kwargs["tolerance_grad"] = math.inf
1949*da0073e9SAndroid Build Coastguard Worker            optim_inf = optim_cls(params, **kwargs)
1950*da0073e9SAndroid Build Coastguard Worker            kwargs["tolerance_grad"] = -math.inf
1951*da0073e9SAndroid Build Coastguard Worker            optim_neg_inf = optim_cls(params, **kwargs)
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker            res1 = optim_inf.step(closure)
1954*da0073e9SAndroid Build Coastguard Worker            res2 = optim_neg_inf.step(closure)
1955*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(type(res1), type(res2))
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1958*da0073e9SAndroid Build Coastguard Worker    @optims(
1959*da0073e9SAndroid Build Coastguard Worker        [
1960*da0073e9SAndroid Build Coastguard Worker            optim
1961*da0073e9SAndroid Build Coastguard Worker            for optim in optim_db
1962*da0073e9SAndroid Build Coastguard Worker            if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on
1963*da0073e9SAndroid Build Coastguard Worker        ],
1964*da0073e9SAndroid Build Coastguard Worker        dtypes=floating_types_and(
1965*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
1966*da0073e9SAndroid Build Coastguard Worker            torch.float16,
1967*da0073e9SAndroid Build Coastguard Worker        ),
1968*da0073e9SAndroid Build Coastguard Worker    )
1969*da0073e9SAndroid Build Coastguard Worker    def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
1970*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
1971*da0073e9SAndroid Build Coastguard Worker        optim_inputs = optim_info.optim_inputs_func(device="cpu")
1972*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_inputs:
1973*da0073e9SAndroid Build Coastguard Worker            inpts, models, optimizers = [], [], []
1974*da0073e9SAndroid Build Coastguard Worker            for dev in ("cpu", "cuda"):
1975*da0073e9SAndroid Build Coastguard Worker                kwargs = optim_input.kwargs
1976*da0073e9SAndroid Build Coastguard Worker                kwargs["fused"] = True
1977*da0073e9SAndroid Build Coastguard Worker                inpt = torch.tensor(
1978*da0073e9SAndroid Build Coastguard Worker                    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
1979*da0073e9SAndroid Build Coastguard Worker                ).reshape(3, 2)
1980*da0073e9SAndroid Build Coastguard Worker
1981*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(1)
1982*da0073e9SAndroid Build Coastguard Worker                model = torch.nn.Sequential(
1983*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(2, 3),
1984*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Sigmoid(),
1985*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(3, 1),
1986*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Sigmoid(),
1987*da0073e9SAndroid Build Coastguard Worker                )
1988*da0073e9SAndroid Build Coastguard Worker                model.to(dtype=dtype, device=dev)
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker                # foreach/fused optimizers should be tested with a
1991*da0073e9SAndroid Build Coastguard Worker                # zero_size tensor as its last param.
1992*da0073e9SAndroid Build Coastguard Worker                # ref: https://github.com/pytorch/pytorch/issues/100701
1993*da0073e9SAndroid Build Coastguard Worker                empty_param = torch.empty(
1994*da0073e9SAndroid Build Coastguard Worker                    (), device=dev, dtype=dtype, requires_grad=True
1995*da0073e9SAndroid Build Coastguard Worker                )
1996*da0073e9SAndroid Build Coastguard Worker                empty_param.grad = torch.rand_like(empty_param)
1997*da0073e9SAndroid Build Coastguard Worker                params = list(model.parameters()) + [empty_param]
1998*da0073e9SAndroid Build Coastguard Worker
1999*da0073e9SAndroid Build Coastguard Worker                optimizer = optim_cls(params, **kwargs)
2000*da0073e9SAndroid Build Coastguard Worker                inpts.append(inpt)
2001*da0073e9SAndroid Build Coastguard Worker                models.append(model)
2002*da0073e9SAndroid Build Coastguard Worker                optimizers.append(optimizer)
2003*da0073e9SAndroid Build Coastguard Worker        self._compare_between(inpts, models, optimizers)
2004*da0073e9SAndroid Build Coastguard Worker
2005*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2006*da0073e9SAndroid Build Coastguard Worker    @optims(
2007*da0073e9SAndroid Build Coastguard Worker        [
2008*da0073e9SAndroid Build Coastguard Worker            o
2009*da0073e9SAndroid Build Coastguard Worker            for o in optim_db
2010*da0073e9SAndroid Build Coastguard Worker            if ("foreach" in o.supported_impls and o.optim_cls.__name__ != "Adafactor")
2011*da0073e9SAndroid Build Coastguard Worker        ],
2012*da0073e9SAndroid Build Coastguard Worker        dtypes=[torch.float32],
2013*da0073e9SAndroid Build Coastguard Worker    )
2014*da0073e9SAndroid Build Coastguard Worker    def test_defaults_changed_to_foreach(self, device, dtype, optim_info):
2015*da0073e9SAndroid Build Coastguard Worker        # Test that the default implementations for optimizers are changed to foreach
2016*da0073e9SAndroid Build Coastguard Worker        # except Adafactor, which defaults to the single tensor impl for memory efficiency.
2017*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
2018*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(5, 5)
2019*da0073e9SAndroid Build Coastguard Worker        model.to(dtype=dtype, device=device)
2020*da0073e9SAndroid Build Coastguard Worker        inpt = torch.rand(2, 5, dtype=dtype, device=device)
2021*da0073e9SAndroid Build Coastguard Worker
2022*da0073e9SAndroid Build Coastguard Worker        import inspect
2023*da0073e9SAndroid Build Coastguard Worker
2024*da0073e9SAndroid Build Coastguard Worker        module = inspect.getmodule(optim_cls)
2025*da0073e9SAndroid Build Coastguard Worker
2026*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_info.optim_inputs_func(device=device):
2027*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls(model.parameters(), **optim_input.kwargs)
2028*da0073e9SAndroid Build Coastguard Worker            optim.zero_grad()
2029*da0073e9SAndroid Build Coastguard Worker            output = model(inpt)
2030*da0073e9SAndroid Build Coastguard Worker            loss = output.sum()
2031*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2032*da0073e9SAndroid Build Coastguard Worker            with patch.object(
2033*da0073e9SAndroid Build Coastguard Worker                module, f"_multi_tensor_{optim_cls.__name__.lower()}"
2034*da0073e9SAndroid Build Coastguard Worker            ) as mocked_foreach_impl:
2035*da0073e9SAndroid Build Coastguard Worker                optim.step()
2036*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(mocked_foreach_impl.called)
2037*da0073e9SAndroid Build Coastguard Worker
2038*da0073e9SAndroid Build Coastguard Worker    @optims(optim_db, dtypes=[torch.float32])
2039*da0073e9SAndroid Build Coastguard Worker    def test_non_empty_state(self, device, dtype, optim_info):
2040*da0073e9SAndroid Build Coastguard Worker        # There are internal tests that check that the state is not empty
2041*da0073e9SAndroid Build Coastguard Worker        optim_cls = optim_info.optim_cls
2042*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(5, 5)
2043*da0073e9SAndroid Build Coastguard Worker        model.to(dtype=dtype, device=device)
2044*da0073e9SAndroid Build Coastguard Worker        inpt = torch.rand(2, 5, dtype=dtype, device=device)
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker        for optim_input in optim_info.optim_inputs_func(device=device):
2047*da0073e9SAndroid Build Coastguard Worker            optim = optim_cls(model.parameters(), **optim_input.kwargs)
2048*da0073e9SAndroid Build Coastguard Worker            optim.zero_grad()
2049*da0073e9SAndroid Build Coastguard Worker            output = model(inpt)
2050*da0073e9SAndroid Build Coastguard Worker            loss = output.sum()
2051*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2052*da0073e9SAndroid Build Coastguard Worker
2053*da0073e9SAndroid Build Coastguard Worker            if optim_info.only_supports_sparse_grads:
2054*da0073e9SAndroid Build Coastguard Worker                for param in model.parameters():
2055*da0073e9SAndroid Build Coastguard Worker                    if param.grad is not None:
2056*da0073e9SAndroid Build Coastguard Worker                        param.grad = param.grad.to_sparse()
2057*da0073e9SAndroid Build Coastguard Worker
2058*da0073e9SAndroid Build Coastguard Worker            if optim_info.step_requires_closure:
2059*da0073e9SAndroid Build Coastguard Worker                optim.step(lambda: 1.0)
2060*da0073e9SAndroid Build Coastguard Worker            else:
2061*da0073e9SAndroid Build Coastguard Worker                optim.step()
2062*da0073e9SAndroid Build Coastguard Worker
2063*da0073e9SAndroid Build Coastguard Worker            for state in optim.state.values():
2064*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(len(state), 0)
2065*da0073e9SAndroid Build Coastguard Worker
2066*da0073e9SAndroid Build Coastguard Worker
2067*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
2068*da0073e9SAndroid Build Coastguard Worker
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2071*da0073e9SAndroid Build Coastguard Worker    run_tests()
2072