xref: /aosp_15_r20/external/pytorch/test/optim/test_optim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: optimizer"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.optim import (
5*da0073e9SAndroid Build Coastguard Worker    Adadelta,
6*da0073e9SAndroid Build Coastguard Worker    Adagrad,
7*da0073e9SAndroid Build Coastguard Worker    Adam,
8*da0073e9SAndroid Build Coastguard Worker    Adamax,
9*da0073e9SAndroid Build Coastguard Worker    AdamW,
10*da0073e9SAndroid Build Coastguard Worker    ASGD,
11*da0073e9SAndroid Build Coastguard Worker    NAdam,
12*da0073e9SAndroid Build Coastguard Worker    RAdam,
13*da0073e9SAndroid Build Coastguard Worker    RMSprop,
14*da0073e9SAndroid Build Coastguard Worker    Rprop,
15*da0073e9SAndroid Build Coastguard Worker    SGD,
16*da0073e9SAndroid Build Coastguard Worker)
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
18*da0073e9SAndroid Build Coastguard Worker    gradcheck,
19*da0073e9SAndroid Build Coastguard Worker    load_tests,
20*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
21*da0073e9SAndroid Build Coastguard Worker    TestCase,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for
26*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
27*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerdef _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
31*da0073e9SAndroid Build Coastguard Worker    # Ignored is the list of values in `opt_differentiable_state`, we do this
32*da0073e9SAndroid Build Coastguard Worker    # for `gradcheck` to correctly track the state tensors as function inputs
33*da0073e9SAndroid Build Coastguard Worker    # because otherwise it can't unpack the values in the `opt_differentiable_state`
34*da0073e9SAndroid Build Coastguard Worker    # dict
35*da0073e9SAndroid Build Coastguard Worker    p = p.clone()
36*da0073e9SAndroid Build Coastguard Worker    p.grad = grad
37*da0073e9SAndroid Build Coastguard Worker    opt_differentiable_state = {
38*da0073e9SAndroid Build Coastguard Worker        k: v.clone() if isinstance(v, torch.Tensor) else v
39*da0073e9SAndroid Build Coastguard Worker        for k, v in opt_differentiable_state.items()
40*da0073e9SAndroid Build Coastguard Worker    }
41*da0073e9SAndroid Build Coastguard Worker    opt = opt_class([p], **kwargs)
42*da0073e9SAndroid Build Coastguard Worker    opt.state[p].update(opt_differentiable_state)
43*da0073e9SAndroid Build Coastguard Worker    opt.step()
44*da0073e9SAndroid Build Coastguard Worker    return (p,) + tuple(
45*da0073e9SAndroid Build Coastguard Worker        v
46*da0073e9SAndroid Build Coastguard Worker        for v in opt.state[p].values()
47*da0073e9SAndroid Build Coastguard Worker        if isinstance(v, torch.Tensor) and v.requires_grad
48*da0073e9SAndroid Build Coastguard Worker    )
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Differentiable optimizers not supported")
52*da0073e9SAndroid Build Coastguard Workerclass TestDifferentiableOptimizer(TestCase):
53*da0073e9SAndroid Build Coastguard Worker    def test_sgd(self):
54*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
55*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
56*da0073e9SAndroid Build Coastguard Worker        mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
57*da0073e9SAndroid Build Coastguard Worker        state = {"momentum_buffer": mbuff}
58*da0073e9SAndroid Build Coastguard Worker        gradcheck(
59*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
60*da0073e9SAndroid Build Coastguard Worker            (
61*da0073e9SAndroid Build Coastguard Worker                p,
62*da0073e9SAndroid Build Coastguard Worker                grad,
63*da0073e9SAndroid Build Coastguard Worker                state,
64*da0073e9SAndroid Build Coastguard Worker                SGD,
65*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "differentiable": True},
66*da0073e9SAndroid Build Coastguard Worker                *state.values(),
67*da0073e9SAndroid Build Coastguard Worker            ),
68*da0073e9SAndroid Build Coastguard Worker        )
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    def test_adam(self):
71*da0073e9SAndroid Build Coastguard Worker        state = {}
72*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
73*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
74*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
75*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
76*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
77*da0073e9SAndroid Build Coastguard Worker        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
78*da0073e9SAndroid Build Coastguard Worker        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
79*da0073e9SAndroid Build Coastguard Worker        state["max_exp_avg_sq"] = torch.rand(
80*da0073e9SAndroid Build Coastguard Worker            10, requires_grad=True, dtype=torch.float64
81*da0073e9SAndroid Build Coastguard Worker        )
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        gradcheck(
84*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
85*da0073e9SAndroid Build Coastguard Worker            (
86*da0073e9SAndroid Build Coastguard Worker                p,
87*da0073e9SAndroid Build Coastguard Worker                grad,
88*da0073e9SAndroid Build Coastguard Worker                state,
89*da0073e9SAndroid Build Coastguard Worker                Adam,
90*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "differentiable": True, "amsgrad": True},
91*da0073e9SAndroid Build Coastguard Worker                *state.values(),
92*da0073e9SAndroid Build Coastguard Worker            ),
93*da0073e9SAndroid Build Coastguard Worker        )
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def test_rmsprop(self):
96*da0073e9SAndroid Build Coastguard Worker        state = {}
97*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
98*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
99*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.zeros((), dtype=torch.float64)
100*da0073e9SAndroid Build Coastguard Worker        state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
101*da0073e9SAndroid Build Coastguard Worker        state["momentum_buffer"] = torch.rand(
102*da0073e9SAndroid Build Coastguard Worker            10, requires_grad=True, dtype=torch.float64
103*da0073e9SAndroid Build Coastguard Worker        )
104*da0073e9SAndroid Build Coastguard Worker        # This can cause issues with large values and nan due to sqrt ops
105*da0073e9SAndroid Build Coastguard Worker        state["grad_avg"] = 1e-2 * torch.rand(
106*da0073e9SAndroid Build Coastguard Worker            10, requires_grad=True, dtype=torch.float64
107*da0073e9SAndroid Build Coastguard Worker        )
108*da0073e9SAndroid Build Coastguard Worker        gradcheck(
109*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
110*da0073e9SAndroid Build Coastguard Worker            (
111*da0073e9SAndroid Build Coastguard Worker                p,
112*da0073e9SAndroid Build Coastguard Worker                grad,
113*da0073e9SAndroid Build Coastguard Worker                state,
114*da0073e9SAndroid Build Coastguard Worker                RMSprop,
115*da0073e9SAndroid Build Coastguard Worker                {
116*da0073e9SAndroid Build Coastguard Worker                    "lr": 0.9,
117*da0073e9SAndroid Build Coastguard Worker                    "maximize": True,
118*da0073e9SAndroid Build Coastguard Worker                    "momentum": 0.9,
119*da0073e9SAndroid Build Coastguard Worker                    "differentiable": True,
120*da0073e9SAndroid Build Coastguard Worker                    "centered": True,
121*da0073e9SAndroid Build Coastguard Worker                    "weight_decay": 0.1,
122*da0073e9SAndroid Build Coastguard Worker                },
123*da0073e9SAndroid Build Coastguard Worker                *state.values(),
124*da0073e9SAndroid Build Coastguard Worker            ),
125*da0073e9SAndroid Build Coastguard Worker        )
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    def test_adadelta(self):
128*da0073e9SAndroid Build Coastguard Worker        state = {}
129*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
130*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
131*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
132*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
133*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
134*da0073e9SAndroid Build Coastguard Worker        state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
135*da0073e9SAndroid Build Coastguard Worker        state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
136*da0073e9SAndroid Build Coastguard Worker        gradcheck(
137*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
138*da0073e9SAndroid Build Coastguard Worker            (
139*da0073e9SAndroid Build Coastguard Worker                p,
140*da0073e9SAndroid Build Coastguard Worker                grad,
141*da0073e9SAndroid Build Coastguard Worker                state,
142*da0073e9SAndroid Build Coastguard Worker                Adadelta,
143*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
144*da0073e9SAndroid Build Coastguard Worker                *state.values(),
145*da0073e9SAndroid Build Coastguard Worker            ),
146*da0073e9SAndroid Build Coastguard Worker        )
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    def test_adagrad(self):
149*da0073e9SAndroid Build Coastguard Worker        state = {}
150*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
151*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
152*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
153*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
154*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
155*da0073e9SAndroid Build Coastguard Worker        state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
156*da0073e9SAndroid Build Coastguard Worker        gradcheck(
157*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
158*da0073e9SAndroid Build Coastguard Worker            (
159*da0073e9SAndroid Build Coastguard Worker                p,
160*da0073e9SAndroid Build Coastguard Worker                grad,
161*da0073e9SAndroid Build Coastguard Worker                state,
162*da0073e9SAndroid Build Coastguard Worker                Adagrad,
163*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
164*da0073e9SAndroid Build Coastguard Worker                *state.values(),
165*da0073e9SAndroid Build Coastguard Worker            ),
166*da0073e9SAndroid Build Coastguard Worker        )
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    def test_adamax(self):
169*da0073e9SAndroid Build Coastguard Worker        state = {}
170*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
171*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
172*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
173*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
174*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
175*da0073e9SAndroid Build Coastguard Worker        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
176*da0073e9SAndroid Build Coastguard Worker        state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
177*da0073e9SAndroid Build Coastguard Worker        gradcheck(
178*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
179*da0073e9SAndroid Build Coastguard Worker            (
180*da0073e9SAndroid Build Coastguard Worker                p,
181*da0073e9SAndroid Build Coastguard Worker                grad,
182*da0073e9SAndroid Build Coastguard Worker                state,
183*da0073e9SAndroid Build Coastguard Worker                Adamax,
184*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
185*da0073e9SAndroid Build Coastguard Worker                *state.values(),
186*da0073e9SAndroid Build Coastguard Worker            ),
187*da0073e9SAndroid Build Coastguard Worker        )
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo(
190*da0073e9SAndroid Build Coastguard Worker        "The inplace mu update fails with dynamo, "
191*da0073e9SAndroid Build Coastguard Worker        "since this is only happening when differentiable is enabled, skipping for now"
192*da0073e9SAndroid Build Coastguard Worker    )
193*da0073e9SAndroid Build Coastguard Worker    def test_asgd(self):
194*da0073e9SAndroid Build Coastguard Worker        state = {}
195*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
196*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
197*da0073e9SAndroid Build Coastguard Worker        # `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
198*da0073e9SAndroid Build Coastguard Worker        # and so they shouldn't require gradients.
199*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
200*da0073e9SAndroid Build Coastguard Worker        state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
201*da0073e9SAndroid Build Coastguard Worker        state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
202*da0073e9SAndroid Build Coastguard Worker        state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        gradcheck(
205*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
206*da0073e9SAndroid Build Coastguard Worker            (
207*da0073e9SAndroid Build Coastguard Worker                p,
208*da0073e9SAndroid Build Coastguard Worker                grad,
209*da0073e9SAndroid Build Coastguard Worker                state,
210*da0073e9SAndroid Build Coastguard Worker                ASGD,
211*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "differentiable": True},
212*da0073e9SAndroid Build Coastguard Worker                *state.values(),
213*da0073e9SAndroid Build Coastguard Worker            ),
214*da0073e9SAndroid Build Coastguard Worker        )
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def test_rprop(self):
217*da0073e9SAndroid Build Coastguard Worker        state = {}
218*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
219*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
220*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
221*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
222*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
223*da0073e9SAndroid Build Coastguard Worker        state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
224*da0073e9SAndroid Build Coastguard Worker        state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        gradcheck(
227*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
228*da0073e9SAndroid Build Coastguard Worker            (
229*da0073e9SAndroid Build Coastguard Worker                p,
230*da0073e9SAndroid Build Coastguard Worker                grad,
231*da0073e9SAndroid Build Coastguard Worker                state,
232*da0073e9SAndroid Build Coastguard Worker                Rprop,
233*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "differentiable": True},
234*da0073e9SAndroid Build Coastguard Worker                *state.values(),
235*da0073e9SAndroid Build Coastguard Worker            ),
236*da0073e9SAndroid Build Coastguard Worker        )
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    def test_adamw(self):
239*da0073e9SAndroid Build Coastguard Worker        state = {}
240*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
241*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
242*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
243*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
244*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
245*da0073e9SAndroid Build Coastguard Worker        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
246*da0073e9SAndroid Build Coastguard Worker        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
247*da0073e9SAndroid Build Coastguard Worker        state["max_exp_avg_sq"] = torch.rand(
248*da0073e9SAndroid Build Coastguard Worker            10, requires_grad=True, dtype=torch.float64
249*da0073e9SAndroid Build Coastguard Worker        )
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        gradcheck(
252*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
253*da0073e9SAndroid Build Coastguard Worker            (
254*da0073e9SAndroid Build Coastguard Worker                p,
255*da0073e9SAndroid Build Coastguard Worker                grad,
256*da0073e9SAndroid Build Coastguard Worker                state,
257*da0073e9SAndroid Build Coastguard Worker                AdamW,
258*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "differentiable": True, "amsgrad": True},
259*da0073e9SAndroid Build Coastguard Worker                *state.values(),
260*da0073e9SAndroid Build Coastguard Worker            ),
261*da0073e9SAndroid Build Coastguard Worker        )
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    def test_nadam(self):
264*da0073e9SAndroid Build Coastguard Worker        state = {}
265*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
266*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
267*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
268*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
269*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
270*da0073e9SAndroid Build Coastguard Worker        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
271*da0073e9SAndroid Build Coastguard Worker        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
272*da0073e9SAndroid Build Coastguard Worker        state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        gradcheck(
275*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
276*da0073e9SAndroid Build Coastguard Worker            (
277*da0073e9SAndroid Build Coastguard Worker                p,
278*da0073e9SAndroid Build Coastguard Worker                grad,
279*da0073e9SAndroid Build Coastguard Worker                state,
280*da0073e9SAndroid Build Coastguard Worker                NAdam,
281*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "differentiable": True},
282*da0073e9SAndroid Build Coastguard Worker                *state.values(),
283*da0073e9SAndroid Build Coastguard Worker            ),
284*da0073e9SAndroid Build Coastguard Worker        )
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        gradcheck(
287*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
288*da0073e9SAndroid Build Coastguard Worker            (
289*da0073e9SAndroid Build Coastguard Worker                p,
290*da0073e9SAndroid Build Coastguard Worker                grad,
291*da0073e9SAndroid Build Coastguard Worker                state,
292*da0073e9SAndroid Build Coastguard Worker                NAdam,
293*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
294*da0073e9SAndroid Build Coastguard Worker                *state.values(),
295*da0073e9SAndroid Build Coastguard Worker            ),
296*da0073e9SAndroid Build Coastguard Worker        )
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    def test_radam(self):
299*da0073e9SAndroid Build Coastguard Worker        state = {}
300*da0073e9SAndroid Build Coastguard Worker        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
301*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
302*da0073e9SAndroid Build Coastguard Worker        # `step` is not a continuous variable (even though we define it as a float)
303*da0073e9SAndroid Build Coastguard Worker        # and so it shouldn't require gradients.
304*da0073e9SAndroid Build Coastguard Worker        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
305*da0073e9SAndroid Build Coastguard Worker        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
306*da0073e9SAndroid Build Coastguard Worker        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker        gradcheck(
309*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
310*da0073e9SAndroid Build Coastguard Worker            (
311*da0073e9SAndroid Build Coastguard Worker                p,
312*da0073e9SAndroid Build Coastguard Worker                grad,
313*da0073e9SAndroid Build Coastguard Worker                state,
314*da0073e9SAndroid Build Coastguard Worker                RAdam,
315*da0073e9SAndroid Build Coastguard Worker                {"lr": 0.9, "differentiable": True},
316*da0073e9SAndroid Build Coastguard Worker                *state.values(),
317*da0073e9SAndroid Build Coastguard Worker            ),
318*da0073e9SAndroid Build Coastguard Worker        )
319*da0073e9SAndroid Build Coastguard Worker        gradcheck(
320*da0073e9SAndroid Build Coastguard Worker            _diff_fn,
321*da0073e9SAndroid Build Coastguard Worker            (
322*da0073e9SAndroid Build Coastguard Worker                p,
323*da0073e9SAndroid Build Coastguard Worker                grad,
324*da0073e9SAndroid Build Coastguard Worker                state,
325*da0073e9SAndroid Build Coastguard Worker                RAdam,
326*da0073e9SAndroid Build Coastguard Worker                {
327*da0073e9SAndroid Build Coastguard Worker                    "lr": 0.9,
328*da0073e9SAndroid Build Coastguard Worker                    "weight_decay": 0.1,
329*da0073e9SAndroid Build Coastguard Worker                    "decoupled_weight_decay": True,
330*da0073e9SAndroid Build Coastguard Worker                    "differentiable": True,
331*da0073e9SAndroid Build Coastguard Worker                },
332*da0073e9SAndroid Build Coastguard Worker                *state.values(),
333*da0073e9SAndroid Build Coastguard Worker            ),
334*da0073e9SAndroid Build Coastguard Worker        )
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
338*da0073e9SAndroid Build Coastguard Worker    print("These tests should be run through test/test_optim.py instead")
339