1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom typing import List 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo() 11*da0073e9SAndroid Build Coastguard Workerclass TestAutodiffJit(JitTestCase): 12*da0073e9SAndroid Build Coastguard Worker def test_undefined_tensor_lists(self): 13*da0073e9SAndroid Build Coastguard Worker def fn(tensor_list: List[torch.Tensor], add_tensor): 14*da0073e9SAndroid Build Coastguard Worker cat = torch.cat(tensor_list, dim=1) 15*da0073e9SAndroid Build Coastguard Worker r = torch.sin(cat + add_tensor) 16*da0073e9SAndroid Build Coastguard Worker return r 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker a = torch.rand((3, 6), requires_grad=True) 21*da0073e9SAndroid Build Coastguard Worker b = torch.rand((3, 10), requires_grad=True) 22*da0073e9SAndroid Build Coastguard Worker x = [a, b] 23*da0073e9SAndroid Build Coastguard Worker y = torch.rand((3, 16), requires_grad=True) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker ret = fn_s(x, y) 26*da0073e9SAndroid Build Coastguard Worker ret.sum().backward() 27*da0073e9SAndroid Build Coastguard Worker ret = fn_s(x, y) 28*da0073e9SAndroid Build Coastguard Worker ret.sum().backward() 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker ret = fn_s(x, y) 31*da0073e9SAndroid Build Coastguard Worker s = ret.sum() 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker # backward_fn expects 2 inputs: (grad_output, current_grad_r) 34*da0073e9SAndroid Build Coastguard Worker # current_grad_r is provided because we need to add this contribution 35*da0073e9SAndroid Build Coastguard Worker # to grad_r when we return it. 36*da0073e9SAndroid Build Coastguard Worker backward_fn = s.grad_fn.next_functions[0][0] 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker # check behavior with defined tensor 39*da0073e9SAndroid Build Coastguard Worker grad_out = torch.rand((3, 16)) 40*da0073e9SAndroid Build Coastguard Worker grad_inputs = backward_fn(grad_out, None) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker # expect 3 tensors: grad_y, grad_a, grad_b 43*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, len(grad_inputs)) 44*da0073e9SAndroid Build Coastguard Worker for x in grad_inputs: 45*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x, torch.Tensor)) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker # now test with undefined grad_out 48*da0073e9SAndroid Build Coastguard Worker grad_inputs = backward_fn(None, None) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker # expect all of them to be None 51*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, len(grad_inputs)) 52*da0073e9SAndroid Build Coastguard Worker for x in grad_inputs: 53*da0073e9SAndroid Build Coastguard Worker if x is not None: 54*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, torch.max(torch.abs(x)).item()) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_outputs(self): 57*da0073e9SAndroid Build Coastguard Worker # outputs should require_grad only if eager outputs would require_grad. 58*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c): 59*da0073e9SAndroid Build Coastguard Worker return a.relu() + b.relu(), c.relu() 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker a = torch.rand((10, 10), requires_grad=False) 62*da0073e9SAndroid Build Coastguard Worker b = torch.rand((10, 10), requires_grad=False) 63*da0073e9SAndroid Build Coastguard Worker c = torch.rand((10, 10), requires_grad=True) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker for i in range(4): 68*da0073e9SAndroid Build Coastguard Worker x, y = fn_s(a, b, c) 69*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.requires_grad) 70*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.requires_grad) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_outputs_profiled_twice(self): 73*da0073e9SAndroid Build Coastguard Worker # the value "r" is used twice, by gammaln and by entr, so it is profiled twice. 74*da0073e9SAndroid Build Coastguard Worker # So during autodiff graph formation the profile nodes are unmerged because 75*da0073e9SAndroid Build Coastguard Worker # they are aliasing. Then the DifferentiableGraph doesn't have a profile 76*da0073e9SAndroid Build Coastguard Worker # node on the output. The requires_grad info should then be added onto the 77*da0073e9SAndroid Build Coastguard Worker # output value (otherwise autodiff will make the output require_grad). 78*da0073e9SAndroid Build Coastguard Worker # Note: this relies on gammaln and entr not having autodiff implementations. 79*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c): 80*da0073e9SAndroid Build Coastguard Worker r = a.relu().relu() 81*da0073e9SAndroid Build Coastguard Worker return torch.special.gammaln(r), torch.special.entr(r), c.cos().relu() 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker a = torch.rand((10, 10), requires_grad=False) 86*da0073e9SAndroid Build Coastguard Worker b = torch.rand((10, 10), requires_grad=False) 87*da0073e9SAndroid Build Coastguard Worker c = torch.rand((10, 10), requires_grad=True) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker for i in range(4): 90*da0073e9SAndroid Build Coastguard Worker x_s, y_s, z_s = fn_s(a, b, c) 91*da0073e9SAndroid Build Coastguard Worker x, y, z = fn(a, b, c) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_s.requires_grad, x.requires_grad) 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_s.requires_grad, y.requires_grad) 95*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z_s.requires_grad, z.requires_grad) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_outputs_side_effects(self): 98*da0073e9SAndroid Build Coastguard Worker # same as above, but also add a CallFunction in between. 99*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 100*da0073e9SAndroid Build Coastguard Worker def python_fn(x): 101*da0073e9SAndroid Build Coastguard Worker return x.relu() 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c): 104*da0073e9SAndroid Build Coastguard Worker r = a.relu().relu() 105*da0073e9SAndroid Build Coastguard Worker z = python_fn(r) 106*da0073e9SAndroid Build Coastguard Worker return torch.relu(r), torch.nn.functional.gelu(r), c.cos().relu() 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker a = torch.rand((10, 10), requires_grad=False) 111*da0073e9SAndroid Build Coastguard Worker b = torch.rand((10, 10), requires_grad=False) 112*da0073e9SAndroid Build Coastguard Worker c = torch.rand((10, 10), requires_grad=True) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker for i in range(4): 115*da0073e9SAndroid Build Coastguard Worker x_s, y_s, z_s = fn_s(a, b, c) 116*da0073e9SAndroid Build Coastguard Worker x, y, z = fn(a, b, c) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_s.requires_grad, x.requires_grad) 119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_s.requires_grad, y.requires_grad) 120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z_s.requires_grad, z.requires_grad) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker def test_autodiff_requires_grad_nograd(self): 123*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 124*da0073e9SAndroid Build Coastguard Worker def python_fn(x): 125*da0073e9SAndroid Build Coastguard Worker return x.relu() 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c): 128*da0073e9SAndroid Build Coastguard Worker x = a.sin().relu() 129*da0073e9SAndroid Build Coastguard Worker y = python_fn(b) 130*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 131*da0073e9SAndroid Build Coastguard Worker z = x + c 132*da0073e9SAndroid Build Coastguard Worker return x, y, z 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker a = torch.rand((10, 10), requires_grad=True) 137*da0073e9SAndroid Build Coastguard Worker b = torch.rand((10, 10), requires_grad=True) 138*da0073e9SAndroid Build Coastguard Worker c = torch.rand((10, 10), requires_grad=True) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker for i in range(4): 141*da0073e9SAndroid Build Coastguard Worker x_s, y_s, z_s = fn_s(a, b, c) 142*da0073e9SAndroid Build Coastguard Worker x, y, z = fn(a, b, c) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_s.requires_grad, x.requires_grad) 145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_s.requires_grad, y.requires_grad) 146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z_s.requires_grad, z.requires_grad) 147