1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker""" 3*da0073e9SAndroid Build Coastguard WorkerPYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes 4*da0073e9SAndroid Build Coastguard Workerwith test_adam in OptimizerTests) 5*da0073e9SAndroid Build Coastguard Worker""" 6*da0073e9SAndroid Build Coastguard Workerimport functools 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 11*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass MyOptimizer(torch.optim.Optimizer): 16*da0073e9SAndroid Build Coastguard Worker def __init__(self, params): 17*da0073e9SAndroid Build Coastguard Worker super().__init__(params, {}) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker def _init_group(self, params, group): 20*da0073e9SAndroid Build Coastguard Worker any_complex = False 21*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 22*da0073e9SAndroid Build Coastguard Worker params.append(p) 23*da0073e9SAndroid Build Coastguard Worker any_complex |= p.is_complex() 24*da0073e9SAndroid Build Coastguard Worker return any_complex 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker def step(self): 27*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 28*da0073e9SAndroid Build Coastguard Worker params = [] 29*da0073e9SAndroid Build Coastguard Worker any_complex = self._init_group(params, group) 30*da0073e9SAndroid Build Coastguard Worker if any_complex: 31*da0073e9SAndroid Build Coastguard Worker params[0] -= 1 32*da0073e9SAndroid Build Coastguard Worker else: 33*da0073e9SAndroid Build Coastguard Worker params[0] += 1 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerclass End2EndTests(torch._dynamo.test_case.TestCase): 37*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/torchdynamo/issues/1604 38*da0073e9SAndroid Build Coastguard Worker def test_optimizing_over_tensor_with_requires_grad(self): 39*da0073e9SAndroid Build Coastguard Worker class Net(torch.nn.Module): 40*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 41*da0073e9SAndroid Build Coastguard Worker z = torch.bmm(x, y) 42*da0073e9SAndroid Build Coastguard Worker z = torch.flatten(z, 1) 43*da0073e9SAndroid Build Coastguard Worker return z 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def training_iter_fn(batch, model, optimizer): 46*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 47*da0073e9SAndroid Build Coastguard Worker out = model(**batch) 48*da0073e9SAndroid Build Coastguard Worker target = torch.tensor([0, 7]) 49*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.CrossEntropyLoss()(out, target) 50*da0073e9SAndroid Build Coastguard Worker loss.backward() 51*da0073e9SAndroid Build Coastguard Worker optimizer.step() 52*da0073e9SAndroid Build Coastguard Worker return loss 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker net = Net() 55*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(2, 1, 4) 56*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(2, 4, 8, requires_grad=True) 57*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.Adam([input2], lr=0.1) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 60*da0073e9SAndroid Build Coastguard Worker opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn) 61*da0073e9SAndroid Build Coastguard Worker batch = {"x": input1, "y": input2} 62*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 63*da0073e9SAndroid Build Coastguard Worker opt_training_iter_fn(batch, net, optimizer) 64*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def test_state_dict(self): 67*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 68*da0073e9SAndroid Build Coastguard Worker def _test_state_dict(weight, bias, input): 69*da0073e9SAndroid Build Coastguard Worker def fn_base(optimizer, weight, bias): 70*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 71*da0073e9SAndroid Build Coastguard Worker i = input 72*da0073e9SAndroid Build Coastguard Worker loss = (weight.mv(i) + bias).pow(2).sum() 73*da0073e9SAndroid Build Coastguard Worker loss.backward() 74*da0073e9SAndroid Build Coastguard Worker return loss 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.Adagrad([weight, bias]) 77*da0073e9SAndroid Build Coastguard Worker fn = functools.partial(fn_base, optimizer, weight, bias) 78*da0073e9SAndroid Build Coastguard Worker return optimizer, fn 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker optimizer, fn = _test_state_dict( 81*da0073e9SAndroid Build Coastguard Worker Parameter(torch.randn(10, 5)), 82*da0073e9SAndroid Build Coastguard Worker Parameter(torch.randn(10)), 83*da0073e9SAndroid Build Coastguard Worker torch.randn(5, requires_grad=True), 84*da0073e9SAndroid Build Coastguard Worker ) 85*da0073e9SAndroid Build Coastguard Worker optimizer.step(fn) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def test_init_group(self): 88*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float32, torch.cfloat]: 89*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(5, 5, dtype=dtype) 90*da0073e9SAndroid Build Coastguard Worker params = Parameter(tensor.detach().clone(), requires_grad=False) 91*da0073e9SAndroid Build Coastguard Worker opt_params = Parameter(tensor.detach().clone(), requires_grad=False) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker optim = MyOptimizer([params]) 94*da0073e9SAndroid Build Coastguard Worker optim.step() 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker opt_optim = MyOptimizer([opt_params]) 97*da0073e9SAndroid Build Coastguard Worker opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step) 98*da0073e9SAndroid Build Coastguard Worker opt_step() 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(params, opt_params) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 104*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker run_tests() 107