xref: /aosp_15_r20/external/pytorch/test/optim/test_optim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: optimizer"]
2
3import torch
4from torch.optim import (
5    Adadelta,
6    Adagrad,
7    Adam,
8    Adamax,
9    AdamW,
10    ASGD,
11    NAdam,
12    RAdam,
13    RMSprop,
14    Rprop,
15    SGD,
16)
17from torch.testing._internal.common_utils import (
18    gradcheck,
19    load_tests,
20    skipIfTorchDynamo,
21    TestCase,
22)
23
24
25# load_tests from common_utils is used to automatically filter tests for
26# sharding on sandcastle. This line silences flake warnings
27load_tests = load_tests
28
29
30def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
31    # Ignored is the list of values in `opt_differentiable_state`, we do this
32    # for `gradcheck` to correctly track the state tensors as function inputs
33    # because otherwise it can't unpack the values in the `opt_differentiable_state`
34    # dict
35    p = p.clone()
36    p.grad = grad
37    opt_differentiable_state = {
38        k: v.clone() if isinstance(v, torch.Tensor) else v
39        for k, v in opt_differentiable_state.items()
40    }
41    opt = opt_class([p], **kwargs)
42    opt.state[p].update(opt_differentiable_state)
43    opt.step()
44    return (p,) + tuple(
45        v
46        for v in opt.state[p].values()
47        if isinstance(v, torch.Tensor) and v.requires_grad
48    )
49
50
51@skipIfTorchDynamo("Differentiable optimizers not supported")
52class TestDifferentiableOptimizer(TestCase):
53    def test_sgd(self):
54        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
55        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
56        mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
57        state = {"momentum_buffer": mbuff}
58        gradcheck(
59            _diff_fn,
60            (
61                p,
62                grad,
63                state,
64                SGD,
65                {"lr": 0.9, "differentiable": True},
66                *state.values(),
67            ),
68        )
69
70    def test_adam(self):
71        state = {}
72        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
73        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
74        # `step` is not a continuous variable (even though we define it as a float)
75        # and so it shouldn't require gradients.
76        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
77        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
78        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
79        state["max_exp_avg_sq"] = torch.rand(
80            10, requires_grad=True, dtype=torch.float64
81        )
82
83        gradcheck(
84            _diff_fn,
85            (
86                p,
87                grad,
88                state,
89                Adam,
90                {"lr": 0.9, "differentiable": True, "amsgrad": True},
91                *state.values(),
92            ),
93        )
94
95    def test_rmsprop(self):
96        state = {}
97        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
98        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
99        state["step"] = torch.zeros((), dtype=torch.float64)
100        state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
101        state["momentum_buffer"] = torch.rand(
102            10, requires_grad=True, dtype=torch.float64
103        )
104        # This can cause issues with large values and nan due to sqrt ops
105        state["grad_avg"] = 1e-2 * torch.rand(
106            10, requires_grad=True, dtype=torch.float64
107        )
108        gradcheck(
109            _diff_fn,
110            (
111                p,
112                grad,
113                state,
114                RMSprop,
115                {
116                    "lr": 0.9,
117                    "maximize": True,
118                    "momentum": 0.9,
119                    "differentiable": True,
120                    "centered": True,
121                    "weight_decay": 0.1,
122                },
123                *state.values(),
124            ),
125        )
126
127    def test_adadelta(self):
128        state = {}
129        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
130        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
131        # `step` is not a continuous variable (even though we define it as a float)
132        # and so it shouldn't require gradients.
133        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
134        state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
135        state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
136        gradcheck(
137            _diff_fn,
138            (
139                p,
140                grad,
141                state,
142                Adadelta,
143                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
144                *state.values(),
145            ),
146        )
147
148    def test_adagrad(self):
149        state = {}
150        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
151        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
152        # `step` is not a continuous variable (even though we define it as a float)
153        # and so it shouldn't require gradients.
154        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
155        state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
156        gradcheck(
157            _diff_fn,
158            (
159                p,
160                grad,
161                state,
162                Adagrad,
163                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
164                *state.values(),
165            ),
166        )
167
168    def test_adamax(self):
169        state = {}
170        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
171        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
172        # `step` is not a continuous variable (even though we define it as a float)
173        # and so it shouldn't require gradients.
174        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
175        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
176        state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
177        gradcheck(
178            _diff_fn,
179            (
180                p,
181                grad,
182                state,
183                Adamax,
184                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
185                *state.values(),
186            ),
187        )
188
189    @skipIfTorchDynamo(
190        "The inplace mu update fails with dynamo, "
191        "since this is only happening when differentiable is enabled, skipping for now"
192    )
193    def test_asgd(self):
194        state = {}
195        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
196        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
197        # `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
198        # and so they shouldn't require gradients.
199        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
200        state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
201        state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
202        state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
203
204        gradcheck(
205            _diff_fn,
206            (
207                p,
208                grad,
209                state,
210                ASGD,
211                {"lr": 0.9, "differentiable": True},
212                *state.values(),
213            ),
214        )
215
216    def test_rprop(self):
217        state = {}
218        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
219        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
220        # `step` is not a continuous variable (even though we define it as a float)
221        # and so it shouldn't require gradients.
222        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
223        state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
224        state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
225
226        gradcheck(
227            _diff_fn,
228            (
229                p,
230                grad,
231                state,
232                Rprop,
233                {"lr": 0.9, "differentiable": True},
234                *state.values(),
235            ),
236        )
237
238    def test_adamw(self):
239        state = {}
240        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
241        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
242        # `step` is not a continuous variable (even though we define it as a float)
243        # and so it shouldn't require gradients.
244        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
245        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
246        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
247        state["max_exp_avg_sq"] = torch.rand(
248            10, requires_grad=True, dtype=torch.float64
249        )
250
251        gradcheck(
252            _diff_fn,
253            (
254                p,
255                grad,
256                state,
257                AdamW,
258                {"lr": 0.9, "differentiable": True, "amsgrad": True},
259                *state.values(),
260            ),
261        )
262
263    def test_nadam(self):
264        state = {}
265        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
266        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
267        # `step` is not a continuous variable (even though we define it as a float)
268        # and so it shouldn't require gradients.
269        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
270        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
271        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
272        state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
273
274        gradcheck(
275            _diff_fn,
276            (
277                p,
278                grad,
279                state,
280                NAdam,
281                {"lr": 0.9, "differentiable": True},
282                *state.values(),
283            ),
284        )
285
286        gradcheck(
287            _diff_fn,
288            (
289                p,
290                grad,
291                state,
292                NAdam,
293                {"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
294                *state.values(),
295            ),
296        )
297
298    def test_radam(self):
299        state = {}
300        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
301        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
302        # `step` is not a continuous variable (even though we define it as a float)
303        # and so it shouldn't require gradients.
304        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
305        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
306        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
307
308        gradcheck(
309            _diff_fn,
310            (
311                p,
312                grad,
313                state,
314                RAdam,
315                {"lr": 0.9, "differentiable": True},
316                *state.values(),
317            ),
318        )
319        gradcheck(
320            _diff_fn,
321            (
322                p,
323                grad,
324                state,
325                RAdam,
326                {
327                    "lr": 0.9,
328                    "weight_decay": 0.1,
329                    "decoupled_weight_decay": True,
330                    "differentiable": True,
331                },
332                *state.values(),
333            ),
334        )
335
336
337if __name__ == "__main__":
338    print("These tests should be run through test/test_optim.py instead")
339