1# Owner(s): ["module: optimizer"] 2 3import torch 4from torch.optim import ( 5 Adadelta, 6 Adagrad, 7 Adam, 8 Adamax, 9 AdamW, 10 ASGD, 11 NAdam, 12 RAdam, 13 RMSprop, 14 Rprop, 15 SGD, 16) 17from torch.testing._internal.common_utils import ( 18 gradcheck, 19 load_tests, 20 skipIfTorchDynamo, 21 TestCase, 22) 23 24 25# load_tests from common_utils is used to automatically filter tests for 26# sharding on sandcastle. This line silences flake warnings 27load_tests = load_tests 28 29 30def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): 31 # Ignored is the list of values in `opt_differentiable_state`, we do this 32 # for `gradcheck` to correctly track the state tensors as function inputs 33 # because otherwise it can't unpack the values in the `opt_differentiable_state` 34 # dict 35 p = p.clone() 36 p.grad = grad 37 opt_differentiable_state = { 38 k: v.clone() if isinstance(v, torch.Tensor) else v 39 for k, v in opt_differentiable_state.items() 40 } 41 opt = opt_class([p], **kwargs) 42 opt.state[p].update(opt_differentiable_state) 43 opt.step() 44 return (p,) + tuple( 45 v 46 for v in opt.state[p].values() 47 if isinstance(v, torch.Tensor) and v.requires_grad 48 ) 49 50 51@skipIfTorchDynamo("Differentiable optimizers not supported") 52class TestDifferentiableOptimizer(TestCase): 53 def test_sgd(self): 54 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 55 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 56 mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64) 57 state = {"momentum_buffer": mbuff} 58 gradcheck( 59 _diff_fn, 60 ( 61 p, 62 grad, 63 state, 64 SGD, 65 {"lr": 0.9, "differentiable": True}, 66 *state.values(), 67 ), 68 ) 69 70 def test_adam(self): 71 state = {} 72 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 73 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 74 # `step` is not a continuous variable (even though we define it as a float) 75 # and so it shouldn't require gradients. 76 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 77 state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 78 state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 79 state["max_exp_avg_sq"] = torch.rand( 80 10, requires_grad=True, dtype=torch.float64 81 ) 82 83 gradcheck( 84 _diff_fn, 85 ( 86 p, 87 grad, 88 state, 89 Adam, 90 {"lr": 0.9, "differentiable": True, "amsgrad": True}, 91 *state.values(), 92 ), 93 ) 94 95 def test_rmsprop(self): 96 state = {} 97 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 98 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 99 state["step"] = torch.zeros((), dtype=torch.float64) 100 state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 101 state["momentum_buffer"] = torch.rand( 102 10, requires_grad=True, dtype=torch.float64 103 ) 104 # This can cause issues with large values and nan due to sqrt ops 105 state["grad_avg"] = 1e-2 * torch.rand( 106 10, requires_grad=True, dtype=torch.float64 107 ) 108 gradcheck( 109 _diff_fn, 110 ( 111 p, 112 grad, 113 state, 114 RMSprop, 115 { 116 "lr": 0.9, 117 "maximize": True, 118 "momentum": 0.9, 119 "differentiable": True, 120 "centered": True, 121 "weight_decay": 0.1, 122 }, 123 *state.values(), 124 ), 125 ) 126 127 def test_adadelta(self): 128 state = {} 129 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 130 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 131 # `step` is not a continuous variable (even though we define it as a float) 132 # and so it shouldn't require gradients. 133 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 134 state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 135 state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 136 gradcheck( 137 _diff_fn, 138 ( 139 p, 140 grad, 141 state, 142 Adadelta, 143 {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, 144 *state.values(), 145 ), 146 ) 147 148 def test_adagrad(self): 149 state = {} 150 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 151 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 152 # `step` is not a continuous variable (even though we define it as a float) 153 # and so it shouldn't require gradients. 154 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 155 state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 156 gradcheck( 157 _diff_fn, 158 ( 159 p, 160 grad, 161 state, 162 Adagrad, 163 {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, 164 *state.values(), 165 ), 166 ) 167 168 def test_adamax(self): 169 state = {} 170 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 171 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 172 # `step` is not a continuous variable (even though we define it as a float) 173 # and so it shouldn't require gradients. 174 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 175 state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 176 state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 177 gradcheck( 178 _diff_fn, 179 ( 180 p, 181 grad, 182 state, 183 Adamax, 184 {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, 185 *state.values(), 186 ), 187 ) 188 189 @skipIfTorchDynamo( 190 "The inplace mu update fails with dynamo, " 191 "since this is only happening when differentiable is enabled, skipping for now" 192 ) 193 def test_asgd(self): 194 state = {} 195 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 196 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 197 # `step` `eta` & `mu` are not continuous variables (even though we define them as floats) 198 # and so they shouldn't require gradients. 199 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 200 state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64) 201 state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64) 202 state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 203 204 gradcheck( 205 _diff_fn, 206 ( 207 p, 208 grad, 209 state, 210 ASGD, 211 {"lr": 0.9, "differentiable": True}, 212 *state.values(), 213 ), 214 ) 215 216 def test_rprop(self): 217 state = {} 218 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 219 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 220 # `step` is not a continuous variable (even though we define it as a float) 221 # and so it shouldn't require gradients. 222 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 223 state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 224 state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 225 226 gradcheck( 227 _diff_fn, 228 ( 229 p, 230 grad, 231 state, 232 Rprop, 233 {"lr": 0.9, "differentiable": True}, 234 *state.values(), 235 ), 236 ) 237 238 def test_adamw(self): 239 state = {} 240 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 241 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 242 # `step` is not a continuous variable (even though we define it as a float) 243 # and so it shouldn't require gradients. 244 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 245 state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 246 state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 247 state["max_exp_avg_sq"] = torch.rand( 248 10, requires_grad=True, dtype=torch.float64 249 ) 250 251 gradcheck( 252 _diff_fn, 253 ( 254 p, 255 grad, 256 state, 257 AdamW, 258 {"lr": 0.9, "differentiable": True, "amsgrad": True}, 259 *state.values(), 260 ), 261 ) 262 263 def test_nadam(self): 264 state = {} 265 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 266 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 267 # `step` is not a continuous variable (even though we define it as a float) 268 # and so it shouldn't require gradients. 269 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 270 state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 271 state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 272 state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64) 273 274 gradcheck( 275 _diff_fn, 276 ( 277 p, 278 grad, 279 state, 280 NAdam, 281 {"lr": 0.9, "differentiable": True}, 282 *state.values(), 283 ), 284 ) 285 286 gradcheck( 287 _diff_fn, 288 ( 289 p, 290 grad, 291 state, 292 NAdam, 293 {"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True}, 294 *state.values(), 295 ), 296 ) 297 298 def test_radam(self): 299 state = {} 300 p = torch.rand(10, requires_grad=True, dtype=torch.float64) 301 grad = torch.rand(10, requires_grad=True, dtype=torch.float64) 302 # `step` is not a continuous variable (even though we define it as a float) 303 # and so it shouldn't require gradients. 304 state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) 305 state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 306 state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) 307 308 gradcheck( 309 _diff_fn, 310 ( 311 p, 312 grad, 313 state, 314 RAdam, 315 {"lr": 0.9, "differentiable": True}, 316 *state.values(), 317 ), 318 ) 319 gradcheck( 320 _diff_fn, 321 ( 322 p, 323 grad, 324 state, 325 RAdam, 326 { 327 "lr": 0.9, 328 "weight_decay": 0.1, 329 "decoupled_weight_decay": True, 330 "differentiable": True, 331 }, 332 *state.values(), 333 ), 334 ) 335 336 337if __name__ == "__main__": 338 print("These tests should be run through test/test_optim.py instead") 339