1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: optimizer"] 2*da0073e9SAndroid Build Coastguard Workerimport functools 3*da0073e9SAndroid Build Coastguard Workerimport math 4*da0073e9SAndroid Build Coastguard Workerimport tempfile 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Tuple 8*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerfrom optim.test_lrscheduler import TestLRScheduler # noqa: F401 11*da0073e9SAndroid Build Coastguard Workerfrom optim.test_optim import TestDifferentiableOptimizer # noqa: F401 12*da0073e9SAndroid Build Coastguard Workerfrom optim.test_swa_utils import TestSWAUtils # noqa: F401 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerimport torch 15*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter 16*da0073e9SAndroid Build Coastguard Workerfrom torch.optim import Optimizer, SGD 17*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.lr_scheduler import ReduceLROnPlateau 18*da0073e9SAndroid Build Coastguard Workerfrom torch.optim.optimizer import ( 19*da0073e9SAndroid Build Coastguard Worker register_optimizer_step_post_hook, 20*da0073e9SAndroid Build Coastguard Worker register_optimizer_step_pre_hook, 21*da0073e9SAndroid Build Coastguard Worker) 22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 24*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 25*da0073e9SAndroid Build Coastguard Worker largeTensorTest, 26*da0073e9SAndroid Build Coastguard Worker onlyCPU, 27*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 28*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 29*da0073e9SAndroid Build Coastguard Worker skipMPS, 30*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and 33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_optimizers import ( 34*da0073e9SAndroid Build Coastguard Worker _get_device_type, 35*da0073e9SAndroid Build Coastguard Worker _get_optim_inputs_including_global_cliquey_kwargs, 36*da0073e9SAndroid Build Coastguard Worker optim_db, 37*da0073e9SAndroid Build Coastguard Worker OptimizerErrorEnum, 38*da0073e9SAndroid Build Coastguard Worker optims, 39*da0073e9SAndroid Build Coastguard Worker TensorTracker, 40*da0073e9SAndroid Build Coastguard Worker) 41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 42*da0073e9SAndroid Build Coastguard Worker markDynamoStrictTest, 43*da0073e9SAndroid Build Coastguard Worker parametrize, 44*da0073e9SAndroid Build Coastguard Worker run_tests, 45*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 46*da0073e9SAndroid Build Coastguard Worker TestCase, 47*da0073e9SAndroid Build Coastguard Worker) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard WorkerFP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4} 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerdef rosenbrock(tensor): 54*da0073e9SAndroid Build Coastguard Worker assert tensor.size() == torch.Size( 55*da0073e9SAndroid Build Coastguard Worker [2] 56*da0073e9SAndroid Build Coastguard Worker ), f"Requires tensor with 2 scalars but got {tensor.size()}" 57*da0073e9SAndroid Build Coastguard Worker x, y = tensor 58*da0073e9SAndroid Build Coastguard Worker return (1 - x) ** 2 + 100 * (y - x**2) ** 2 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerdef drosenbrock(tensor): 62*da0073e9SAndroid Build Coastguard Worker assert tensor.size() == torch.Size( 63*da0073e9SAndroid Build Coastguard Worker [2] 64*da0073e9SAndroid Build Coastguard Worker ), f"Requires tensor with 2 scalars but got {tensor.size()}" 65*da0073e9SAndroid Build Coastguard Worker x, y = tensor 66*da0073e9SAndroid Build Coastguard Worker return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2))) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker@markDynamoStrictTest 70*da0073e9SAndroid Build Coastguard Workerclass TestOptimRenewed(TestCase): 71*da0073e9SAndroid Build Coastguard Worker """ 72*da0073e9SAndroid Build Coastguard Worker This test class validates the core optimizers and is structured as the correctness of: 73*da0073e9SAndroid Build Coastguard Worker - The update algorithms (forloop implementation) 74*da0073e9SAndroid Build Coastguard Worker * Every optimizer's algorithm is most readably implemented through a big for-loop 75*da0073e9SAndroid Build Coastguard Worker over all the parameters, which is what we refer to as the forloop or single tensor 76*da0073e9SAndroid Build Coastguard Worker implementation. These algorithms are manually validated by comparing to the paper 77*da0073e9SAndroid Build Coastguard Worker and systematically validated by assuring that the loss goes the right direction 78*da0073e9SAndroid Build Coastguard Worker when the optimizer has been applied. 79*da0073e9SAndroid Build Coastguard Worker * This implementation should compose with optimizer hyperparameters well, such as 80*da0073e9SAndroid Build Coastguard Worker supporting Tensor LRs, the capturable API, and sparse and complex parameters. 81*da0073e9SAndroid Build Coastguard Worker - Each varying implementation 82*da0073e9SAndroid Build Coastguard Worker * We then have implementations that improve upon the performance of the forloop 83*da0073e9SAndroid Build Coastguard Worker implementation by leveraging fusion, namely our foreach (mult_tensor) and fused 84*da0073e9SAndroid Build Coastguard Worker implementations. 85*da0073e9SAndroid Build Coastguard Worker * These variations are validated numerically by comparing with the forloop version 86*da0073e9SAndroid Build Coastguard Worker of the optimizer. In fact, we test most variations this way--we see the forloop 87*da0073e9SAndroid Build Coastguard Worker implementation as the ground truth and expect that improvements to it in any way 88*da0073e9SAndroid Build Coastguard Worker should be just as correct. 89*da0073e9SAndroid Build Coastguard Worker * Both params and optimizer states should be validated numerically. 90*da0073e9SAndroid Build Coastguard Worker - state_dict APIs 91*da0073e9SAndroid Build Coastguard Worker * The optimizer instance should be serializable 92*da0073e9SAndroid Build Coastguard Worker * Calling save and load should be deterministic 93*da0073e9SAndroid Build Coastguard Worker * Moving between devices should be seamless 94*da0073e9SAndroid Build Coastguard Worker * BC - load_state_dict should be able to handle older optimizer states 95*da0073e9SAndroid Build Coastguard Worker - Hook APIs (everything should fire in the right order) 96*da0073e9SAndroid Build Coastguard Worker - LR Scheduler integration (composing should not error + should go the right direction) 97*da0073e9SAndroid Build Coastguard Worker - Parameter groups (should be equivalent to having multiple optimizers) 98*da0073e9SAndroid Build Coastguard Worker - Erroring (what should error should error) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker We also cover different ways of generating parameters and grads: 101*da0073e9SAndroid Build Coastguard Worker - With parameters, we either generate them randomly given specific shapes or we take 102*da0073e9SAndroid Build Coastguard Worker them from a sample NN module. 103*da0073e9SAndroid Build Coastguard Worker * Variety is important here because NN modules have type Parameter and randomly 104*da0073e9SAndroid Build Coastguard Worker generated tensors have type Tensor. 105*da0073e9SAndroid Build Coastguard Worker * Parameters can be sparse for a subset of the optimizers (check out OptimizerInfo) 106*da0073e9SAndroid Build Coastguard Worker * Complex parameters should be handled using view_as_real 107*da0073e9SAndroid Build Coastguard Worker * Parameters can be spread across different devices and different dtypes for any 108*da0073e9SAndroid Build Coastguard Worker given optimizer 109*da0073e9SAndroid Build Coastguard Worker * Parameters can be contiguous and noncontiguous 110*da0073e9SAndroid Build Coastguard Worker - With grads, we follow suit from the parameters. 111*da0073e9SAndroid Build Coastguard Worker * Grads can also be None, empty, or zero-valued, and this should not disrupt training. 112*da0073e9SAndroid Build Coastguard Worker """ 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker @onlyCPU 115*da0073e9SAndroid Build Coastguard Worker @optims(optim_db) 116*da0073e9SAndroid Build Coastguard Worker def test_optim_infos_do_not_specify_global_cliquey_kwargs( 117*da0073e9SAndroid Build Coastguard Worker self, device, dtype, optim_info 118*da0073e9SAndroid Build Coastguard Worker ): 119*da0073e9SAndroid Build Coastguard Worker global_cliquey_flags = ["foreach", "fused", "differentiable"] 120*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_info.optim_inputs_func(device=device): 121*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 122*da0073e9SAndroid Build Coastguard Worker any(f for f in global_cliquey_flags if f in optim_input.kwargs) 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None]) 126*da0073e9SAndroid Build Coastguard Worker def test_errors(self, device, dtype, optim_info): 127*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 128*da0073e9SAndroid Build Coastguard Worker error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker for error_input in error_inputs: 131*da0073e9SAndroid Build Coastguard Worker optim_input = error_input.optimizer_error_input 132*da0073e9SAndroid Build Coastguard Worker params, kwargs = optim_input.params, optim_input.kwargs 133*da0073e9SAndroid Build Coastguard Worker if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR: 134*da0073e9SAndroid Build Coastguard Worker if issubclass(error_input.error_type, Warning): 135*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 136*da0073e9SAndroid Build Coastguard Worker error_input.error_type, error_input.error_regex 137*da0073e9SAndroid Build Coastguard Worker ): 138*da0073e9SAndroid Build Coastguard Worker optim_cls(params, **kwargs) 139*da0073e9SAndroid Build Coastguard Worker else: 140*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 141*da0073e9SAndroid Build Coastguard Worker error_input.error_type, error_input.error_regex 142*da0073e9SAndroid Build Coastguard Worker ): 143*da0073e9SAndroid Build Coastguard Worker optim_cls(params, **kwargs) 144*da0073e9SAndroid Build Coastguard Worker elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR: 145*da0073e9SAndroid Build Coastguard Worker optim = optim_cls(params, **kwargs) 146*da0073e9SAndroid Build Coastguard Worker if issubclass(error_input.error_type, Warning): 147*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 148*da0073e9SAndroid Build Coastguard Worker error_input.error_type, error_input.error_regex 149*da0073e9SAndroid Build Coastguard Worker ): 150*da0073e9SAndroid Build Coastguard Worker optim.step() 151*da0073e9SAndroid Build Coastguard Worker else: 152*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 153*da0073e9SAndroid Build Coastguard Worker error_input.error_type, error_input.error_regex 154*da0073e9SAndroid Build Coastguard Worker ): 155*da0073e9SAndroid Build Coastguard Worker optim.step() 156*da0073e9SAndroid Build Coastguard Worker else: 157*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f"Unknown error type {error_input.error_on}") 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker @parametrize("contiguous", [True, False]) 160*da0073e9SAndroid Build Coastguard Worker @parametrize("with_lrsched", [True, False]) 161*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 162*da0073e9SAndroid Build Coastguard Worker def test_forloop_goes_right_direction( 163*da0073e9SAndroid Build Coastguard Worker self, device, dtype, optim_info, contiguous, with_lrsched 164*da0073e9SAndroid Build Coastguard Worker ): 165*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 166*da0073e9SAndroid Build Coastguard Worker schedulers_constructors = ( 167*da0073e9SAndroid Build Coastguard Worker optim_info.scheduler_inputs if with_lrsched else [None] 168*da0073e9SAndroid Build Coastguard Worker ) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker for schedulers_constructor in schedulers_constructors: 171*da0073e9SAndroid Build Coastguard Worker # with tensor LR we need fresh inputs for each scheduler 172*da0073e9SAndroid Build Coastguard Worker # or mutating it will carry across iters 173*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 174*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 175*da0073e9SAndroid Build Coastguard Worker if "foreach" in optim_info.supported_impls: 176*da0073e9SAndroid Build Coastguard Worker optim_input.kwargs["foreach"] = False # force forloop 177*da0073e9SAndroid Build Coastguard Worker if contiguous: 178*da0073e9SAndroid Build Coastguard Worker weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 179*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 180*da0073e9SAndroid Build Coastguard Worker else: 181*da0073e9SAndroid Build Coastguard Worker weight = Parameter( 182*da0073e9SAndroid Build Coastguard Worker torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0] 183*da0073e9SAndroid Build Coastguard Worker ) 184*da0073e9SAndroid Build Coastguard Worker bias = Parameter( 185*da0073e9SAndroid Build Coastguard Worker torch.randn((10, 2), device=device, dtype=dtype)[..., 0] 186*da0073e9SAndroid Build Coastguard Worker ) 187*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, device=device, dtype=dtype) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls([weight, bias], **optim_input.kwargs) 190*da0073e9SAndroid Build Coastguard Worker schedulers = [ 191*da0073e9SAndroid Build Coastguard Worker s(optimizer) 192*da0073e9SAndroid Build Coastguard Worker for s in (schedulers_constructor if schedulers_constructor else []) 193*da0073e9SAndroid Build Coastguard Worker ] 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def closure(): 196*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 197*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(input) + bias).pow(2).sum() 198*da0073e9SAndroid Build Coastguard Worker loss.backward() 199*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 200*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 201*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 202*da0073e9SAndroid Build Coastguard Worker weight.grad = weight.grad.to_sparse() 203*da0073e9SAndroid Build Coastguard Worker bias.grad = bias.grad.to_sparse() 204*da0073e9SAndroid Build Coastguard Worker return loss 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker initial_value = closure().item() 207*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 208*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 209*da0073e9SAndroid Build Coastguard Worker loss = optimizer.step(closure) 210*da0073e9SAndroid Build Coastguard Worker else: 211*da0073e9SAndroid Build Coastguard Worker loss = closure() 212*da0073e9SAndroid Build Coastguard Worker optimizer.step() 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker for scheduler in schedulers: 215*da0073e9SAndroid Build Coastguard Worker if isinstance(scheduler, ReduceLROnPlateau): 216*da0073e9SAndroid Build Coastguard Worker scheduler.step(loss) 217*da0073e9SAndroid Build Coastguard Worker else: 218*da0073e9SAndroid Build Coastguard Worker scheduler.step() 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker if optim_input.kwargs.get("maximize", False): 221*da0073e9SAndroid Build Coastguard Worker self.assertGreater(closure().item(), initial_value) 222*da0073e9SAndroid Build Coastguard Worker else: 223*da0073e9SAndroid Build Coastguard Worker self.assertLess(closure().item(), initial_value) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 226*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 227*da0073e9SAndroid Build Coastguard Worker @parametrize("with_lrsched", [True, False]) 228*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 229*da0073e9SAndroid Build Coastguard Worker def test_forloop_goes_right_direction_multigpu( 230*da0073e9SAndroid Build Coastguard Worker self, device, dtype, optim_info, with_lrsched 231*da0073e9SAndroid Build Coastguard Worker ): 232*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 233*da0073e9SAndroid Build Coastguard Worker schedulers_constructors = ( 234*da0073e9SAndroid Build Coastguard Worker optim_info.scheduler_inputs if with_lrsched else [None] 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker for schedulers_constructor in schedulers_constructors: 237*da0073e9SAndroid Build Coastguard Worker # We need a fresh set of inputs if we have a tensor LR 238*da0073e9SAndroid Build Coastguard Worker # to not carry mutations across iterations. 239*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 240*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 241*da0073e9SAndroid Build Coastguard Worker if "foreach" in optim_info.supported_impls: 242*da0073e9SAndroid Build Coastguard Worker optim_input.kwargs["foreach"] = False # force forloop 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) 245*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) 246*da0073e9SAndroid Build Coastguard Worker inpt = torch.randn(5, device="cuda:0", dtype=dtype) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls([weight, bias], **optim_input.kwargs) 249*da0073e9SAndroid Build Coastguard Worker schedulers = [ 250*da0073e9SAndroid Build Coastguard Worker s(optimizer) 251*da0073e9SAndroid Build Coastguard Worker for s in (schedulers_constructor if schedulers_constructor else []) 252*da0073e9SAndroid Build Coastguard Worker ] 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker def closure(): 255*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 256*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum() 257*da0073e9SAndroid Build Coastguard Worker loss.backward() 258*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 259*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 260*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 261*da0073e9SAndroid Build Coastguard Worker weight.grad = weight.grad.to_sparse() 262*da0073e9SAndroid Build Coastguard Worker bias.grad = bias.grad.to_sparse() 263*da0073e9SAndroid Build Coastguard Worker return loss 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker initial_value = closure().item() 266*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 267*da0073e9SAndroid Build Coastguard Worker loss = optimizer.step(closure) 268*da0073e9SAndroid Build Coastguard Worker for scheduler in schedulers: 269*da0073e9SAndroid Build Coastguard Worker if isinstance(scheduler, ReduceLROnPlateau): 270*da0073e9SAndroid Build Coastguard Worker scheduler.step(loss) 271*da0073e9SAndroid Build Coastguard Worker else: 272*da0073e9SAndroid Build Coastguard Worker scheduler.step() 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker if optim_input.kwargs.get("maximize", False): 275*da0073e9SAndroid Build Coastguard Worker self.assertGreater(closure().item(), initial_value) 276*da0073e9SAndroid Build Coastguard Worker else: 277*da0073e9SAndroid Build Coastguard Worker self.assertLess(closure().item(), initial_value) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 280*da0073e9SAndroid Build Coastguard Worker def test_param_group_with_lrscheduler_goes_right_direction( 281*da0073e9SAndroid Build Coastguard Worker self, device, dtype, optim_info 282*da0073e9SAndroid Build Coastguard Worker ): 283*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker for schedulers_c in optim_info.scheduler_inputs: 286*da0073e9SAndroid Build Coastguard Worker weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 287*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 288*da0073e9SAndroid Build Coastguard Worker inpt = torch.randn(5, device=device, dtype=dtype) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker # avoid endless recompiles by wrapping LR in a tensor if we're compiling 291*da0073e9SAndroid Build Coastguard Worker lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01 292*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) 293*da0073e9SAndroid Build Coastguard Worker schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker def closure(): 296*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 297*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(inpt) + bias).pow(2).sum() 298*da0073e9SAndroid Build Coastguard Worker loss.backward() 299*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 300*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 301*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 302*da0073e9SAndroid Build Coastguard Worker weight.grad = weight.grad.to_sparse() 303*da0073e9SAndroid Build Coastguard Worker bias.grad = bias.grad.to_sparse() 304*da0073e9SAndroid Build Coastguard Worker return loss 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker initial_value = closure().item() 307*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 308*da0073e9SAndroid Build Coastguard Worker loss = optimizer.step(closure) 309*da0073e9SAndroid Build Coastguard Worker for scheduler in schedulers: 310*da0073e9SAndroid Build Coastguard Worker if isinstance(scheduler, ReduceLROnPlateau): 311*da0073e9SAndroid Build Coastguard Worker scheduler.step(loss) 312*da0073e9SAndroid Build Coastguard Worker else: 313*da0073e9SAndroid Build Coastguard Worker scheduler.step() 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker self.assertLess(closure().item(), initial_value) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 318*da0073e9SAndroid Build Coastguard Worker def test_tensor_lr(self, device, dtype, optim_info): 319*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 322*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 323*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 324*da0073e9SAndroid Build Coastguard Worker ) 325*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 326*da0073e9SAndroid Build Coastguard Worker weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 327*da0073e9SAndroid Build Coastguard Worker weight_c = weight.clone().detach().requires_grad_(True) 328*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 329*da0073e9SAndroid Build Coastguard Worker bias_c = bias.clone().detach().requires_grad_(True) 330*da0073e9SAndroid Build Coastguard Worker inpt = torch.randn(5, device=device, dtype=dtype) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 333*da0073e9SAndroid Build Coastguard Worker if "lr" in kwargs: 334*da0073e9SAndroid Build Coastguard Worker del kwargs["lr"] 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3 337*da0073e9SAndroid Build Coastguard Worker optimizer_r = optim_cls([weight, bias], **kwargs) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker try: 340*da0073e9SAndroid Build Coastguard Worker kwargs["lr"] = torch.tensor(kwargs["lr"]) 341*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls([weight_c, bias_c], **kwargs) 342*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 343*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(e), ".*lr as a Tensor is not supported.*") 344*da0073e9SAndroid Build Coastguard Worker continue 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker def closure(optim, w, b, i): 347*da0073e9SAndroid Build Coastguard Worker optim.zero_grad() 348*da0073e9SAndroid Build Coastguard Worker loss = (w.mv(i) + b).pow(2).sum() 349*da0073e9SAndroid Build Coastguard Worker loss.backward() 350*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 351*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 352*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 353*da0073e9SAndroid Build Coastguard Worker w.grad = w.grad.to_sparse() 354*da0073e9SAndroid Build Coastguard Worker b.grad = b.grad.to_sparse() 355*da0073e9SAndroid Build Coastguard Worker return loss 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 358*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 359*da0073e9SAndroid Build Coastguard Worker optimizer_r.step( 360*da0073e9SAndroid Build Coastguard Worker functools.partial(closure, optimizer_r, weight, bias, inpt) 361*da0073e9SAndroid Build Coastguard Worker ) 362*da0073e9SAndroid Build Coastguard Worker optimizer.step( 363*da0073e9SAndroid Build Coastguard Worker functools.partial(closure, optimizer, weight_c, bias_c, inpt) 364*da0073e9SAndroid Build Coastguard Worker ) 365*da0073e9SAndroid Build Coastguard Worker else: 366*da0073e9SAndroid Build Coastguard Worker closure(optimizer_r, weight, bias, inpt) 367*da0073e9SAndroid Build Coastguard Worker closure(optimizer, weight_c, bias_c, inpt) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight, weight_c) 370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bias, bias_c) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker @parametrize("with_lrsched", [True, False]) 373*da0073e9SAndroid Build Coastguard Worker @optims( 374*da0073e9SAndroid Build Coastguard Worker [o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads], 375*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float64], 376*da0073e9SAndroid Build Coastguard Worker ) 377*da0073e9SAndroid Build Coastguard Worker def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched): 378*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 381*da0073e9SAndroid Build Coastguard Worker # Fused impls do not support sparse gradients 382*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 383*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable", "fused") 384*da0073e9SAndroid Build Coastguard Worker ) 385*da0073e9SAndroid Build Coastguard Worker kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker if with_lrsched and len(schedulers_constructors) == 0: 388*da0073e9SAndroid Build Coastguard Worker return 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker supported_inputs = [] 391*da0073e9SAndroid Build Coastguard Worker if len(kwarg_updates) != 0: 392*da0073e9SAndroid Build Coastguard Worker seen = set() 393*da0073e9SAndroid Build Coastguard Worker for i in all_optim_inputs: 394*da0073e9SAndroid Build Coastguard Worker for k in kwarg_updates: 395*da0073e9SAndroid Build Coastguard Worker if k in i.kwargs: 396*da0073e9SAndroid Build Coastguard Worker del i.kwargs[k] 397*da0073e9SAndroid Build Coastguard Worker hashable_kwargs = tuple(sorted(i.kwargs.items())) 398*da0073e9SAndroid Build Coastguard Worker if len(i.kwargs) > 0 and hashable_kwargs not in seen: 399*da0073e9SAndroid Build Coastguard Worker supported_inputs.append(i) 400*da0073e9SAndroid Build Coastguard Worker seen.add(hashable_kwargs) 401*da0073e9SAndroid Build Coastguard Worker if "lr" in kwarg_updates: 402*da0073e9SAndroid Build Coastguard Worker i.kwargs["lr"] = kwarg_updates["lr"] 403*da0073e9SAndroid Build Coastguard Worker else: 404*da0073e9SAndroid Build Coastguard Worker supported_inputs = all_optim_inputs 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker for optim_input in supported_inputs: 407*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 408*da0073e9SAndroid Build Coastguard Worker multi_tensor = kwargs.get("foreach", False) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker # For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers 411*da0073e9SAndroid Build Coastguard Worker if multi_tensor: 412*da0073e9SAndroid Build Coastguard Worker params_t = [ 413*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.5, 1.5]), 414*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.5, 1.5], dtype=dtype), 415*da0073e9SAndroid Build Coastguard Worker ] 416*da0073e9SAndroid Build Coastguard Worker else: 417*da0073e9SAndroid Build Coastguard Worker params_t = [torch.tensor([1.5, 1.5])] 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker params = [Parameter(param_t) for param_t in params_t] 420*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **kwargs) 421*da0073e9SAndroid Build Coastguard Worker schedulers = [ 422*da0073e9SAndroid Build Coastguard Worker s(optimizer) for s in (schedulers_constructors if with_lrsched else []) 423*da0073e9SAndroid Build Coastguard Worker ] 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker if not optim_info.only_supports_sparse_grads: 426*da0073e9SAndroid Build Coastguard Worker params_c = [Parameter(param_t.clone()) for param_t in params_t] 427*da0073e9SAndroid Build Coastguard Worker optimizer_c = optim_cls(params_c, **kwargs) 428*da0073e9SAndroid Build Coastguard Worker schedulers_c = [ 429*da0073e9SAndroid Build Coastguard Worker s(optimizer_c) 430*da0073e9SAndroid Build Coastguard Worker for s in (schedulers_constructors if with_lrsched else []) 431*da0073e9SAndroid Build Coastguard Worker ] 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker solution = torch.tensor([1, 1]) 434*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 435*da0073e9SAndroid Build Coastguard Worker initial_dist = sum(param.dist(solution) for param in params) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker def get_grad(param, sparse_grad, w): 438*da0073e9SAndroid Build Coastguard Worker grad = drosenbrock(param) 439*da0073e9SAndroid Build Coastguard Worker # NB: We torture test the optimizer by returning an 440*da0073e9SAndroid Build Coastguard Worker # uncoalesced sparse tensor 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker # Depending on w, provide only the x or y gradient 443*da0073e9SAndroid Build Coastguard Worker if sparse_grad: 444*da0073e9SAndroid Build Coastguard Worker if w: 445*da0073e9SAndroid Build Coastguard Worker i = torch.tensor([[0, 0]], dtype=torch.int64) 446*da0073e9SAndroid Build Coastguard Worker x = grad[0] 447*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([x / 4.0, x - x / 4.0]) 448*da0073e9SAndroid Build Coastguard Worker else: 449*da0073e9SAndroid Build Coastguard Worker i = torch.tensor([[1, 1]], dtype=torch.int64) 450*da0073e9SAndroid Build Coastguard Worker y = grad[1] 451*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([y - y / 4.0, y / 4.0]) 452*da0073e9SAndroid Build Coastguard Worker grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype) 453*da0073e9SAndroid Build Coastguard Worker else: 454*da0073e9SAndroid Build Coastguard Worker if w: 455*da0073e9SAndroid Build Coastguard Worker grad_out = torch.tensor([grad[0], 0], dtype=param.dtype) 456*da0073e9SAndroid Build Coastguard Worker else: 457*da0073e9SAndroid Build Coastguard Worker grad_out = torch.tensor([0, grad[1]], dtype=param.dtype) 458*da0073e9SAndroid Build Coastguard Worker return grad_out 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker def eval(params, sparse_grad, w): 461*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 462*da0073e9SAndroid Build Coastguard Worker if multi_tensor: 463*da0073e9SAndroid Build Coastguard Worker loss = sum(rosenbrock(param) for param in params) 464*da0073e9SAndroid Build Coastguard Worker else: 465*da0073e9SAndroid Build Coastguard Worker loss = rosenbrock(params[0]) 466*da0073e9SAndroid Build Coastguard Worker loss.backward() 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker grads_out = [get_grad(param, sparse_grad, w) for param in params] 469*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 470*da0073e9SAndroid Build Coastguard Worker params[0].grad = grads_out[0] 471*da0073e9SAndroid Build Coastguard Worker if multi_tensor: 472*da0073e9SAndroid Build Coastguard Worker params[1].grad = grads_out[1].to(dtype=dtype) 473*da0073e9SAndroid Build Coastguard Worker return loss 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker for i in range(1800): 476*da0073e9SAndroid Build Coastguard Worker # Do cyclic coordinate descent 477*da0073e9SAndroid Build Coastguard Worker w = i % 2 478*da0073e9SAndroid Build Coastguard Worker optimizer.step(functools.partial(eval, params, True, w)) 479*da0073e9SAndroid Build Coastguard Worker for scheduler in schedulers: 480*da0073e9SAndroid Build Coastguard Worker if isinstance(scheduler, ReduceLROnPlateau): 481*da0073e9SAndroid Build Coastguard Worker scheduler.step(rosenbrock(params[0])) 482*da0073e9SAndroid Build Coastguard Worker else: 483*da0073e9SAndroid Build Coastguard Worker scheduler.step() 484*da0073e9SAndroid Build Coastguard Worker if not optim_info.only_supports_sparse_grads: 485*da0073e9SAndroid Build Coastguard Worker optimizer_c.step(functools.partial(eval, params_c, False, w)) 486*da0073e9SAndroid Build Coastguard Worker for scheduler in schedulers_c: 487*da0073e9SAndroid Build Coastguard Worker if isinstance(scheduler, ReduceLROnPlateau): 488*da0073e9SAndroid Build Coastguard Worker scheduler.step(rosenbrock(params_c[0])) 489*da0073e9SAndroid Build Coastguard Worker else: 490*da0073e9SAndroid Build Coastguard Worker scheduler.step() 491*da0073e9SAndroid Build Coastguard Worker # Tolerance is increased due to floating point error from different 492*da0073e9SAndroid Build Coastguard Worker # code path for dense case: x v.s. x - x / 4.0 + x / 4.0 493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker if not kwargs.get("maximize", False): 496*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual( 497*da0073e9SAndroid Build Coastguard Worker sum(param.dist(solution) for param in params), initial_dist 498*da0073e9SAndroid Build Coastguard Worker ) 499*da0073e9SAndroid Build Coastguard Worker else: 500*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 501*da0073e9SAndroid Build Coastguard Worker sum(rosenbrock(param) for param in params), 502*da0073e9SAndroid Build Coastguard Worker sum(rosenbrock(param_t) for param_t in params_t), 503*da0073e9SAndroid Build Coastguard Worker ) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker @skipMPS 506*da0073e9SAndroid Build Coastguard Worker @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64]) 507*da0073e9SAndroid Build Coastguard Worker def test_complex(self, device, dtype, optim_info): 508*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 509*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 510*da0073e9SAndroid Build Coastguard Worker # Also skip fused, since our fused kernels do not support complex 511*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 512*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable", "fused") 513*da0073e9SAndroid Build Coastguard Worker ) 514*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 515*da0073e9SAndroid Build Coastguard Worker # Last param is intentionally real to test that we can mix real and complex 516*da0073e9SAndroid Build Coastguard Worker complex_params = [ 517*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True), 518*da0073e9SAndroid Build Coastguard Worker torch.randn(10, device=device, dtype=dtype, requires_grad=True), 519*da0073e9SAndroid Build Coastguard Worker torch.randn( 520*da0073e9SAndroid Build Coastguard Worker 10, 5, device=device, dtype=torch.float32, requires_grad=True 521*da0073e9SAndroid Build Coastguard Worker ), 522*da0073e9SAndroid Build Coastguard Worker ] 523*da0073e9SAndroid Build Coastguard Worker real_params = [ 524*da0073e9SAndroid Build Coastguard Worker ( 525*da0073e9SAndroid Build Coastguard Worker torch.view_as_real(param).detach().clone().requires_grad_() 526*da0073e9SAndroid Build Coastguard Worker if param.is_complex() 527*da0073e9SAndroid Build Coastguard Worker else param.detach().clone().requires_grad_() 528*da0073e9SAndroid Build Coastguard Worker ) 529*da0073e9SAndroid Build Coastguard Worker for param in complex_params 530*da0073e9SAndroid Build Coastguard Worker ] 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker complex_optimizer = optim_cls(complex_params, **optim_input.kwargs) 533*da0073e9SAndroid Build Coastguard Worker real_optimizer = optim_cls(real_params, **optim_input.kwargs) 534*da0073e9SAndroid Build Coastguard Worker real_steps = [] 535*da0073e9SAndroid Build Coastguard Worker complex_steps = [] 536*da0073e9SAndroid Build Coastguard Worker grads_losses = [] 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker def real_closure(): 539*da0073e9SAndroid Build Coastguard Worker for param in real_params: 540*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(param) 541*da0073e9SAndroid Build Coastguard Worker param.grad = grad 542*da0073e9SAndroid Build Coastguard Worker real_steps.append(param.detach().clone()) 543*da0073e9SAndroid Build Coastguard Worker grads_losses.append(grad.clone()) 544*da0073e9SAndroid Build Coastguard Worker loss = torch.randn(1) 545*da0073e9SAndroid Build Coastguard Worker grads_losses.append(loss.clone()) 546*da0073e9SAndroid Build Coastguard Worker return loss 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker def complex_closure(): 549*da0073e9SAndroid Build Coastguard Worker for param in complex_params: 550*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(param): 551*da0073e9SAndroid Build Coastguard Worker grad = torch.view_as_complex(grads_losses.pop(0)) 552*da0073e9SAndroid Build Coastguard Worker complex_steps.append(torch.view_as_real_copy(param.detach())) 553*da0073e9SAndroid Build Coastguard Worker else: 554*da0073e9SAndroid Build Coastguard Worker grad = grads_losses.pop(0) 555*da0073e9SAndroid Build Coastguard Worker complex_steps.append(param.detach().clone()) 556*da0073e9SAndroid Build Coastguard Worker param.grad = grad 557*da0073e9SAndroid Build Coastguard Worker return grads_losses.pop(0) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 560*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 561*da0073e9SAndroid Build Coastguard Worker # LBFGS, for example, requires closure and calls it internally 562*da0073e9SAndroid Build Coastguard Worker real_optimizer.step(real_closure) 563*da0073e9SAndroid Build Coastguard Worker complex_optimizer.step(complex_closure) 564*da0073e9SAndroid Build Coastguard Worker else: 565*da0073e9SAndroid Build Coastguard Worker # For other optimizers, we call closure explicitly to set the gradients 566*da0073e9SAndroid Build Coastguard Worker real_closure() 567*da0073e9SAndroid Build Coastguard Worker complex_closure() 568*da0073e9SAndroid Build Coastguard Worker real_optimizer.step() 569*da0073e9SAndroid Build Coastguard Worker complex_optimizer.step() 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker # Final Parameters should be the same 572*da0073e9SAndroid Build Coastguard Worker complex_params_asreal = [ 573*da0073e9SAndroid Build Coastguard Worker torch.view_as_real(param) if param.is_complex() else param 574*da0073e9SAndroid Build Coastguard Worker for param in complex_params 575*da0073e9SAndroid Build Coastguard Worker ] 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(real_params, complex_params_asreal) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker # All intermediate steps should also be the same 579*da0073e9SAndroid Build Coastguard Worker # also checks steps taken within for example a line search 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complex_steps, real_steps) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker @skipMPS 583*da0073e9SAndroid Build Coastguard Worker @optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64]) 584*da0073e9SAndroid Build Coastguard Worker def test_complex_2d(self, device, dtype, optim_info): 585*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 586*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 587*da0073e9SAndroid Build Coastguard Worker # Also skip fused, since our fused kernels do not support complex 588*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 589*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable", "fused") 590*da0073e9SAndroid Build Coastguard Worker ) 591*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 592*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 593*da0073e9SAndroid Build Coastguard Worker # Why? The way we implement complex is by turning complex params into view_as_real 594*da0073e9SAndroid Build Coastguard Worker # alternatives. For example, an size (M,N) tensor will become (M,N,2). In this test, 595*da0073e9SAndroid Build Coastguard Worker # we break apart a tensor into its real and imaginary parts, which would be 2x(M,N). 596*da0073e9SAndroid Build Coastguard Worker # For other pointwise optimizers, this distinction is trivial, but for LBFGS where 597*da0073e9SAndroid Build Coastguard Worker # there are reductions across all parameters (and all the grads get flattened into 598*da0073e9SAndroid Build Coastguard Worker # one long Tensor), this ordering matters. Why? Reductions are not deterministic 599*da0073e9SAndroid Build Coastguard Worker # because addition between floating point numbers is not associative, i.e., 600*da0073e9SAndroid Build Coastguard Worker # a + b + c != a + c + b. Thus, we add a seed here to control the discrepancy that 601*da0073e9SAndroid Build Coastguard Worker # will happen with LBFGS. Note that in test_complex above, there is no need for a seed 602*da0073e9SAndroid Build Coastguard Worker # nor for increased tolerance, because results should be bitwise equivalent. 603*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2024) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker a1 = torch.randn(2, device=device, dtype=dtype, requires_grad=True) 606*da0073e9SAndroid Build Coastguard Worker a1_real = a1.real.clone().detach() 607*da0073e9SAndroid Build Coastguard Worker a1_imag = a1.imag.clone().detach() 608*da0073e9SAndroid Build Coastguard Worker a1_real.requires_grad_() 609*da0073e9SAndroid Build Coastguard Worker a1_imag.requires_grad_() 610*da0073e9SAndroid Build Coastguard Worker optim1 = optim_cls([a1], **optim_input.kwargs) 611*da0073e9SAndroid Build Coastguard Worker optim2 = optim_cls([a1_real, a1_imag], **optim_input.kwargs) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker a1_reals = TensorTracker() 614*da0073e9SAndroid Build Coastguard Worker a1_imags = TensorTracker() 615*da0073e9SAndroid Build Coastguard Worker a1_grad_reals = TensorTracker() 616*da0073e9SAndroid Build Coastguard Worker a1_grad_imags = TensorTracker() 617*da0073e9SAndroid Build Coastguard Worker losses = TensorTracker() 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker def closure1(): 620*da0073e9SAndroid Build Coastguard Worker optim1.zero_grad() 621*da0073e9SAndroid Build Coastguard Worker loss = rosenbrock(a1).abs() 622*da0073e9SAndroid Build Coastguard Worker loss.backward() 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker # Track clones to best test accuracy 625*da0073e9SAndroid Build Coastguard Worker a1_reals.add(a1.real) 626*da0073e9SAndroid Build Coastguard Worker a1_imags.add(a1.imag) 627*da0073e9SAndroid Build Coastguard Worker a1_grad_reals.add(a1.grad.real) 628*da0073e9SAndroid Build Coastguard Worker a1_grad_imags.add(a1.grad.imag) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker losses.add(loss) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker return loss 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker def closure2(): 635*da0073e9SAndroid Build Coastguard Worker optim2.zero_grad() 636*da0073e9SAndroid Build Coastguard Worker a1_reals.pop_check_set(a1_real, self) 637*da0073e9SAndroid Build Coastguard Worker a1_imags.pop_check_set(a1_imag, self) 638*da0073e9SAndroid Build Coastguard Worker a2 = torch.complex(a1_real, a1_imag) 639*da0073e9SAndroid Build Coastguard Worker loss = rosenbrock(a2).abs() 640*da0073e9SAndroid Build Coastguard Worker losses.pop_check_set(loss, self) 641*da0073e9SAndroid Build Coastguard Worker loss.backward() 642*da0073e9SAndroid Build Coastguard Worker a1_grad_reals.pop_check_set(a1_real.grad, self) 643*da0073e9SAndroid Build Coastguard Worker a1_grad_imags.pop_check_set(a1_imag.grad, self) 644*da0073e9SAndroid Build Coastguard Worker return loss 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 647*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 648*da0073e9SAndroid Build Coastguard Worker # LBFGS, for example, requires closure and calls it internally 649*da0073e9SAndroid Build Coastguard Worker optim1.step(closure1) 650*da0073e9SAndroid Build Coastguard Worker optim2.step(closure2) 651*da0073e9SAndroid Build Coastguard Worker else: 652*da0073e9SAndroid Build Coastguard Worker closure1() 653*da0073e9SAndroid Build Coastguard Worker closure2() 654*da0073e9SAndroid Build Coastguard Worker optim1.step() 655*da0073e9SAndroid Build Coastguard Worker optim2.step() 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.real, a1_real) 658*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.imag, a1_imag) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a1_reals.all_popped()) 661*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a1_imags.all_popped()) 662*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a1_grad_reals.all_popped()) 663*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a1_grad_imags.all_popped()) 664*da0073e9SAndroid Build Coastguard Worker self.assertTrue(losses.all_popped()) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker def _compare_between( 667*da0073e9SAndroid Build Coastguard Worker self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None 668*da0073e9SAndroid Build Coastguard Worker ): 669*da0073e9SAndroid Build Coastguard Worker # why 7? iteration 7 is where we start to see differences for RAdam 670*da0073e9SAndroid Build Coastguard Worker # params interacting with the small eps value, because that's right 671*da0073e9SAndroid Build Coastguard Worker # after rho_t becomes greater than 5 in step 6. 672*da0073e9SAndroid Build Coastguard Worker if assert_eq_kwargs is None: 673*da0073e9SAndroid Build Coastguard Worker assert_eq_kwargs = {} 674*da0073e9SAndroid Build Coastguard Worker kIterations = 7 675*da0073e9SAndroid Build Coastguard Worker tracker = TensorTracker(assert_eq_kwargs) 676*da0073e9SAndroid Build Coastguard Worker for i in range(kIterations): 677*da0073e9SAndroid Build Coastguard Worker state, updated_params = [], [] 678*da0073e9SAndroid Build Coastguard Worker if not isinstance(inputs, list): 679*da0073e9SAndroid Build Coastguard Worker inputs = [inputs, inputs] 680*da0073e9SAndroid Build Coastguard Worker for input, model, optimizer in zip(inputs, models, optimizers): 681*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker if i == 3: 684*da0073e9SAndroid Build Coastguard Worker # Freeze a layer to test if the step of this layer in 'fused' or 'foreach' 685*da0073e9SAndroid Build Coastguard Worker # is same as the step in 'forloop'. 686*da0073e9SAndroid Build Coastguard Worker model[2].requires_grad_(False) 687*da0073e9SAndroid Build Coastguard Worker if i == 5: 688*da0073e9SAndroid Build Coastguard Worker # Unfreeze the layer after 2 iters. 689*da0073e9SAndroid Build Coastguard Worker model[2].requires_grad_(True) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker # Test that step behaves as expected (a no-op) when grads are set to None 692*da0073e9SAndroid Build Coastguard Worker if i != 2: 693*da0073e9SAndroid Build Coastguard Worker output = model(input) 694*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 695*da0073e9SAndroid Build Coastguard Worker loss.backward() 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker optimizer.step() 698*da0073e9SAndroid Build Coastguard Worker state.append(optimizer.state) 699*da0073e9SAndroid Build Coastguard Worker updated_params.append(model.parameters()) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker og_state, new_state = state 702*da0073e9SAndroid Build Coastguard Worker for og_p, new_p in zip(updated_params[0], updated_params[1]): 703*da0073e9SAndroid Build Coastguard Worker tracker.add(og_p) 704*da0073e9SAndroid Build Coastguard Worker tracker.pop_check_set(new_p, self) 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker # check that optimizer states are the same 707*da0073e9SAndroid Build Coastguard Worker og_p_state = og_state[og_p] 708*da0073e9SAndroid Build Coastguard Worker new_p_state = new_state[new_p] 709*da0073e9SAndroid Build Coastguard Worker if assert_step_dtype is not None: 710*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(og_p_state.get("step", None)): 711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(og_p_state["step"].dtype, assert_step_dtype) 712*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(new_p_state.get("step", None)): 713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_p_state["step"].dtype, assert_step_dtype) 714*da0073e9SAndroid Build Coastguard Worker for k in og_p_state: 715*da0073e9SAndroid Build Coastguard Worker tracker.add(og_p_state[k]) 716*da0073e9SAndroid Build Coastguard Worker tracker.pop_check_set(new_p_state[k], self) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tracker.all_popped()) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker def _test_derived_optimizers( 721*da0073e9SAndroid Build Coastguard Worker self, 722*da0073e9SAndroid Build Coastguard Worker device, 723*da0073e9SAndroid Build Coastguard Worker dtype, 724*da0073e9SAndroid Build Coastguard Worker optim_info, 725*da0073e9SAndroid Build Coastguard Worker flag, 726*da0073e9SAndroid Build Coastguard Worker reduced_precision=False, 727*da0073e9SAndroid Build Coastguard Worker assert_step_dtype=None, 728*da0073e9SAndroid Build Coastguard Worker ): 729*da0073e9SAndroid Build Coastguard Worker """ 730*da0073e9SAndroid Build Coastguard Worker Given a flag 'fused' or 'foreach', test for parity of optimizer state 731*da0073e9SAndroid Build Coastguard Worker and updated parameters between when the flag is set to True and False 732*da0073e9SAndroid Build Coastguard Worker for provided optimizer configurations. 733*da0073e9SAndroid Build Coastguard Worker """ 734*da0073e9SAndroid Build Coastguard Worker assert flag in ("foreach", "fused") 735*da0073e9SAndroid Build Coastguard Worker assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype) 738*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 739*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 740*da0073e9SAndroid Build Coastguard Worker models, optimizers = [], [] 741*da0073e9SAndroid Build Coastguard Worker kwargs = deepcopy(optim_input.kwargs) 742*da0073e9SAndroid Build Coastguard Worker if kwargs.get("capturable", False) and _get_device_type(device) == "cpu": 743*da0073e9SAndroid Build Coastguard Worker # capturable is not supported on CPU 744*da0073e9SAndroid Build Coastguard Worker continue 745*da0073e9SAndroid Build Coastguard Worker for flag_value in (False, True): 746*da0073e9SAndroid Build Coastguard Worker kwargs[flag] = flag_value 747*da0073e9SAndroid Build Coastguard Worker input = torch.tensor( 748*da0073e9SAndroid Build Coastguard Worker [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device 749*da0073e9SAndroid Build Coastguard Worker ).reshape(3, 2) 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 752*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 753*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 3), 754*da0073e9SAndroid Build Coastguard Worker torch.nn.Sigmoid(), 755*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(3, 1), 756*da0073e9SAndroid Build Coastguard Worker torch.nn.Sigmoid(), 757*da0073e9SAndroid Build Coastguard Worker ) 758*da0073e9SAndroid Build Coastguard Worker model.to(dtype=dtype, device=device) 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker # foreach/fused optimizers should be tested with a 761*da0073e9SAndroid Build Coastguard Worker # zero_size tensor as its last param. 762*da0073e9SAndroid Build Coastguard Worker # ref: https://github.com/pytorch/pytorch/issues/100701 763*da0073e9SAndroid Build Coastguard Worker empty_param = torch.empty( 764*da0073e9SAndroid Build Coastguard Worker (), device=device, dtype=dtype, requires_grad=True 765*da0073e9SAndroid Build Coastguard Worker ) 766*da0073e9SAndroid Build Coastguard Worker empty_param.grad = torch.rand_like(empty_param) 767*da0073e9SAndroid Build Coastguard Worker params = list(model.parameters()) + [empty_param] 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **kwargs) 770*da0073e9SAndroid Build Coastguard Worker models.append(model) 771*da0073e9SAndroid Build Coastguard Worker optimizers.append(optimizer) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker self._compare_between( 774*da0073e9SAndroid Build Coastguard Worker input, models, optimizers, assert_eq_kwargs, assert_step_dtype 775*da0073e9SAndroid Build Coastguard Worker ) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker @skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350 778*da0073e9SAndroid Build Coastguard Worker @optims( 779*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "foreach" in optim.supported_impls], 780*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float64], 781*da0073e9SAndroid Build Coastguard Worker ) 782*da0073e9SAndroid Build Coastguard Worker def test_foreach_matches_forloop(self, device, dtype, optim_info): 783*da0073e9SAndroid Build Coastguard Worker self._test_derived_optimizers(device, dtype, optim_info, "foreach") 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 786*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 787*da0073e9SAndroid Build Coastguard Worker @parametrize("impl", ["foreach", "fused"]) 788*da0073e9SAndroid Build Coastguard Worker @optims( 789*da0073e9SAndroid Build Coastguard Worker [ 790*da0073e9SAndroid Build Coastguard Worker optim 791*da0073e9SAndroid Build Coastguard Worker for optim in optim_db 792*da0073e9SAndroid Build Coastguard Worker if "foreach" in optim.supported_impls or "fused" in optim.supported_impls 793*da0073e9SAndroid Build Coastguard Worker ] 794*da0073e9SAndroid Build Coastguard Worker ) 795*da0073e9SAndroid Build Coastguard Worker def test_mixed_device_dtype(self, device, dtype, optim_info, impl): 796*da0073e9SAndroid Build Coastguard Worker """ 797*da0073e9SAndroid Build Coastguard Worker Similar in essence to _test_derived_optimizers above. The main difference is that 798*da0073e9SAndroid Build Coastguard Worker _test_derived_optimizers uses model parameters whereas we randomly pass in 799*da0073e9SAndroid Build Coastguard Worker parameters of different dtypes and devices here. We need multiple GPUs (vs just a 800*da0073e9SAndroid Build Coastguard Worker CPU and GPU) because fused adam only works on GPUs. (Thus we only run the tests 801*da0073e9SAndroid Build Coastguard Worker that call into this helper when TEST_MULTIGPU.) 802*da0073e9SAndroid Build Coastguard Worker """ 803*da0073e9SAndroid Build Coastguard Worker assert impl in ("foreach", "fused") 804*da0073e9SAndroid Build Coastguard Worker if impl == "foreach" and "foreach" not in optim_info.supported_impls: 805*da0073e9SAndroid Build Coastguard Worker return unittest.skip( 806*da0073e9SAndroid Build Coastguard Worker f"foreach not supported for {optim_info.optim_cls.__name__}" 807*da0073e9SAndroid Build Coastguard Worker ) 808*da0073e9SAndroid Build Coastguard Worker elif impl == "fused" and "cuda" not in optim_info.supports_fused_on: 809*da0073e9SAndroid Build Coastguard Worker return unittest.skip( 810*da0073e9SAndroid Build Coastguard Worker f"fused not supported for {optim_info.optim_cls.__name__} on cuda" 811*da0073e9SAndroid Build Coastguard Worker ) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker params = [ 814*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.float64, device="cuda:0", requires_grad=True), 815*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.float32, device="cuda:0", requires_grad=True), 816*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.float16, device="cuda:0", requires_grad=True), 817*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:0", requires_grad=True), 818*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.float64, device="cuda:1", requires_grad=True), 819*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.float32, device="cuda:1", requires_grad=True), 820*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.float16, device="cuda:1", requires_grad=True), 821*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:1", requires_grad=True), 822*da0073e9SAndroid Build Coastguard Worker torch.randint( 823*da0073e9SAndroid Build Coastguard Worker 1024, (2, 3), dtype=torch.int64, device="cuda:1", requires_grad=False 824*da0073e9SAndroid Build Coastguard Worker ), 825*da0073e9SAndroid Build Coastguard Worker ] 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker for p in params: 828*da0073e9SAndroid Build Coastguard Worker if p.requires_grad: 829*da0073e9SAndroid Build Coastguard Worker p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker kIterations = 7 if impl == "foreach" else 1 832*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 833*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 834*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 835*da0073e9SAndroid Build Coastguard Worker updated_params, state = [], [] 836*da0073e9SAndroid Build Coastguard Worker kwargs = deepcopy(optim_input.kwargs) 837*da0073e9SAndroid Build Coastguard Worker if kwargs.get("capturable", False) and _get_device_type(device) == "cpu": 838*da0073e9SAndroid Build Coastguard Worker # capturable is not supported on CPU 839*da0073e9SAndroid Build Coastguard Worker continue 840*da0073e9SAndroid Build Coastguard Worker for use_impl in (False, True): 841*da0073e9SAndroid Build Coastguard Worker kwargs[impl] = use_impl 842*da0073e9SAndroid Build Coastguard Worker params_clone = [] 843*da0073e9SAndroid Build Coastguard Worker for p in params: 844*da0073e9SAndroid Build Coastguard Worker p_clone = p.clone().detach() 845*da0073e9SAndroid Build Coastguard Worker if p.requires_grad: 846*da0073e9SAndroid Build Coastguard Worker p_clone.requires_grad = True 847*da0073e9SAndroid Build Coastguard Worker p_clone.grad = p.grad.clone().detach() 848*da0073e9SAndroid Build Coastguard Worker params_clone.append(p_clone) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params_clone, **kwargs) 851*da0073e9SAndroid Build Coastguard Worker for _ in range(kIterations): 852*da0073e9SAndroid Build Coastguard Worker optimizer.step() 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker state.append(optimizer.state) 855*da0073e9SAndroid Build Coastguard Worker updated_params.append(params_clone) 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker og_state, new_state = state 858*da0073e9SAndroid Build Coastguard Worker for og_p, new_p in zip(updated_params[0], updated_params[1]): 859*da0073e9SAndroid Build Coastguard Worker # Increasing the tolerance as we are collating lots of ops together for optimizers and 860*da0073e9SAndroid Build Coastguard Worker # the designated tolerances are for single op only. 861*da0073e9SAndroid Build Coastguard Worker single_rtol, single_atol = torch.testing._comparison.get_tolerances( 862*da0073e9SAndroid Build Coastguard Worker new_p.dtype, rtol=None, atol=None 863*da0073e9SAndroid Build Coastguard Worker ) 864*da0073e9SAndroid Build Coastguard Worker rtol = 5 * single_rtol 865*da0073e9SAndroid Build Coastguard Worker atol = 5 * single_atol 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(og_p, new_p, rtol=rtol, atol=atol) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker # check that optimizer states are the same 870*da0073e9SAndroid Build Coastguard Worker og_p_state = og_state[og_p] 871*da0073e9SAndroid Build Coastguard Worker new_p_state = new_state[new_p] 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker for k in og_p_state: 874*da0073e9SAndroid Build Coastguard Worker actual = new_p_state[k] 875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(og_p_state[k], actual, rtol=rtol, atol=atol) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 878*da0073e9SAndroid Build Coastguard Worker @optims( 879*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "foreach" in optim.supported_impls], 880*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float64], 881*da0073e9SAndroid Build Coastguard Worker ) 882*da0073e9SAndroid Build Coastguard Worker def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info): 883*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/110940 884*da0073e9SAndroid Build Coastguard Worker # We coerce step to always be float32 unless the 885*da0073e9SAndroid Build Coastguard Worker # default dtype is higher prec float64 886*da0073e9SAndroid Build Coastguard Worker old_default_dtype = torch.get_default_dtype() 887*da0073e9SAndroid Build Coastguard Worker for default_dtype in [torch.float64, torch.float16]: 888*da0073e9SAndroid Build Coastguard Worker try: 889*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(default_dtype) 890*da0073e9SAndroid Build Coastguard Worker self._test_derived_optimizers( 891*da0073e9SAndroid Build Coastguard Worker device, 892*da0073e9SAndroid Build Coastguard Worker dtype, 893*da0073e9SAndroid Build Coastguard Worker optim_info, 894*da0073e9SAndroid Build Coastguard Worker "foreach", 895*da0073e9SAndroid Build Coastguard Worker reduced_precision=default_dtype == torch.float16, 896*da0073e9SAndroid Build Coastguard Worker assert_step_dtype=( 897*da0073e9SAndroid Build Coastguard Worker torch.float64 898*da0073e9SAndroid Build Coastguard Worker if default_dtype == torch.float64 899*da0073e9SAndroid Build Coastguard Worker else torch.float32 900*da0073e9SAndroid Build Coastguard Worker ), 901*da0073e9SAndroid Build Coastguard Worker ) 902*da0073e9SAndroid Build Coastguard Worker finally: 903*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(old_default_dtype) 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 906*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("72GB", "cuda") 907*da0073e9SAndroid Build Coastguard Worker @optims( 908*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "foreach" in optim.supported_impls], 909*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float16], 910*da0073e9SAndroid Build Coastguard Worker ) 911*da0073e9SAndroid Build Coastguard Worker def test_foreach_large_tensor(self, device, dtype, optim_info): 912*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 913*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 914*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 915*da0073e9SAndroid Build Coastguard Worker params = [torch.ones(2**32, device=device, dtype=dtype)] 916*da0073e9SAndroid Build Coastguard Worker params[0].grad = torch.zeros_like(params[0]) 917*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, foreach=True, **optim_input.kwargs) 918*da0073e9SAndroid Build Coastguard Worker optimizer.step() 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 921*da0073e9SAndroid Build Coastguard Worker @optims( 922*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "foreach" in optim.supported_impls], 923*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 924*da0073e9SAndroid Build Coastguard Worker ) 925*da0073e9SAndroid Build Coastguard Worker def test_peak_memory_foreach(self, device, dtype, optim_info): 926*da0073e9SAndroid Build Coastguard Worker nparams = 10 927*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 928*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 929*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 930*da0073e9SAndroid Build Coastguard Worker kwargs = deepcopy(optim_input.kwargs) 931*da0073e9SAndroid Build Coastguard Worker max_mems = [] 932*da0073e9SAndroid Build Coastguard Worker for flag_value in (False, True): 933*da0073e9SAndroid Build Coastguard Worker kwargs["foreach"] = flag_value 934*da0073e9SAndroid Build Coastguard Worker # The 16 * 8 = 128 is critical here! Our CUDACachingAllocator allocates in blocks 935*da0073e9SAndroid Build Coastguard Worker # of 512, meaning any tensor that occupies <512 bytes of memory will allocate a 936*da0073e9SAndroid Build Coastguard Worker # whole 512 bytes anyway. We use 128 (cuz datasize would be 4 bytes) so that param 937*da0073e9SAndroid Build Coastguard Worker # is size 512 exactly, making our later calculations for intermediate_size easy. 938*da0073e9SAndroid Build Coastguard Worker param = torch.rand(16, 8, device=device, dtype=dtype) 939*da0073e9SAndroid Build Coastguard Worker params = [torch.rand_like(param) for _ in range(nparams)] 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **kwargs) 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker for p in params: 944*da0073e9SAndroid Build Coastguard Worker p.grad = torch.rand_like(p) 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker optimizer.step() 947*da0073e9SAndroid Build Coastguard Worker import gc 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker gc.collect() 950*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats() 951*da0073e9SAndroid Build Coastguard Worker optimizer.step() 952*da0073e9SAndroid Build Coastguard Worker gc.collect() 953*da0073e9SAndroid Build Coastguard Worker max_mems.append(torch.cuda.max_memory_allocated()) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker st_max_mem, mt_max_mem = max_mems 956*da0073e9SAndroid Build Coastguard Worker intermediate_size = nparams * param.nelement() * param.element_size() 957*da0073e9SAndroid Build Coastguard Worker nintermediates = 1 # we expect a budget of 1 intermediate most of the time 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker # Check the param group directly to handle if the compiler set capturable 960*da0073e9SAndroid Build Coastguard Worker if optimizer.param_groups[0].get( 961*da0073e9SAndroid Build Coastguard Worker "capturable", False 962*da0073e9SAndroid Build Coastguard Worker ) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]: 963*da0073e9SAndroid Build Coastguard Worker # with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections 964*da0073e9SAndroid Build Coastguard Worker # with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps) 965*da0073e9SAndroid Build Coastguard Worker # ASGD allocates axs, 2x mus, 2x etas, and grads at the same time 966*da0073e9SAndroid Build Coastguard Worker nintermediates = 3 967*da0073e9SAndroid Build Coastguard Worker if optim_cls.__name__ == "NAdam": 968*da0073e9SAndroid Build Coastguard Worker # with capturable in NAdam, we have 3 extra intermediates for the 969*da0073e9SAndroid Build Coastguard Worker # bias_correction, mus, and mu_nexts 970*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_TORCHDYNAMO: 971*da0073e9SAndroid Build Coastguard Worker # With dynamo, the eager/FX backend appears to hold memory longer than 972*da0073e9SAndroid Build Coastguard Worker # vanilla eager: https://github.com/pytorch/pytorch/issues/125511 973*da0073e9SAndroid Build Coastguard Worker nintermediates = 8 974*da0073e9SAndroid Build Coastguard Worker else: 975*da0073e9SAndroid Build Coastguard Worker nintermediates = 5 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker if optim_cls.__name__ == "RAdam": 978*da0073e9SAndroid Build Coastguard Worker # RAdam has four intermediates with capturable 979*da0073e9SAndroid Build Coastguard Worker # num, unrect_step_size, buffer, grouped_grads 980*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_TORCHDYNAMO: 981*da0073e9SAndroid Build Coastguard Worker # With dynamo, the eager/FX backend appears to hold memory than 982*da0073e9SAndroid Build Coastguard Worker # vanilla eager: https://github.com/pytorch/pytorch/issues/125511 983*da0073e9SAndroid Build Coastguard Worker nintermediates = 6 984*da0073e9SAndroid Build Coastguard Worker else: 985*da0073e9SAndroid Build Coastguard Worker nintermediates = 4 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop", "Adafactor"]: 988*da0073e9SAndroid Build Coastguard Worker # NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt) 989*da0073e9SAndroid Build Coastguard Worker # Adagrad uses std and grads at the same time 990*da0073e9SAndroid Build Coastguard Worker # RMSprop uses avg and grads 991*da0073e9SAndroid Build Coastguard Worker # Adafactor uses row/col var and its mean 992*da0073e9SAndroid Build Coastguard Worker nintermediates = 2 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker if optim_cls.__name__ == "Adafactor" and kwargs.get("maximize", False): 995*da0073e9SAndroid Build Coastguard Worker # When maximize is True, Adafactor also tracks device_grad 996*da0073e9SAndroid Build Coastguard Worker nintermediates = 3 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker # Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam 999*da0073e9SAndroid Build Coastguard Worker # which makes the foreach memory check fail 1000*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_TORCHDYNAMO: 1001*da0073e9SAndroid Build Coastguard Worker st_max_mem += 6000 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker expected_max_mem = st_max_mem + intermediate_size * nintermediates 1004*da0073e9SAndroid Build Coastguard Worker # hipcc currently can't generate efficient code for the small buffer optimization 1005*da0073e9SAndroid Build Coastguard Worker # code path (see Note [small buffer optimization] for details), thus we always 1006*da0073e9SAndroid Build Coastguard Worker # dynamically allocate the tensor metadata for ROCM. Adjusting the expected max 1007*da0073e9SAndroid Build Coastguard Worker # memory usage to account for this. 1008*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 1009*da0073e9SAndroid Build Coastguard Worker expected_max_mem *= 1.02 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(mt_max_mem, expected_max_mem) 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker @optims( 1014*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "fused" in optim.supported_impls], 1015*da0073e9SAndroid Build Coastguard Worker dtypes=floating_types_and( 1016*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 1017*da0073e9SAndroid Build Coastguard Worker torch.float16, 1018*da0073e9SAndroid Build Coastguard Worker ), 1019*da0073e9SAndroid Build Coastguard Worker ) 1020*da0073e9SAndroid Build Coastguard Worker def test_fused_matches_forloop(self, device, dtype, optim_info): 1021*da0073e9SAndroid Build Coastguard Worker if _get_device_type(device) not in optim_info.supports_fused_on: 1022*da0073e9SAndroid Build Coastguard Worker self.skipTest( 1023*da0073e9SAndroid Build Coastguard Worker f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1024*da0073e9SAndroid Build Coastguard Worker ) 1025*da0073e9SAndroid Build Coastguard Worker if _get_device_type(device) == "mps" and dtype not in ( 1026*da0073e9SAndroid Build Coastguard Worker torch.float16, 1027*da0073e9SAndroid Build Coastguard Worker torch.float32, 1028*da0073e9SAndroid Build Coastguard Worker ): 1029*da0073e9SAndroid Build Coastguard Worker self.skipTest("MPS supports only torch.float16 and torch.float32") 1030*da0073e9SAndroid Build Coastguard Worker self._test_derived_optimizers(device, dtype, optim_info, "fused") 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker @optims( 1033*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "fused" in optim.supported_impls], 1034*da0073e9SAndroid Build Coastguard Worker dtypes=(torch.float32,), 1035*da0073e9SAndroid Build Coastguard Worker ) 1036*da0073e9SAndroid Build Coastguard Worker def test_fused_error_on_params_on_meta(self, device, dtype, optim_info): 1037*da0073e9SAndroid Build Coastguard Worker if _get_device_type(device) not in optim_info.supports_fused_on: 1038*da0073e9SAndroid Build Coastguard Worker self.skipTest( 1039*da0073e9SAndroid Build Coastguard Worker f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1040*da0073e9SAndroid Build Coastguard Worker ) 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker with torch.device("meta"): 1043*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 1044*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 3), 1045*da0073e9SAndroid Build Coastguard Worker torch.nn.Sigmoid(), 1046*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(3, 1), 1047*da0073e9SAndroid Build Coastguard Worker torch.nn.Sigmoid(), 1048*da0073e9SAndroid Build Coastguard Worker ).to(dtype) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker optimizer = optim_info.optim_cls(model.parameters(), fused=True) 1051*da0073e9SAndroid Build Coastguard Worker with torch.device("meta"): 1052*da0073e9SAndroid Build Coastguard Worker for p in model.parameters(): 1053*da0073e9SAndroid Build Coastguard Worker p.grad = torch.rand_like(p) 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1056*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1057*da0073e9SAndroid Build Coastguard Worker "`fused=True` requires all the params to be floating point Tensors", 1058*da0073e9SAndroid Build Coastguard Worker ): 1059*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad(set_to_none=True) 1062*da0073e9SAndroid Build Coastguard Worker model.to_empty(device=device) 1063*da0073e9SAndroid Build Coastguard Worker for p in model.parameters(): 1064*da0073e9SAndroid Build Coastguard Worker p.grad = torch.rand_like(p) 1065*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1066*da0073e9SAndroid Build Coastguard Worker 1067*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1068*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("64GB") 1069*da0073e9SAndroid Build Coastguard Worker @optims( 1070*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "fused" in optim.supported_impls], 1071*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float16], 1072*da0073e9SAndroid Build Coastguard Worker ) 1073*da0073e9SAndroid Build Coastguard Worker def test_fused_large_tensor(self, device, dtype, optim_info): 1074*da0073e9SAndroid Build Coastguard Worker if device not in optim_info.supports_fused_on: 1075*da0073e9SAndroid Build Coastguard Worker self.skipTest( 1076*da0073e9SAndroid Build Coastguard Worker f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1077*da0073e9SAndroid Build Coastguard Worker ) 1078*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1079*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 1080*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 1081*da0073e9SAndroid Build Coastguard Worker params = [torch.ones(2**32, device=device, dtype=dtype)] 1082*da0073e9SAndroid Build Coastguard Worker params[0].grad = torch.zeros_like(params[0]) 1083*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, fused=True, **optim_input.kwargs) 1084*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1085*da0073e9SAndroid Build Coastguard Worker 1086*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1087*da0073e9SAndroid Build Coastguard Worker @optims( 1088*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "fused" in optim.supported_impls], 1089*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 1090*da0073e9SAndroid Build Coastguard Worker ) 1091*da0073e9SAndroid Build Coastguard Worker def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info): 1092*da0073e9SAndroid Build Coastguard Worker if device not in optim_info.supports_fused_on: 1093*da0073e9SAndroid Build Coastguard Worker self.skipTest( 1094*da0073e9SAndroid Build Coastguard Worker f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" 1095*da0073e9SAndroid Build Coastguard Worker ) 1096*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1097*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device=device) 1098*da0073e9SAndroid Build Coastguard Worker num_params = 5 1099*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 1100*da0073e9SAndroid Build Coastguard Worker for no_grad_scale in (False, True): 1101*da0073e9SAndroid Build Coastguard Worker params = [ 1102*da0073e9SAndroid Build Coastguard Worker torch.ones((1,), device=device, dtype=dtype) 1103*da0073e9SAndroid Build Coastguard Worker for _ in range(num_params) 1104*da0073e9SAndroid Build Coastguard Worker ] 1105*da0073e9SAndroid Build Coastguard Worker params_c = [param.clone().detach() for param in params] 1106*da0073e9SAndroid Build Coastguard Worker for p in params: 1107*da0073e9SAndroid Build Coastguard Worker p.grad = torch.ones_like(p) 1108*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, fused=True, **optim_input.kwargs) 1109*da0073e9SAndroid Build Coastguard Worker optimizer.grad_scale = ( 1110*da0073e9SAndroid Build Coastguard Worker None 1111*da0073e9SAndroid Build Coastguard Worker if no_grad_scale 1112*da0073e9SAndroid Build Coastguard Worker else torch.ones((1,), dtype=dtype, device=device) 1113*da0073e9SAndroid Build Coastguard Worker ) 1114*da0073e9SAndroid Build Coastguard Worker optimizer.found_inf = torch.ones((), dtype=dtype, device=device) 1115*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1116*da0073e9SAndroid Build Coastguard Worker for p in params: 1117*da0073e9SAndroid Build Coastguard Worker if "step" in optimizer.state[p]: 1118*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1119*da0073e9SAndroid Build Coastguard Worker torch.zeros((), dtype=dtype, device=device), 1120*da0073e9SAndroid Build Coastguard Worker optimizer.state[p]["step"], 1121*da0073e9SAndroid Build Coastguard Worker ) 1122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(params, params_c) 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker @parametrize("impl", ["fused", "capturable"]) 1125*da0073e9SAndroid Build Coastguard Worker @optims( 1126*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if "fused" in optim.supported_impls], 1127*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 1128*da0073e9SAndroid Build Coastguard Worker ) 1129*da0073e9SAndroid Build Coastguard Worker def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): 1130*da0073e9SAndroid Build Coastguard Worker # NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256 1131*da0073e9SAndroid Build Coastguard Worker # How do we get there? Users typically create CUDA models on fused optimizers and then 1132*da0073e9SAndroid Build Coastguard Worker # store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu"). 1133*da0073e9SAndroid Build Coastguard Worker # Since this is a unit test, it is more expedient to simulate what the state_dict 1134*da0073e9SAndroid Build Coastguard Worker # would look like, which is basically CPU tensors with fused/capturable flag = True. 1135*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1136*da0073e9SAndroid Build Coastguard Worker opt_name = optim_cls.__name__ 1137*da0073e9SAndroid Build Coastguard Worker if opt_name in ("SGD", "Adagrad") and impl == "capturable": 1138*da0073e9SAndroid Build Coastguard Worker # Capturable SGD/Adagrad does not exist 1139*da0073e9SAndroid Build Coastguard Worker self.skipTest("SGD does not currently support capturable") 1140*da0073e9SAndroid Build Coastguard Worker if _get_device_type(device) == "cpu": 1141*da0073e9SAndroid Build Coastguard Worker self.skipTest("Test is only for non-cpu devices") 1142*da0073e9SAndroid Build Coastguard Worker elif ( 1143*da0073e9SAndroid Build Coastguard Worker impl == "fused" 1144*da0073e9SAndroid Build Coastguard Worker and _get_device_type(device) not in optim_info.supports_fused_on 1145*da0073e9SAndroid Build Coastguard Worker ): 1146*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{device} is not supported for fused on {opt_name}") 1147*da0073e9SAndroid Build Coastguard Worker elif impl == "capturable" and _get_device_type(device) == "mps": 1148*da0073e9SAndroid Build Coastguard Worker self.skipTest("MPS does not support capturable") 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu") 1151*da0073e9SAndroid Build Coastguard Worker for optim_input in cpu_optim_inputs: 1152*da0073e9SAndroid Build Coastguard Worker param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu") 1153*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls([param], **optim_input.kwargs) 1154*da0073e9SAndroid Build Coastguard Worker param.grad = torch.rand_like(param) 1155*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1156*da0073e9SAndroid Build Coastguard Worker optim_state_dict_cpu = deepcopy(optimizer.state_dict()) 1157*da0073e9SAndroid Build Coastguard Worker optim_state_dict_cpu["param_groups"][0][impl] = True 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker # load 1160*da0073e9SAndroid Build Coastguard Worker optim_input.kwargs[impl] = True 1161*da0073e9SAndroid Build Coastguard Worker param_device = param.clone().detach().to(device=device) 1162*da0073e9SAndroid Build Coastguard Worker optimizer_device = optim_cls([param_device], **optim_input.kwargs) 1163*da0073e9SAndroid Build Coastguard Worker optimizer_device.load_state_dict(optim_state_dict_cpu) 1164*da0073e9SAndroid Build Coastguard Worker optimizer_device.zero_grad() 1165*da0073e9SAndroid Build Coastguard Worker param_device.grad = torch.rand_like(param_device) 1166*da0073e9SAndroid Build Coastguard Worker optimizer_device.step() 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1169*da0073e9SAndroid Build Coastguard Worker def test_param_groups_weight_decay(self, device, dtype, optim_info): 1170*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1171*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1172*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1173*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 1174*da0073e9SAndroid Build Coastguard Worker ) 1175*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1176*da0073e9SAndroid Build Coastguard Worker weight_kwargs = optim_input.kwargs 1177*da0073e9SAndroid Build Coastguard Worker bias_kwargs = deepcopy(optim_input.kwargs) 1178*da0073e9SAndroid Build Coastguard Worker bias_kwargs["weight_decay"] = 0.0 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 1181*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 1182*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, device=device, dtype=dtype) 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls( 1185*da0073e9SAndroid Build Coastguard Worker [ 1186*da0073e9SAndroid Build Coastguard Worker dict(params=[weight], **weight_kwargs), 1187*da0073e9SAndroid Build Coastguard Worker dict(params=[bias], **bias_kwargs), 1188*da0073e9SAndroid Build Coastguard Worker ] 1189*da0073e9SAndroid Build Coastguard Worker ) 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(input) + bias).pow(2).sum() 1192*da0073e9SAndroid Build Coastguard Worker initial_value = loss.item() 1193*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 1194*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 1195*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(input) + bias).pow(2).sum() 1196*da0073e9SAndroid Build Coastguard Worker loss.backward() 1197*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1198*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 1199*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 1200*da0073e9SAndroid Build Coastguard Worker weight.grad = weight.grad.to_sparse() 1201*da0073e9SAndroid Build Coastguard Worker bias.grad = bias.grad.to_sparse() 1202*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker # Test that the direction of loss moved appropriately 1205*da0073e9SAndroid Build Coastguard Worker if optim_input.kwargs.get("maximize", False): 1206*da0073e9SAndroid Build Coastguard Worker self.assertGreater(loss.item(), initial_value) 1207*da0073e9SAndroid Build Coastguard Worker else: 1208*da0073e9SAndroid Build Coastguard Worker self.assertLess(loss.item(), initial_value) 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1211*da0073e9SAndroid Build Coastguard Worker def test_param_groups_lr(self, device, dtype, optim_info): 1212*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1213*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1214*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1215*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 1216*da0073e9SAndroid Build Coastguard Worker ) 1217*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1218*da0073e9SAndroid Build Coastguard Worker # optim_input.kwargs will be the param group kwargs, which should have >0 lr 1219*da0073e9SAndroid Build Coastguard Worker if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0: 1220*da0073e9SAndroid Build Coastguard Worker optim_input.kwargs["lr"] = 1e-3 1221*da0073e9SAndroid Build Coastguard Worker outer_kwargs = {"lr": 1e-28} 1222*da0073e9SAndroid Build Coastguard Worker if optim_cls.__name__ == "Rprop": 1223*da0073e9SAndroid Build Coastguard Worker # Allow min step size to be 0 1224*da0073e9SAndroid Build Coastguard Worker outer_kwargs["step_sizes"] = (0, 50) 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) 1227*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn((10), device=device, dtype=dtype)) 1228*da0073e9SAndroid Build Coastguard Worker irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype)) 1229*da0073e9SAndroid Build Coastguard Worker irrelevant_clone = irrelevant.clone() 1230*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, device=device, dtype=dtype) 1231*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls( 1232*da0073e9SAndroid Build Coastguard Worker [ 1233*da0073e9SAndroid Build Coastguard Worker dict(params=[weight, bias], **optim_input.kwargs), 1234*da0073e9SAndroid Build Coastguard Worker dict(params=[irrelevant]), 1235*da0073e9SAndroid Build Coastguard Worker ], 1236*da0073e9SAndroid Build Coastguard Worker **outer_kwargs, 1237*da0073e9SAndroid Build Coastguard Worker ) 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(input) + bias).pow(2).sum() 1240*da0073e9SAndroid Build Coastguard Worker initial_value = loss.item() 1241*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 1242*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 1243*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(input) + bias).pow(2).sum() 1244*da0073e9SAndroid Build Coastguard Worker loss.backward() 1245*da0073e9SAndroid Build Coastguard Worker irrelevant.grad = torch.rand_like(irrelevant) 1246*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1247*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 1248*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 1249*da0073e9SAndroid Build Coastguard Worker weight.grad = weight.grad.to_sparse() 1250*da0073e9SAndroid Build Coastguard Worker bias.grad = bias.grad.to_sparse() 1251*da0073e9SAndroid Build Coastguard Worker irrelevant.grad = irrelevant.grad.to_sparse() 1252*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Worker # Test that the direction of loss moved appropriately 1255*da0073e9SAndroid Build Coastguard Worker if optim_input.kwargs.get("maximize", False): 1256*da0073e9SAndroid Build Coastguard Worker self.assertGreater(loss.item(), initial_value) 1257*da0073e9SAndroid Build Coastguard Worker else: 1258*da0073e9SAndroid Build Coastguard Worker self.assertLess(loss.item(), initial_value) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker # Test that irrelevant parameters were not updated since lr was almost 0 1261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(irrelevant, irrelevant_clone) 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1264*da0073e9SAndroid Build Coastguard Worker def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info): 1265*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1266*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1267*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1268*da0073e9SAndroid Build Coastguard Worker ) 1269*da0073e9SAndroid Build Coastguard Worker params = [ 1270*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=False, device=device, dtype=dtype) 1271*da0073e9SAndroid Build Coastguard Worker for _ in range(2) 1272*da0073e9SAndroid Build Coastguard Worker ] 1273*da0073e9SAndroid Build Coastguard Worker old_params = [p.clone().detach() for p in params] 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker def closure(): 1276*da0073e9SAndroid Build Coastguard Worker return torch.tensor([1], device=device, dtype=dtype) 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1279*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **optim_input.kwargs) 1280*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1281*da0073e9SAndroid Build Coastguard Worker 1282*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1283*da0073e9SAndroid Build Coastguard Worker def test_step_is_noop_for_zero_grads(self, device, dtype, optim_info): 1284*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1285*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1286*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1287*da0073e9SAndroid Build Coastguard Worker ) 1288*da0073e9SAndroid Build Coastguard Worker param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True) 1289*da0073e9SAndroid Build Coastguard Worker old_param = param.clone().detach() 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker def closure(): 1292*da0073e9SAndroid Build Coastguard Worker return torch.tensor([1], device=device, dtype=dtype) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1295*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker # params will decay even if grads are empty if weight_decay != 0, 1298*da0073e9SAndroid Build Coastguard Worker # and capturable doesn't work for CPU tensors 1299*da0073e9SAndroid Build Coastguard Worker if kwargs.get("weight_decay", 0) != 0: 1300*da0073e9SAndroid Build Coastguard Worker continue 1301*da0073e9SAndroid Build Coastguard Worker 1302*da0073e9SAndroid Build Coastguard Worker # AdamW params will be updated regardless of grads due to lr, so make lr smaller 1303*da0073e9SAndroid Build Coastguard Worker if optim_cls.__name__ == "AdamW": 1304*da0073e9SAndroid Build Coastguard Worker kwargs["lr"] = ( 1305*da0073e9SAndroid Build Coastguard Worker torch.tensor(1e-5) 1306*da0073e9SAndroid Build Coastguard Worker if isinstance(kwargs.get("lr", 1e-5), torch.Tensor) 1307*da0073e9SAndroid Build Coastguard Worker else 1e-5 1308*da0073e9SAndroid Build Coastguard Worker ) 1309*da0073e9SAndroid Build Coastguard Worker 1310*da0073e9SAndroid Build Coastguard Worker if kwargs.get("differentiable", False): 1311*da0073e9SAndroid Build Coastguard Worker params = [param.clone()] 1312*da0073e9SAndroid Build Coastguard Worker else: 1313*da0073e9SAndroid Build Coastguard Worker params = [param] 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **kwargs) 1316*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1317*da0073e9SAndroid Build Coastguard Worker # Intentionally construct a multidimensional empty v for the sparse grad 1318*da0073e9SAndroid Build Coastguard Worker # Single dim v passes the test while multidim correctly repros the issue 1319*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/82486 1320*da0073e9SAndroid Build Coastguard Worker i = torch.empty((1, 0), device=device, dtype=dtype) 1321*da0073e9SAndroid Build Coastguard Worker v = torch.empty((0, 1), device=device, dtype=dtype) 1322*da0073e9SAndroid Build Coastguard Worker params[0].grad = torch.sparse_coo_tensor( 1323*da0073e9SAndroid Build Coastguard Worker i, v, (5, 1), device=device, dtype=dtype 1324*da0073e9SAndroid Build Coastguard Worker ) 1325*da0073e9SAndroid Build Coastguard Worker else: 1326*da0073e9SAndroid Build Coastguard Worker params[0].grad = torch.zeros_like(params[0]) 1327*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_param, params[0]) 1329*da0073e9SAndroid Build Coastguard Worker 1330*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1331*da0073e9SAndroid Build Coastguard Worker def test_optimizer_can_be_printed(self, device, dtype, optim_info): 1332*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1333*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1334*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1335*da0073e9SAndroid Build Coastguard Worker ) 1336*da0073e9SAndroid Build Coastguard Worker params = [ 1337*da0073e9SAndroid Build Coastguard Worker Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)) 1338*da0073e9SAndroid Build Coastguard Worker for _ in range(2) 1339*da0073e9SAndroid Build Coastguard Worker ] 1340*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1341*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **optim_input.kwargs) 1342*da0073e9SAndroid Build Coastguard Worker optimizer.__repr__() 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1345*da0073e9SAndroid Build Coastguard Worker def test_state_dict_deterministic(self, device, dtype, optim_info): 1346*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1347*da0073e9SAndroid Build Coastguard Worker 1348*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1349*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1350*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 1351*da0073e9SAndroid Build Coastguard Worker ) 1352*da0073e9SAndroid Build Coastguard Worker weight = Parameter( 1353*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype) 1354*da0073e9SAndroid Build Coastguard Worker ) 1355*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype)) 1356*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) 1357*da0073e9SAndroid Build Coastguard Worker params = [weight, bias] 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker def fwd_bwd(optim, w, b, i): 1360*da0073e9SAndroid Build Coastguard Worker optim.zero_grad() 1361*da0073e9SAndroid Build Coastguard Worker loss = (w.mv(i) + b).pow(2).sum() 1362*da0073e9SAndroid Build Coastguard Worker loss.backward() 1363*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1364*da0073e9SAndroid Build Coastguard Worker if w.grad is not None: 1365*da0073e9SAndroid Build Coastguard Worker w.grad = w.grad.to_sparse() 1366*da0073e9SAndroid Build Coastguard Worker if b.grad is not None: 1367*da0073e9SAndroid Build Coastguard Worker b.grad = b.grad.to_sparse() 1368*da0073e9SAndroid Build Coastguard Worker return loss 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1371*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **optim_input.kwargs) 1372*da0073e9SAndroid Build Coastguard Worker closure = functools.partial(fwd_bwd, optimizer, weight, bias, input) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker # Prime the optimizer 1375*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1376*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 1377*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1378*da0073e9SAndroid Build Coastguard Worker else: 1379*da0073e9SAndroid Build Coastguard Worker closure() 1380*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1381*da0073e9SAndroid Build Coastguard Worker 1382*da0073e9SAndroid Build Coastguard Worker # Clone the weights and construct a new optimizer for them 1383*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1384*da0073e9SAndroid Build Coastguard Worker weight_c = Parameter(weight.clone()) 1385*da0073e9SAndroid Build Coastguard Worker bias_c = Parameter(bias.clone()) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs) 1388*da0073e9SAndroid Build Coastguard Worker closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input) 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker # Load the state dict from the original optimizer into the new one 1391*da0073e9SAndroid Build Coastguard Worker optimizer_c.load_state_dict(deepcopy(optimizer.state_dict())) 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker # Run both optimizers in parallel 1394*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1395*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 1396*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1397*da0073e9SAndroid Build Coastguard Worker optimizer_c.step(closure_c) 1398*da0073e9SAndroid Build Coastguard Worker else: 1399*da0073e9SAndroid Build Coastguard Worker closure() 1400*da0073e9SAndroid Build Coastguard Worker closure_c() 1401*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1402*da0073e9SAndroid Build Coastguard Worker optimizer_c.step() 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight, weight_c) 1405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bias, bias_c) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker # Make sure state dict is deterministic with equal (not identical) parameters 1408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker # Make sure repeated parameters have identical representation (see #36831) 1411*da0073e9SAndroid Build Coastguard Worker optimizer_c.param_groups.extend(optimizer_c.param_groups) 1412*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1413*da0073e9SAndroid Build Coastguard Worker optimizer.state_dict()["param_groups"][-1], 1414*da0073e9SAndroid Build Coastguard Worker optimizer_c.state_dict()["param_groups"][-1], 1415*da0073e9SAndroid Build Coastguard Worker ) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1418*da0073e9SAndroid Build Coastguard Worker def test_can_load_older_state_dict(self, device, dtype, optim_info): 1419*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1422*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1423*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 1424*da0073e9SAndroid Build Coastguard Worker ) 1425*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1426*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 1427*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 1428*da0073e9SAndroid Build Coastguard Worker torch.nn.Conv2d(4, 2, 1, stride=2), 1429*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1), 1430*da0073e9SAndroid Build Coastguard Worker ) 1431*da0073e9SAndroid Build Coastguard Worker model.to(dtype=dtype, device=device) 1432*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype) 1433*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(model.parameters(), **optim_input.kwargs) 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker def fwd_bwd(optim, mod, i): 1436*da0073e9SAndroid Build Coastguard Worker optim.zero_grad() 1437*da0073e9SAndroid Build Coastguard Worker loss = mod(i).sum() 1438*da0073e9SAndroid Build Coastguard Worker loss.backward() 1439*da0073e9SAndroid Build Coastguard Worker return loss 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1442*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 1443*da0073e9SAndroid Build Coastguard Worker optimizer.step(functools.partial(fwd_bwd, optimizer, model, input)) 1444*da0073e9SAndroid Build Coastguard Worker else: 1445*da0073e9SAndroid Build Coastguard Worker fwd_bwd(optimizer, model, input) 1446*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1447*da0073e9SAndroid Build Coastguard Worker 1448*da0073e9SAndroid Build Coastguard Worker # old_state_dict has all new flags del'd 1449*da0073e9SAndroid Build Coastguard Worker old_state_dict = deepcopy(optimizer.state_dict()) 1450*da0073e9SAndroid Build Coastguard Worker old_state_dict_pg = old_state_dict["param_groups"] 1451*da0073e9SAndroid Build Coastguard Worker for group in old_state_dict_pg: 1452*da0073e9SAndroid Build Coastguard Worker for flag in optim_info.not_og_supported_flags: 1453*da0073e9SAndroid Build Coastguard Worker if flag in group: 1454*da0073e9SAndroid Build Coastguard Worker del group[flag] 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker optimizer.load_state_dict(old_state_dict) 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker # Make sure we can still step 1459*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 1460*da0073e9SAndroid Build Coastguard Worker optimizer.step(functools.partial(fwd_bwd, optimizer, model, input)) 1461*da0073e9SAndroid Build Coastguard Worker else: 1462*da0073e9SAndroid Build Coastguard Worker fwd_bwd(optimizer, model, input) 1463*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1466*da0073e9SAndroid Build Coastguard Worker def test_save_load_equality_with_weights_only(self, device, dtype, optim_info): 1467*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1470*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1471*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 1472*da0073e9SAndroid Build Coastguard Worker ) 1473*da0073e9SAndroid Build Coastguard Worker weight = Parameter( 1474*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype) 1475*da0073e9SAndroid Build Coastguard Worker ) 1476*da0073e9SAndroid Build Coastguard Worker bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype)) 1477*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) 1478*da0073e9SAndroid Build Coastguard Worker params = [weight, bias] 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker def fwd_bwd(optim, w, b, i): 1481*da0073e9SAndroid Build Coastguard Worker optim.zero_grad() 1482*da0073e9SAndroid Build Coastguard Worker loss = (w.mv(i) + b).pow(2).sum() 1483*da0073e9SAndroid Build Coastguard Worker loss.backward() 1484*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1485*da0073e9SAndroid Build Coastguard Worker weight.grad = weight.grad.to_sparse() 1486*da0073e9SAndroid Build Coastguard Worker bias.grad = bias.grad.to_sparse() 1487*da0073e9SAndroid Build Coastguard Worker return loss 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1490*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **optim_input.kwargs) 1491*da0073e9SAndroid Build Coastguard Worker closure = functools.partial(fwd_bwd, optimizer, weight, bias, input) 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker # Prime the optimizer 1494*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1495*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker sd = optimizer.state_dict() 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker # === Check saved/loaded state_dict are the same (including weights_only load). === 1500*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryFile() as f: 1501*da0073e9SAndroid Build Coastguard Worker torch.save(sd, f) 1502*da0073e9SAndroid Build Coastguard Worker f.seek(0) 1503*da0073e9SAndroid Build Coastguard Worker sd_copy = torch.load(f) 1504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sd_copy, sd) 1505*da0073e9SAndroid Build Coastguard Worker del sd_copy 1506*da0073e9SAndroid Build Coastguard Worker f.seek(0) 1507*da0073e9SAndroid Build Coastguard Worker sd_copy_wo = torch.load(f, weights_only=True) 1508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sd_copy_wo, sd) 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1511*da0073e9SAndroid Build Coastguard Worker def test_load_nontensor_step(self, device, dtype, optim_info): 1512*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1515*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1516*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 1517*da0073e9SAndroid Build Coastguard Worker ) 1518*da0073e9SAndroid Build Coastguard Worker params = [ 1519*da0073e9SAndroid Build Coastguard Worker Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2) 1520*da0073e9SAndroid Build Coastguard Worker ] 1521*da0073e9SAndroid Build Coastguard Worker for p in params: 1522*da0073e9SAndroid Build Coastguard Worker p.grad = torch.rand_like(p) 1523*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1524*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 1525*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 1526*da0073e9SAndroid Build Coastguard Worker p.grad = p.grad.to_sparse() 1527*da0073e9SAndroid Build Coastguard Worker 1528*da0073e9SAndroid Build Coastguard Worker # Needed for second order optims like LBFGS 1529*da0073e9SAndroid Build Coastguard Worker closure_loss = torch.rand(1, device=device, dtype=dtype) 1530*da0073e9SAndroid Build Coastguard Worker 1531*da0073e9SAndroid Build Coastguard Worker def closure(): 1532*da0073e9SAndroid Build Coastguard Worker return closure_loss if optim_info.step_requires_closure else None 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1535*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 1536*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **optim_input.kwargs) 1537*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1538*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1539*da0073e9SAndroid Build Coastguard Worker state_dict = deepcopy(optimizer.state_dict()) 1540*da0073e9SAndroid Build Coastguard Worker for p_state in state_dict["state"].values(): 1541*da0073e9SAndroid Build Coastguard Worker if "step" in p_state and torch.is_tensor(p_state["step"]): 1542*da0073e9SAndroid Build Coastguard Worker p_state["step"] = p_state["step"].item() 1543*da0073e9SAndroid Build Coastguard Worker optimizer.load_state_dict(state_dict) 1544*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1547*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1548*da0073e9SAndroid Build Coastguard Worker def test_state_dict_with_cuda_params(self, device, dtype, optim_info): 1549*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1552*da0073e9SAndroid Build Coastguard Worker # We limit our configs to CPU only, because we will be moving them to CUDA later 1553*da0073e9SAndroid Build Coastguard Worker cpu_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1554*da0073e9SAndroid Build Coastguard Worker "cpu", dtype, optim_info, skip=("differentiable",) 1555*da0073e9SAndroid Build Coastguard Worker ) 1556*da0073e9SAndroid Build Coastguard Worker 1557*da0073e9SAndroid Build Coastguard Worker # Needed for second order optims like LBFGS 1558*da0073e9SAndroid Build Coastguard Worker closure_loss = torch.rand(1, device=device, dtype=dtype) 1559*da0073e9SAndroid Build Coastguard Worker 1560*da0073e9SAndroid Build Coastguard Worker def closure(): 1561*da0073e9SAndroid Build Coastguard Worker return closure_loss if optim_info.step_requires_closure else None 1562*da0073e9SAndroid Build Coastguard Worker 1563*da0073e9SAndroid Build Coastguard Worker for optim_input in cpu_optim_inputs: 1564*da0073e9SAndroid Build Coastguard Worker if ( 1565*da0073e9SAndroid Build Coastguard Worker "fused" in optim_input.kwargs 1566*da0073e9SAndroid Build Coastguard Worker and "cuda" not in optim_info.supports_fused_on 1567*da0073e9SAndroid Build Coastguard Worker ): 1568*da0073e9SAndroid Build Coastguard Worker self.skipTest( 1569*da0073e9SAndroid Build Coastguard Worker f"cuda is not supported for fused on {optim_cls.__name__}" 1570*da0073e9SAndroid Build Coastguard Worker ) 1571*da0073e9SAndroid Build Coastguard Worker params = [ 1572*da0073e9SAndroid Build Coastguard Worker Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) 1573*da0073e9SAndroid Build Coastguard Worker for _ in range(2) 1574*da0073e9SAndroid Build Coastguard Worker ] 1575*da0073e9SAndroid Build Coastguard Worker for p in params: 1576*da0073e9SAndroid Build Coastguard Worker p.grad = torch.randn_like(p) 1577*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1578*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 1579*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 1580*da0073e9SAndroid Build Coastguard Worker p.grad = p.grad.to_sparse() 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **optim_input.kwargs) 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1585*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1586*da0073e9SAndroid Build Coastguard Worker 1587*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1588*da0073e9SAndroid Build Coastguard Worker params_cuda = [p.to(device="cuda") for p in params] 1589*da0073e9SAndroid Build Coastguard Worker for i, p in enumerate(params_cuda): 1590*da0073e9SAndroid Build Coastguard Worker p.grad = params[i].grad.to(device="cuda") 1591*da0073e9SAndroid Build Coastguard Worker optimizer_cuda = optim_cls(params_cuda, **optim_input.kwargs) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker state_dict_cpu = deepcopy(optimizer.state_dict()) 1594*da0073e9SAndroid Build Coastguard Worker state_dict_cuda = deepcopy(optimizer.state_dict()) 1595*da0073e9SAndroid Build Coastguard Worker optimizer_cuda.load_state_dict(state_dict_cuda) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker # Make sure state_dict_cuda isn't modified by merely calling load_state_dict 1598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state_dict_cpu, state_dict_cuda) 1599*da0073e9SAndroid Build Coastguard Worker 1600*da0073e9SAndroid Build Coastguard Worker # Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable! 1601*da0073e9SAndroid Build Coastguard Worker capturable = state_dict_cpu["param_groups"][0].get("capturable", False) 1602*da0073e9SAndroid Build Coastguard Worker fused = state_dict_cpu["param_groups"][0].get("fused", False) 1603*da0073e9SAndroid Build Coastguard Worker new_state_dict = optimizer_cuda.state_dict() 1604*da0073e9SAndroid Build Coastguard Worker for state_cpu, state_cuda in zip( 1605*da0073e9SAndroid Build Coastguard Worker state_dict_cpu["state"].values(), new_state_dict["state"].values() 1606*da0073e9SAndroid Build Coastguard Worker ): 1607*da0073e9SAndroid Build Coastguard Worker if "step" in state_cpu and torch.is_tensor(state_cpu["step"]): 1608*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1609*da0073e9SAndroid Build Coastguard Worker state_cuda["step"].device.type, 1610*da0073e9SAndroid Build Coastguard Worker "cuda" if capturable or fused else "cpu", 1611*da0073e9SAndroid Build Coastguard Worker ) 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 1614*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1615*da0073e9SAndroid Build Coastguard Worker optimizer_cuda.step(closure) 1616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(params, params_cuda) 1617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict()) 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker @staticmethod 1620*da0073e9SAndroid Build Coastguard Worker def _state_dict_pre_hook(optimizer: Optimizer) -> None: 1621*da0073e9SAndroid Build Coastguard Worker optimizer.state["test"] = 1 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker @staticmethod 1624*da0073e9SAndroid Build Coastguard Worker def _state_dict_post_hook( 1625*da0073e9SAndroid Build Coastguard Worker optimizer: Optimizer, state_dict: Dict[str, Any] 1626*da0073e9SAndroid Build Coastguard Worker ) -> Dict[str, Any]: 1627*da0073e9SAndroid Build Coastguard Worker if "test" in state_dict["state"]: 1628*da0073e9SAndroid Build Coastguard Worker state_dict["state"].pop("test") 1629*da0073e9SAndroid Build Coastguard Worker state_dict["ran_state_dict_pre_hook"] = True 1630*da0073e9SAndroid Build Coastguard Worker else: 1631*da0073e9SAndroid Build Coastguard Worker state_dict["ran_state_dict_pre_hook"] = False 1632*da0073e9SAndroid Build Coastguard Worker return state_dict 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1635*da0073e9SAndroid Build Coastguard Worker def test_state_dict_pre_hook(self, device, dtype, optim_info): 1636*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1637*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1638*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1639*da0073e9SAndroid Build Coastguard Worker ) 1640*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1641*da0073e9SAndroid Build Coastguard Worker param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1642*da0073e9SAndroid Build Coastguard Worker optim = optim_cls([param], **optim_input.kwargs) 1643*da0073e9SAndroid Build Coastguard Worker optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook) 1644*da0073e9SAndroid Build Coastguard Worker state_dict = optim.state_dict() 1645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state_dict["state"]["test"], 1) 1646*da0073e9SAndroid Build Coastguard Worker 1647*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1648*da0073e9SAndroid Build Coastguard Worker def test_state_dict_post_hook(self, device, dtype, optim_info): 1649*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1650*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1651*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1652*da0073e9SAndroid Build Coastguard Worker ) 1653*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1654*da0073e9SAndroid Build Coastguard Worker param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1655*da0073e9SAndroid Build Coastguard Worker optim = optim_cls([param], **optim_input.kwargs) 1656*da0073e9SAndroid Build Coastguard Worker optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook) 1657*da0073e9SAndroid Build Coastguard Worker state_dict = optim.state_dict() 1658*da0073e9SAndroid Build Coastguard Worker self.assertFalse(state_dict["ran_state_dict_pre_hook"]) 1659*da0073e9SAndroid Build Coastguard Worker 1660*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1661*da0073e9SAndroid Build Coastguard Worker def test_state_dict_pre_post_hook(self, device, dtype, optim_info): 1662*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1663*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1664*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1665*da0073e9SAndroid Build Coastguard Worker ) 1666*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1667*da0073e9SAndroid Build Coastguard Worker param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1668*da0073e9SAndroid Build Coastguard Worker optim = optim_cls([param], **optim_input.kwargs) 1669*da0073e9SAndroid Build Coastguard Worker optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook) 1670*da0073e9SAndroid Build Coastguard Worker optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook) 1671*da0073e9SAndroid Build Coastguard Worker state_dict = optim.state_dict() 1672*da0073e9SAndroid Build Coastguard Worker self.assertFalse("test" in state_dict["state"]) 1673*da0073e9SAndroid Build Coastguard Worker self.assertTrue(state_dict["ran_state_dict_pre_hook"]) 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker @staticmethod 1676*da0073e9SAndroid Build Coastguard Worker def _load_state_dict_pre_hook1( 1677*da0073e9SAndroid Build Coastguard Worker optimizer: Optimizer, state_dict: Dict[str, Any] 1678*da0073e9SAndroid Build Coastguard Worker ) -> None: 1679*da0073e9SAndroid Build Coastguard Worker state_dict["param_groups"][0]["lr"] = 0.002 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker @staticmethod 1682*da0073e9SAndroid Build Coastguard Worker def _load_state_dict_pre_hook2( 1683*da0073e9SAndroid Build Coastguard Worker optimizer: Optimizer, state_dict: Dict[str, Any] 1684*da0073e9SAndroid Build Coastguard Worker ) -> Dict[str, Any]: 1685*da0073e9SAndroid Build Coastguard Worker # The typical use case for returning a state dict is to drastically modify the state dict. 1686*da0073e9SAndroid Build Coastguard Worker # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used 1687*da0073e9SAndroid Build Coastguard Worker my_state_dict = deepcopy(state_dict) 1688*da0073e9SAndroid Build Coastguard Worker my_state_dict["param_groups"][0]["lr"] = 0.003 1689*da0073e9SAndroid Build Coastguard Worker return my_state_dict 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker @staticmethod 1692*da0073e9SAndroid Build Coastguard Worker def _load_state_dict_post_hook(optimizer: Optimizer) -> None: 1693*da0073e9SAndroid Build Coastguard Worker optimizer.state["ran_load_state_dict_pre_hook2"] = ( 1694*da0073e9SAndroid Build Coastguard Worker optimizer.param_groups[0]["lr"] == 0.003 1695*da0073e9SAndroid Build Coastguard Worker ) 1696*da0073e9SAndroid Build Coastguard Worker optimizer.state["ran_load_state_dict_post_hook"] = True 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1699*da0073e9SAndroid Build Coastguard Worker def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info): 1700*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1701*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1702*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1703*da0073e9SAndroid Build Coastguard Worker ) 1704*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1705*da0073e9SAndroid Build Coastguard Worker param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1706*da0073e9SAndroid Build Coastguard Worker optim = optim_cls([param], **optim_input.kwargs) 1707*da0073e9SAndroid Build Coastguard Worker state_dict = optim.state_dict() 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker # usually one would have a new optim instance here, but it's all the same here 1710*da0073e9SAndroid Build Coastguard Worker optim.register_load_state_dict_pre_hook( 1711*da0073e9SAndroid Build Coastguard Worker self.__class__._load_state_dict_pre_hook1 1712*da0073e9SAndroid Build Coastguard Worker ) 1713*da0073e9SAndroid Build Coastguard Worker optim.load_state_dict(state_dict) 1714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(optim.param_groups[0]["lr"], 0.002) 1715*da0073e9SAndroid Build Coastguard Worker 1716*da0073e9SAndroid Build Coastguard Worker optim.register_load_state_dict_pre_hook( 1717*da0073e9SAndroid Build Coastguard Worker self.__class__._load_state_dict_pre_hook2, prepend=True 1718*da0073e9SAndroid Build Coastguard Worker ) 1719*da0073e9SAndroid Build Coastguard Worker optim.load_state_dict(state_dict) 1720*da0073e9SAndroid Build Coastguard Worker # If prepend were False would be 0.003 but since prepend is True, the other hook overrides 1721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(optim.param_groups[0]["lr"], 0.002) 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1724*da0073e9SAndroid Build Coastguard Worker def test_load_state_dict_post_hook(self, device, dtype, optim_info): 1725*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1726*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1727*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1728*da0073e9SAndroid Build Coastguard Worker ) 1729*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1730*da0073e9SAndroid Build Coastguard Worker param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1731*da0073e9SAndroid Build Coastguard Worker optim = optim_cls([param], **optim_input.kwargs) 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker optim.register_load_state_dict_post_hook( 1734*da0073e9SAndroid Build Coastguard Worker self.__class__._load_state_dict_post_hook 1735*da0073e9SAndroid Build Coastguard Worker ) 1736*da0073e9SAndroid Build Coastguard Worker optim.load_state_dict(optim.state_dict()) 1737*da0073e9SAndroid Build Coastguard Worker self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"]) 1738*da0073e9SAndroid Build Coastguard Worker self.assertTrue(optim.state["ran_load_state_dict_post_hook"]) 1739*da0073e9SAndroid Build Coastguard Worker 1740*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1741*da0073e9SAndroid Build Coastguard Worker def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info): 1742*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1743*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1744*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1745*da0073e9SAndroid Build Coastguard Worker ) 1746*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1747*da0073e9SAndroid Build Coastguard Worker param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True) 1748*da0073e9SAndroid Build Coastguard Worker optim = optim_cls([param], **optim_input.kwargs) 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker optim.register_load_state_dict_pre_hook( 1751*da0073e9SAndroid Build Coastguard Worker self.__class__._load_state_dict_pre_hook2 1752*da0073e9SAndroid Build Coastguard Worker ) 1753*da0073e9SAndroid Build Coastguard Worker optim.register_load_state_dict_post_hook( 1754*da0073e9SAndroid Build Coastguard Worker self.__class__._load_state_dict_post_hook 1755*da0073e9SAndroid Build Coastguard Worker ) 1756*da0073e9SAndroid Build Coastguard Worker optim.load_state_dict(optim.state_dict()) 1757*da0073e9SAndroid Build Coastguard Worker self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"]) 1758*da0073e9SAndroid Build Coastguard Worker self.assertTrue(optim.state["ran_load_state_dict_post_hook"]) 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1761*da0073e9SAndroid Build Coastguard Worker def test_step_post_hook(self, device, dtype, optim_info): 1762*da0073e9SAndroid Build Coastguard Worker def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1763*da0073e9SAndroid Build Coastguard Worker nonlocal data 1764*da0073e9SAndroid Build Coastguard Worker data += 2 1765*da0073e9SAndroid Build Coastguard Worker 1766*da0073e9SAndroid Build Coastguard Worker params = [torch.tensor([1, 1], device=device, dtype=dtype)] 1767*da0073e9SAndroid Build Coastguard Worker 1768*da0073e9SAndroid Build Coastguard Worker def dummy_closure(): 1769*da0073e9SAndroid Build Coastguard Worker return 1 1770*da0073e9SAndroid Build Coastguard Worker 1771*da0073e9SAndroid Build Coastguard Worker closure = dummy_closure if optim_info.step_requires_closure else None 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1774*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1775*da0073e9SAndroid Build Coastguard Worker ) 1776*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1777*da0073e9SAndroid Build Coastguard Worker optim = optim_info.optim_cls(params, **optim_input.kwargs) 1778*da0073e9SAndroid Build Coastguard Worker data = 2 1779*da0073e9SAndroid Build Coastguard Worker hook_handle = optim.register_step_post_hook(post_hook) 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1782*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1783*da0073e9SAndroid Build Coastguard Worker # check if post hooks were registered 1784*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data, 6) 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker # remove handles, take step and verify that hook is no longer registered 1787*da0073e9SAndroid Build Coastguard Worker hook_handle.remove() 1788*da0073e9SAndroid Build Coastguard Worker 1789*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data, 6) 1791*da0073e9SAndroid Build Coastguard Worker 1792*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1793*da0073e9SAndroid Build Coastguard Worker def test_step_pre_hook(self, device, dtype, optim_info): 1794*da0073e9SAndroid Build Coastguard Worker def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1795*da0073e9SAndroid Build Coastguard Worker nonlocal data 1796*da0073e9SAndroid Build Coastguard Worker data += 2 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker params = [torch.tensor([1, 1], device=device, dtype=dtype)] 1799*da0073e9SAndroid Build Coastguard Worker 1800*da0073e9SAndroid Build Coastguard Worker def dummy_closure(): 1801*da0073e9SAndroid Build Coastguard Worker return 1 1802*da0073e9SAndroid Build Coastguard Worker 1803*da0073e9SAndroid Build Coastguard Worker closure = dummy_closure if optim_info.step_requires_closure else None 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1806*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1807*da0073e9SAndroid Build Coastguard Worker ) 1808*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1809*da0073e9SAndroid Build Coastguard Worker optim = optim_info.optim_cls(params, **optim_input.kwargs) 1810*da0073e9SAndroid Build Coastguard Worker data = 5 1811*da0073e9SAndroid Build Coastguard Worker hook_handle = optim.register_step_pre_hook(pre_hook) 1812*da0073e9SAndroid Build Coastguard Worker 1813*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1814*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1815*da0073e9SAndroid Build Coastguard Worker # check if pre hooks were registered 1816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data, 9) 1817*da0073e9SAndroid Build Coastguard Worker 1818*da0073e9SAndroid Build Coastguard Worker # remove handles, take step and verify that hook is no longer registered 1819*da0073e9SAndroid Build Coastguard Worker hook_handle.remove() 1820*da0073e9SAndroid Build Coastguard Worker 1821*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1822*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data, 9) 1823*da0073e9SAndroid Build Coastguard Worker 1824*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1825*da0073e9SAndroid Build Coastguard Worker def test_step_all_hooks(self, device, dtype, optim_info): 1826*da0073e9SAndroid Build Coastguard Worker def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1827*da0073e9SAndroid Build Coastguard Worker nonlocal data 1828*da0073e9SAndroid Build Coastguard Worker data.append(0) 1829*da0073e9SAndroid Build Coastguard Worker 1830*da0073e9SAndroid Build Coastguard Worker def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1831*da0073e9SAndroid Build Coastguard Worker nonlocal data 1832*da0073e9SAndroid Build Coastguard Worker data.append(5) 1833*da0073e9SAndroid Build Coastguard Worker 1834*da0073e9SAndroid Build Coastguard Worker def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1835*da0073e9SAndroid Build Coastguard Worker nonlocal data 1836*da0073e9SAndroid Build Coastguard Worker data.append(1) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): 1839*da0073e9SAndroid Build Coastguard Worker nonlocal data 1840*da0073e9SAndroid Build Coastguard Worker data.append(2) 1841*da0073e9SAndroid Build Coastguard Worker 1842*da0073e9SAndroid Build Coastguard Worker params = [torch.tensor([1, 1], device=device, dtype=dtype)] 1843*da0073e9SAndroid Build Coastguard Worker 1844*da0073e9SAndroid Build Coastguard Worker def dummy_closure(): 1845*da0073e9SAndroid Build Coastguard Worker return 1 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker closure = dummy_closure if optim_info.step_requires_closure else None 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1850*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info 1851*da0073e9SAndroid Build Coastguard Worker ) 1852*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1853*da0073e9SAndroid Build Coastguard Worker optim = optim_info.optim_cls(params, **optim_input.kwargs) 1854*da0073e9SAndroid Build Coastguard Worker optim2 = SGD(params) 1855*da0073e9SAndroid Build Coastguard Worker data = [] 1856*da0073e9SAndroid Build Coastguard Worker 1857*da0073e9SAndroid Build Coastguard Worker # register global hooks to both optimizers 1858*da0073e9SAndroid Build Coastguard Worker global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook) 1859*da0073e9SAndroid Build Coastguard Worker global_post_handle = register_optimizer_step_post_hook(global_post_hook) 1860*da0073e9SAndroid Build Coastguard Worker 1861*da0073e9SAndroid Build Coastguard Worker # register local hooks 1862*da0073e9SAndroid Build Coastguard Worker first_pre_handle = optim.register_step_pre_hook(local_pre_hook) 1863*da0073e9SAndroid Build Coastguard Worker first_post_handle = optim.register_step_post_hook(local_post_hook) 1864*da0073e9SAndroid Build Coastguard Worker second_pre_handle = optim2.register_step_pre_hook(local_pre_hook) 1865*da0073e9SAndroid Build Coastguard Worker second_post_handle = optim2.register_step_post_hook(local_post_hook) 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1868*da0073e9SAndroid Build Coastguard Worker self.assertListEqual(data, [0, 1, 2, 5]) 1869*da0073e9SAndroid Build Coastguard Worker optim2.step(closure) 1870*da0073e9SAndroid Build Coastguard Worker self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5]) 1871*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1872*da0073e9SAndroid Build Coastguard Worker self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker # remove all hooks 1875*da0073e9SAndroid Build Coastguard Worker global_pre_handle.remove() 1876*da0073e9SAndroid Build Coastguard Worker global_post_handle.remove() 1877*da0073e9SAndroid Build Coastguard Worker first_pre_handle.remove() 1878*da0073e9SAndroid Build Coastguard Worker first_post_handle.remove() 1879*da0073e9SAndroid Build Coastguard Worker second_pre_handle.remove() 1880*da0073e9SAndroid Build Coastguard Worker second_post_handle.remove() 1881*da0073e9SAndroid Build Coastguard Worker 1882*da0073e9SAndroid Build Coastguard Worker optim.step(closure) 1883*da0073e9SAndroid Build Coastguard Worker optim2.step(closure) 1884*da0073e9SAndroid Build Coastguard Worker self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) 1885*da0073e9SAndroid Build Coastguard Worker 1886*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 1887*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info): 1888*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 1891*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 1892*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",) 1893*da0073e9SAndroid Build Coastguard Worker ) 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker params = [ 1896*da0073e9SAndroid Build Coastguard Worker Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2) 1897*da0073e9SAndroid Build Coastguard Worker ] 1898*da0073e9SAndroid Build Coastguard Worker for p in params: 1899*da0073e9SAndroid Build Coastguard Worker p.grad = torch.rand_like(p) 1900*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 1901*da0073e9SAndroid Build Coastguard Worker # For this test, we naively convert the Tensor layout, which we know does 1902*da0073e9SAndroid Build Coastguard Worker # NOT represent the expected use case for optims like SparseAdam! 1903*da0073e9SAndroid Build Coastguard Worker p.grad = p.grad.to_sparse() 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker # Needed for second order optims like LBFGS 1906*da0073e9SAndroid Build Coastguard Worker def closure(): 1907*da0073e9SAndroid Build Coastguard Worker return 1 if optim_info.step_requires_closure else None 1908*da0073e9SAndroid Build Coastguard Worker 1909*da0073e9SAndroid Build Coastguard Worker def getPublicAttrs(obj): 1910*da0073e9SAndroid Build Coastguard Worker return {k for k in obj.__dict__ if not k.startswith("_")} 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 1913*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **optim_input.kwargs) 1914*da0073e9SAndroid Build Coastguard Worker 1915*da0073e9SAndroid Build Coastguard Worker # Make some state 1916*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1917*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 1918*da0073e9SAndroid Build Coastguard Worker optimizer.step(closure) 1919*da0073e9SAndroid Build Coastguard Worker else: 1920*da0073e9SAndroid Build Coastguard Worker closure() 1921*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1922*da0073e9SAndroid Build Coastguard Worker 1923*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1924*da0073e9SAndroid Build Coastguard Worker getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer)) 1925*da0073e9SAndroid Build Coastguard Worker ) 1926*da0073e9SAndroid Build Coastguard Worker 1927*da0073e9SAndroid Build Coastguard Worker @optims( 1928*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if optim.step_requires_closure], 1929*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 1930*da0073e9SAndroid Build Coastguard Worker ) 1931*da0073e9SAndroid Build Coastguard Worker def test_second_order_optims_return_consistent_types( 1932*da0073e9SAndroid Build Coastguard Worker self, device, dtype, optim_info 1933*da0073e9SAndroid Build Coastguard Worker ): 1934*da0073e9SAndroid Build Coastguard Worker # Motivated by #7586 1935*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1936*da0073e9SAndroid Build Coastguard Worker params = [ 1937*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 5, device=device, dtype=dtype), 1938*da0073e9SAndroid Build Coastguard Worker torch.randn(10, device=device, dtype=dtype), 1939*da0073e9SAndroid Build Coastguard Worker ] 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker def closure(): 1942*da0073e9SAndroid Build Coastguard Worker return torch.tensor([10], device=device, dtype=dtype) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_info.optim_inputs_func(device=device): 1945*da0073e9SAndroid Build Coastguard Worker # Currently, the only second order optim is LBFGS, so we just go ahead and modify 1946*da0073e9SAndroid Build Coastguard Worker # "tolerance_grad", but this may not scale if we add second order optims in the future 1947*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 1948*da0073e9SAndroid Build Coastguard Worker kwargs["tolerance_grad"] = math.inf 1949*da0073e9SAndroid Build Coastguard Worker optim_inf = optim_cls(params, **kwargs) 1950*da0073e9SAndroid Build Coastguard Worker kwargs["tolerance_grad"] = -math.inf 1951*da0073e9SAndroid Build Coastguard Worker optim_neg_inf = optim_cls(params, **kwargs) 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker res1 = optim_inf.step(closure) 1954*da0073e9SAndroid Build Coastguard Worker res2 = optim_neg_inf.step(closure) 1955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(res1), type(res2)) 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1958*da0073e9SAndroid Build Coastguard Worker @optims( 1959*da0073e9SAndroid Build Coastguard Worker [ 1960*da0073e9SAndroid Build Coastguard Worker optim 1961*da0073e9SAndroid Build Coastguard Worker for optim in optim_db 1962*da0073e9SAndroid Build Coastguard Worker if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on 1963*da0073e9SAndroid Build Coastguard Worker ], 1964*da0073e9SAndroid Build Coastguard Worker dtypes=floating_types_and( 1965*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 1966*da0073e9SAndroid Build Coastguard Worker torch.float16, 1967*da0073e9SAndroid Build Coastguard Worker ), 1968*da0073e9SAndroid Build Coastguard Worker ) 1969*da0073e9SAndroid Build Coastguard Worker def test_fused_cpu_matches_cuda(self, device, dtype, optim_info): 1970*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 1971*da0073e9SAndroid Build Coastguard Worker optim_inputs = optim_info.optim_inputs_func(device="cpu") 1972*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_inputs: 1973*da0073e9SAndroid Build Coastguard Worker inpts, models, optimizers = [], [], [] 1974*da0073e9SAndroid Build Coastguard Worker for dev in ("cpu", "cuda"): 1975*da0073e9SAndroid Build Coastguard Worker kwargs = optim_input.kwargs 1976*da0073e9SAndroid Build Coastguard Worker kwargs["fused"] = True 1977*da0073e9SAndroid Build Coastguard Worker inpt = torch.tensor( 1978*da0073e9SAndroid Build Coastguard Worker [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev 1979*da0073e9SAndroid Build Coastguard Worker ).reshape(3, 2) 1980*da0073e9SAndroid Build Coastguard Worker 1981*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 1982*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 1983*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(2, 3), 1984*da0073e9SAndroid Build Coastguard Worker torch.nn.Sigmoid(), 1985*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(3, 1), 1986*da0073e9SAndroid Build Coastguard Worker torch.nn.Sigmoid(), 1987*da0073e9SAndroid Build Coastguard Worker ) 1988*da0073e9SAndroid Build Coastguard Worker model.to(dtype=dtype, device=dev) 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker # foreach/fused optimizers should be tested with a 1991*da0073e9SAndroid Build Coastguard Worker # zero_size tensor as its last param. 1992*da0073e9SAndroid Build Coastguard Worker # ref: https://github.com/pytorch/pytorch/issues/100701 1993*da0073e9SAndroid Build Coastguard Worker empty_param = torch.empty( 1994*da0073e9SAndroid Build Coastguard Worker (), device=dev, dtype=dtype, requires_grad=True 1995*da0073e9SAndroid Build Coastguard Worker ) 1996*da0073e9SAndroid Build Coastguard Worker empty_param.grad = torch.rand_like(empty_param) 1997*da0073e9SAndroid Build Coastguard Worker params = list(model.parameters()) + [empty_param] 1998*da0073e9SAndroid Build Coastguard Worker 1999*da0073e9SAndroid Build Coastguard Worker optimizer = optim_cls(params, **kwargs) 2000*da0073e9SAndroid Build Coastguard Worker inpts.append(inpt) 2001*da0073e9SAndroid Build Coastguard Worker models.append(model) 2002*da0073e9SAndroid Build Coastguard Worker optimizers.append(optimizer) 2003*da0073e9SAndroid Build Coastguard Worker self._compare_between(inpts, models, optimizers) 2004*da0073e9SAndroid Build Coastguard Worker 2005*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2006*da0073e9SAndroid Build Coastguard Worker @optims( 2007*da0073e9SAndroid Build Coastguard Worker [ 2008*da0073e9SAndroid Build Coastguard Worker o 2009*da0073e9SAndroid Build Coastguard Worker for o in optim_db 2010*da0073e9SAndroid Build Coastguard Worker if ("foreach" in o.supported_impls and o.optim_cls.__name__ != "Adafactor") 2011*da0073e9SAndroid Build Coastguard Worker ], 2012*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32], 2013*da0073e9SAndroid Build Coastguard Worker ) 2014*da0073e9SAndroid Build Coastguard Worker def test_defaults_changed_to_foreach(self, device, dtype, optim_info): 2015*da0073e9SAndroid Build Coastguard Worker # Test that the default implementations for optimizers are changed to foreach 2016*da0073e9SAndroid Build Coastguard Worker # except Adafactor, which defaults to the single tensor impl for memory efficiency. 2017*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 2018*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(5, 5) 2019*da0073e9SAndroid Build Coastguard Worker model.to(dtype=dtype, device=device) 2020*da0073e9SAndroid Build Coastguard Worker inpt = torch.rand(2, 5, dtype=dtype, device=device) 2021*da0073e9SAndroid Build Coastguard Worker 2022*da0073e9SAndroid Build Coastguard Worker import inspect 2023*da0073e9SAndroid Build Coastguard Worker 2024*da0073e9SAndroid Build Coastguard Worker module = inspect.getmodule(optim_cls) 2025*da0073e9SAndroid Build Coastguard Worker 2026*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_info.optim_inputs_func(device=device): 2027*da0073e9SAndroid Build Coastguard Worker optim = optim_cls(model.parameters(), **optim_input.kwargs) 2028*da0073e9SAndroid Build Coastguard Worker optim.zero_grad() 2029*da0073e9SAndroid Build Coastguard Worker output = model(inpt) 2030*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 2031*da0073e9SAndroid Build Coastguard Worker loss.backward() 2032*da0073e9SAndroid Build Coastguard Worker with patch.object( 2033*da0073e9SAndroid Build Coastguard Worker module, f"_multi_tensor_{optim_cls.__name__.lower()}" 2034*da0073e9SAndroid Build Coastguard Worker ) as mocked_foreach_impl: 2035*da0073e9SAndroid Build Coastguard Worker optim.step() 2036*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mocked_foreach_impl.called) 2037*da0073e9SAndroid Build Coastguard Worker 2038*da0073e9SAndroid Build Coastguard Worker @optims(optim_db, dtypes=[torch.float32]) 2039*da0073e9SAndroid Build Coastguard Worker def test_non_empty_state(self, device, dtype, optim_info): 2040*da0073e9SAndroid Build Coastguard Worker # There are internal tests that check that the state is not empty 2041*da0073e9SAndroid Build Coastguard Worker optim_cls = optim_info.optim_cls 2042*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(5, 5) 2043*da0073e9SAndroid Build Coastguard Worker model.to(dtype=dtype, device=device) 2044*da0073e9SAndroid Build Coastguard Worker inpt = torch.rand(2, 5, dtype=dtype, device=device) 2045*da0073e9SAndroid Build Coastguard Worker 2046*da0073e9SAndroid Build Coastguard Worker for optim_input in optim_info.optim_inputs_func(device=device): 2047*da0073e9SAndroid Build Coastguard Worker optim = optim_cls(model.parameters(), **optim_input.kwargs) 2048*da0073e9SAndroid Build Coastguard Worker optim.zero_grad() 2049*da0073e9SAndroid Build Coastguard Worker output = model(inpt) 2050*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 2051*da0073e9SAndroid Build Coastguard Worker loss.backward() 2052*da0073e9SAndroid Build Coastguard Worker 2053*da0073e9SAndroid Build Coastguard Worker if optim_info.only_supports_sparse_grads: 2054*da0073e9SAndroid Build Coastguard Worker for param in model.parameters(): 2055*da0073e9SAndroid Build Coastguard Worker if param.grad is not None: 2056*da0073e9SAndroid Build Coastguard Worker param.grad = param.grad.to_sparse() 2057*da0073e9SAndroid Build Coastguard Worker 2058*da0073e9SAndroid Build Coastguard Worker if optim_info.step_requires_closure: 2059*da0073e9SAndroid Build Coastguard Worker optim.step(lambda: 1.0) 2060*da0073e9SAndroid Build Coastguard Worker else: 2061*da0073e9SAndroid Build Coastguard Worker optim.step() 2062*da0073e9SAndroid Build Coastguard Worker 2063*da0073e9SAndroid Build Coastguard Worker for state in optim.state.values(): 2064*da0073e9SAndroid Build Coastguard Worker self.assertGreater(len(state), 0) 2065*da0073e9SAndroid Build Coastguard Worker 2066*da0073e9SAndroid Build Coastguard Worker 2067*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True) 2068*da0073e9SAndroid Build Coastguard Worker 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2071*da0073e9SAndroid Build Coastguard Worker run_tests() 2072