1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: optimizer"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.optim import ( 5*da0073e9SAndroid Build Coastguard Worker Adadelta, 6*da0073e9SAndroid Build Coastguard Worker Adagrad, 7*da0073e9SAndroid Build Coastguard Worker Adam, 8*da0073e9SAndroid Build Coastguard Worker Adamax, 9*da0073e9SAndroid Build Coastguard Worker AdamW, 10*da0073e9SAndroid Build Coastguard Worker ASGD, 11*da0073e9SAndroid Build Coastguard Worker NAdam, 12*da0073e9SAndroid Build Coastguard Worker RAdam, 13*da0073e9SAndroid Build Coastguard Worker RMSprop, 14*da0073e9SAndroid Build Coastguard Worker Rprop, 15*da0073e9SAndroid Build Coastguard Worker SGD, 16*da0073e9SAndroid Build Coastguard Worker) 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 18*da0073e9SAndroid Build Coastguard Worker gradcheck, 19*da0073e9SAndroid Build Coastguard Worker load_tests, 20*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 21*da0073e9SAndroid Build Coastguard Worker TestCase, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for 26*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 27*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerdef _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): 31*da0073e9SAndroid Build Coastguard Worker # Ignored is the list of values in `opt_differentiable_state`, we do this 32*da0073e9SAndroid Build Coastguard Worker # for `gradcheck` to correctly track the state tensors as function inputs 33*da0073e9SAndroid Build Coastguard Worker # because otherwise it can't unpack the values in the `opt_differentiable_state` 34*da0073e9SAndroid Build Coastguard Worker # dict 35*da0073e9SAndroid Build Coastguard Worker p = p.clone() 36*da0073e9SAndroid Build Coastguard Worker p.grad = grad 37*da0073e9SAndroid Build Coastguard Worker opt_differentiable_state = { 38*da0073e9SAndroid Build Coastguard Worker k: v.clone() if isinstance(v, torch.Tensor) else v 39*da0073e9SAndroid Build Coastguard Worker for k, v in opt_differentiable_state.items() 40*da0073e9SAndroid Build Coastguard Worker } 41*da0073e9SAndroid Build Coastguard Worker opt = opt_class([p], **kwargs) 42*da0073e9SAndroid Build Coastguard Worker opt.state[p].update(opt_differentiable_state) 43*da0073e9SAndroid Build Coastguard Worker opt.step() 44*da0073e9SAndroid Build Coastguard Worker return (p,) + tuple( 45*da0073e9SAndroid Build Coastguard Worker v 46*da0073e9SAndroid Build Coastguard Worker for v in opt.state[p].values() 47*da0073e9SAndroid Build Coastguard Worker if isinstance(v, torch.Tensor) and v.requires_grad 48*da0073e9SAndroid Build Coastguard Worker ) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Differentiable optimizers not supported") 52*da0073e9SAndroid Build Coastguard Workerclass TestDifferentiableOptimizer(TestCase): 53*da0073e9SAndroid Build Coastguard Worker def test_sgd(self): 54*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 55*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 56*da0073e9SAndroid Build Coastguard Worker mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64) 57*da0073e9SAndroid Build Coastguard Worker state = {"momentum_buffer": mbuff} 58*da0073e9SAndroid Build Coastguard Worker gradcheck( 59*da0073e9SAndroid Build Coastguard Worker _diff_fn, 60*da0073e9SAndroid Build Coastguard Worker ( 61*da0073e9SAndroid Build Coastguard Worker p, 62*da0073e9SAndroid Build Coastguard Worker grad, 63*da0073e9SAndroid Build Coastguard Worker state, 64*da0073e9SAndroid Build Coastguard Worker SGD, 65*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "differentiable": True}, 66*da0073e9SAndroid Build Coastguard Worker *state.values(), 67*da0073e9SAndroid Build Coastguard Worker ), 68*da0073e9SAndroid Build Coastguard Worker ) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker def test_adam(self): 71*da0073e9SAndroid Build Coastguard Worker state = {} 72*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 73*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 74*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 75*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 76*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 77*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 78*da0073e9SAndroid Build Coastguard Worker state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 79*da0073e9SAndroid Build Coastguard Worker state["max_exp_avg_sq"] = torch.rand( 80*da0073e9SAndroid Build Coastguard Worker 10, requires_grad=True, dtype=torch.float64 81*da0073e9SAndroid Build Coastguard Worker ) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker gradcheck( 84*da0073e9SAndroid Build Coastguard Worker _diff_fn, 85*da0073e9SAndroid Build Coastguard Worker ( 86*da0073e9SAndroid Build Coastguard Worker p, 87*da0073e9SAndroid Build Coastguard Worker grad, 88*da0073e9SAndroid Build Coastguard Worker state, 89*da0073e9SAndroid Build Coastguard Worker Adam, 90*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "differentiable": True, "amsgrad": True}, 91*da0073e9SAndroid Build Coastguard Worker *state.values(), 92*da0073e9SAndroid Build Coastguard Worker ), 93*da0073e9SAndroid Build Coastguard Worker ) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def test_rmsprop(self): 96*da0073e9SAndroid Build Coastguard Worker state = {} 97*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 98*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 99*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.zeros((), dtype=torch.float64) 100*da0073e9SAndroid Build Coastguard Worker state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 101*da0073e9SAndroid Build Coastguard Worker state["momentum_buffer"] = torch.rand( 102*da0073e9SAndroid Build Coastguard Worker 10, requires_grad=True, dtype=torch.float64 103*da0073e9SAndroid Build Coastguard Worker ) 104*da0073e9SAndroid Build Coastguard Worker # This can cause issues with large values and nan due to sqrt ops 105*da0073e9SAndroid Build Coastguard Worker state["grad_avg"] = 1e-2 * torch.rand( 106*da0073e9SAndroid Build Coastguard Worker 10, requires_grad=True, dtype=torch.float64 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker gradcheck( 109*da0073e9SAndroid Build Coastguard Worker _diff_fn, 110*da0073e9SAndroid Build Coastguard Worker ( 111*da0073e9SAndroid Build Coastguard Worker p, 112*da0073e9SAndroid Build Coastguard Worker grad, 113*da0073e9SAndroid Build Coastguard Worker state, 114*da0073e9SAndroid Build Coastguard Worker RMSprop, 115*da0073e9SAndroid Build Coastguard Worker { 116*da0073e9SAndroid Build Coastguard Worker "lr": 0.9, 117*da0073e9SAndroid Build Coastguard Worker "maximize": True, 118*da0073e9SAndroid Build Coastguard Worker "momentum": 0.9, 119*da0073e9SAndroid Build Coastguard Worker "differentiable": True, 120*da0073e9SAndroid Build Coastguard Worker "centered": True, 121*da0073e9SAndroid Build Coastguard Worker "weight_decay": 0.1, 122*da0073e9SAndroid Build Coastguard Worker }, 123*da0073e9SAndroid Build Coastguard Worker *state.values(), 124*da0073e9SAndroid Build Coastguard Worker ), 125*da0073e9SAndroid Build Coastguard Worker ) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def test_adadelta(self): 128*da0073e9SAndroid Build Coastguard Worker state = {} 129*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 130*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 131*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 132*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 133*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 134*da0073e9SAndroid Build Coastguard Worker state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 135*da0073e9SAndroid Build Coastguard Worker state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 136*da0073e9SAndroid Build Coastguard Worker gradcheck( 137*da0073e9SAndroid Build Coastguard Worker _diff_fn, 138*da0073e9SAndroid Build Coastguard Worker ( 139*da0073e9SAndroid Build Coastguard Worker p, 140*da0073e9SAndroid Build Coastguard Worker grad, 141*da0073e9SAndroid Build Coastguard Worker state, 142*da0073e9SAndroid Build Coastguard Worker Adadelta, 143*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, 144*da0073e9SAndroid Build Coastguard Worker *state.values(), 145*da0073e9SAndroid Build Coastguard Worker ), 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker def test_adagrad(self): 149*da0073e9SAndroid Build Coastguard Worker state = {} 150*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 151*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 152*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 153*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 154*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 155*da0073e9SAndroid Build Coastguard Worker state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 156*da0073e9SAndroid Build Coastguard Worker gradcheck( 157*da0073e9SAndroid Build Coastguard Worker _diff_fn, 158*da0073e9SAndroid Build Coastguard Worker ( 159*da0073e9SAndroid Build Coastguard Worker p, 160*da0073e9SAndroid Build Coastguard Worker grad, 161*da0073e9SAndroid Build Coastguard Worker state, 162*da0073e9SAndroid Build Coastguard Worker Adagrad, 163*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, 164*da0073e9SAndroid Build Coastguard Worker *state.values(), 165*da0073e9SAndroid Build Coastguard Worker ), 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def test_adamax(self): 169*da0073e9SAndroid Build Coastguard Worker state = {} 170*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 171*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 172*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 173*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 174*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 175*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 176*da0073e9SAndroid Build Coastguard Worker state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 177*da0073e9SAndroid Build Coastguard Worker gradcheck( 178*da0073e9SAndroid Build Coastguard Worker _diff_fn, 179*da0073e9SAndroid Build Coastguard Worker ( 180*da0073e9SAndroid Build Coastguard Worker p, 181*da0073e9SAndroid Build Coastguard Worker grad, 182*da0073e9SAndroid Build Coastguard Worker state, 183*da0073e9SAndroid Build Coastguard Worker Adamax, 184*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, 185*da0073e9SAndroid Build Coastguard Worker *state.values(), 186*da0073e9SAndroid Build Coastguard Worker ), 187*da0073e9SAndroid Build Coastguard Worker ) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo( 190*da0073e9SAndroid Build Coastguard Worker "The inplace mu update fails with dynamo, " 191*da0073e9SAndroid Build Coastguard Worker "since this is only happening when differentiable is enabled, skipping for now" 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker def test_asgd(self): 194*da0073e9SAndroid Build Coastguard Worker state = {} 195*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 196*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 197*da0073e9SAndroid Build Coastguard Worker # `step` `eta` & `mu` are not continuous variables (even though we define them as floats) 198*da0073e9SAndroid Build Coastguard Worker # and so they shouldn't require gradients. 199*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 200*da0073e9SAndroid Build Coastguard Worker state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64) 201*da0073e9SAndroid Build Coastguard Worker state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64) 202*da0073e9SAndroid Build Coastguard Worker state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker gradcheck( 205*da0073e9SAndroid Build Coastguard Worker _diff_fn, 206*da0073e9SAndroid Build Coastguard Worker ( 207*da0073e9SAndroid Build Coastguard Worker p, 208*da0073e9SAndroid Build Coastguard Worker grad, 209*da0073e9SAndroid Build Coastguard Worker state, 210*da0073e9SAndroid Build Coastguard Worker ASGD, 211*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "differentiable": True}, 212*da0073e9SAndroid Build Coastguard Worker *state.values(), 213*da0073e9SAndroid Build Coastguard Worker ), 214*da0073e9SAndroid Build Coastguard Worker ) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def test_rprop(self): 217*da0073e9SAndroid Build Coastguard Worker state = {} 218*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 219*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 220*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 221*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 222*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 223*da0073e9SAndroid Build Coastguard Worker state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 224*da0073e9SAndroid Build Coastguard Worker state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker gradcheck( 227*da0073e9SAndroid Build Coastguard Worker _diff_fn, 228*da0073e9SAndroid Build Coastguard Worker ( 229*da0073e9SAndroid Build Coastguard Worker p, 230*da0073e9SAndroid Build Coastguard Worker grad, 231*da0073e9SAndroid Build Coastguard Worker state, 232*da0073e9SAndroid Build Coastguard Worker Rprop, 233*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "differentiable": True}, 234*da0073e9SAndroid Build Coastguard Worker *state.values(), 235*da0073e9SAndroid Build Coastguard Worker ), 236*da0073e9SAndroid Build Coastguard Worker ) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def test_adamw(self): 239*da0073e9SAndroid Build Coastguard Worker state = {} 240*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 241*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 242*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 243*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 244*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 245*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 246*da0073e9SAndroid Build Coastguard Worker state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 247*da0073e9SAndroid Build Coastguard Worker state["max_exp_avg_sq"] = torch.rand( 248*da0073e9SAndroid Build Coastguard Worker 10, requires_grad=True, dtype=torch.float64 249*da0073e9SAndroid Build Coastguard Worker ) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker gradcheck( 252*da0073e9SAndroid Build Coastguard Worker _diff_fn, 253*da0073e9SAndroid Build Coastguard Worker ( 254*da0073e9SAndroid Build Coastguard Worker p, 255*da0073e9SAndroid Build Coastguard Worker grad, 256*da0073e9SAndroid Build Coastguard Worker state, 257*da0073e9SAndroid Build Coastguard Worker AdamW, 258*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "differentiable": True, "amsgrad": True}, 259*da0073e9SAndroid Build Coastguard Worker *state.values(), 260*da0073e9SAndroid Build Coastguard Worker ), 261*da0073e9SAndroid Build Coastguard Worker ) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker def test_nadam(self): 264*da0073e9SAndroid Build Coastguard Worker state = {} 265*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 266*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 267*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 268*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 269*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 270*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 271*da0073e9SAndroid Build Coastguard Worker state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 272*da0073e9SAndroid Build Coastguard Worker state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker gradcheck( 275*da0073e9SAndroid Build Coastguard Worker _diff_fn, 276*da0073e9SAndroid Build Coastguard Worker ( 277*da0073e9SAndroid Build Coastguard Worker p, 278*da0073e9SAndroid Build Coastguard Worker grad, 279*da0073e9SAndroid Build Coastguard Worker state, 280*da0073e9SAndroid Build Coastguard Worker NAdam, 281*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "differentiable": True}, 282*da0073e9SAndroid Build Coastguard Worker *state.values(), 283*da0073e9SAndroid Build Coastguard Worker ), 284*da0073e9SAndroid Build Coastguard Worker ) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker gradcheck( 287*da0073e9SAndroid Build Coastguard Worker _diff_fn, 288*da0073e9SAndroid Build Coastguard Worker ( 289*da0073e9SAndroid Build Coastguard Worker p, 290*da0073e9SAndroid Build Coastguard Worker grad, 291*da0073e9SAndroid Build Coastguard Worker state, 292*da0073e9SAndroid Build Coastguard Worker NAdam, 293*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True}, 294*da0073e9SAndroid Build Coastguard Worker *state.values(), 295*da0073e9SAndroid Build Coastguard Worker ), 296*da0073e9SAndroid Build Coastguard Worker ) 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker def test_radam(self): 299*da0073e9SAndroid Build Coastguard Worker state = {} 300*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, requires_grad=True, dtype=torch.float64) 301*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 302*da0073e9SAndroid Build Coastguard Worker # `step` is not a continuous variable (even though we define it as a float) 303*da0073e9SAndroid Build Coastguard Worker # and so it shouldn't require gradients. 304*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 305*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 306*da0073e9SAndroid Build Coastguard Worker state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker gradcheck( 309*da0073e9SAndroid Build Coastguard Worker _diff_fn, 310*da0073e9SAndroid Build Coastguard Worker ( 311*da0073e9SAndroid Build Coastguard Worker p, 312*da0073e9SAndroid Build Coastguard Worker grad, 313*da0073e9SAndroid Build Coastguard Worker state, 314*da0073e9SAndroid Build Coastguard Worker RAdam, 315*da0073e9SAndroid Build Coastguard Worker {"lr": 0.9, "differentiable": True}, 316*da0073e9SAndroid Build Coastguard Worker *state.values(), 317*da0073e9SAndroid Build Coastguard Worker ), 318*da0073e9SAndroid Build Coastguard Worker ) 319*da0073e9SAndroid Build Coastguard Worker gradcheck( 320*da0073e9SAndroid Build Coastguard Worker _diff_fn, 321*da0073e9SAndroid Build Coastguard Worker ( 322*da0073e9SAndroid Build Coastguard Worker p, 323*da0073e9SAndroid Build Coastguard Worker grad, 324*da0073e9SAndroid Build Coastguard Worker state, 325*da0073e9SAndroid Build Coastguard Worker RAdam, 326*da0073e9SAndroid Build Coastguard Worker { 327*da0073e9SAndroid Build Coastguard Worker "lr": 0.9, 328*da0073e9SAndroid Build Coastguard Worker "weight_decay": 0.1, 329*da0073e9SAndroid Build Coastguard Worker "decoupled_weight_decay": True, 330*da0073e9SAndroid Build Coastguard Worker "differentiable": True, 331*da0073e9SAndroid Build Coastguard Worker }, 332*da0073e9SAndroid Build Coastguard Worker *state.values(), 333*da0073e9SAndroid Build Coastguard Worker ), 334*da0073e9SAndroid Build Coastguard Worker ) 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 338*da0073e9SAndroid Build Coastguard Worker print("These tests should be run through test/test_optim.py instead") 339