# Owner(s): ["oncall: jit"] from typing import List import torch from torch.testing._internal.common_utils import skipIfTorchDynamo from torch.testing._internal.jit_utils import JitTestCase @skipIfTorchDynamo() class TestAutodiffJit(JitTestCase): def test_undefined_tensor_lists(self): def fn(tensor_list: List[torch.Tensor], add_tensor): cat = torch.cat(tensor_list, dim=1) r = torch.sin(cat + add_tensor) return r fn_s = torch.jit.script(fn) a = torch.rand((3, 6), requires_grad=True) b = torch.rand((3, 10), requires_grad=True) x = [a, b] y = torch.rand((3, 16), requires_grad=True) ret = fn_s(x, y) ret.sum().backward() ret = fn_s(x, y) ret.sum().backward() ret = fn_s(x, y) s = ret.sum() # backward_fn expects 2 inputs: (grad_output, current_grad_r) # current_grad_r is provided because we need to add this contribution # to grad_r when we return it. backward_fn = s.grad_fn.next_functions[0][0] # check behavior with defined tensor grad_out = torch.rand((3, 16)) grad_inputs = backward_fn(grad_out, None) # expect 3 tensors: grad_y, grad_a, grad_b self.assertEqual(3, len(grad_inputs)) for x in grad_inputs: self.assertTrue(isinstance(x, torch.Tensor)) # now test with undefined grad_out grad_inputs = backward_fn(None, None) # expect all of them to be None self.assertEqual(3, len(grad_inputs)) for x in grad_inputs: if x is not None: self.assertEqual(0, torch.max(torch.abs(x)).item()) def test_requires_grad_outputs(self): # outputs should require_grad only if eager outputs would require_grad. def fn(a, b, c): return a.relu() + b.relu(), c.relu() a = torch.rand((10, 10), requires_grad=False) b = torch.rand((10, 10), requires_grad=False) c = torch.rand((10, 10), requires_grad=True) fn_s = torch.jit.script(fn) for i in range(4): x, y = fn_s(a, b, c) self.assertFalse(x.requires_grad) self.assertTrue(y.requires_grad) def test_requires_grad_outputs_profiled_twice(self): # the value "r" is used twice, by gammaln and by entr, so it is profiled twice. # So during autodiff graph formation the profile nodes are unmerged because # they are aliasing. Then the DifferentiableGraph doesn't have a profile # node on the output. The requires_grad info should then be added onto the # output value (otherwise autodiff will make the output require_grad). # Note: this relies on gammaln and entr not having autodiff implementations. def fn(a, b, c): r = a.relu().relu() return torch.special.gammaln(r), torch.special.entr(r), c.cos().relu() fn_s = torch.jit.script(fn) a = torch.rand((10, 10), requires_grad=False) b = torch.rand((10, 10), requires_grad=False) c = torch.rand((10, 10), requires_grad=True) for i in range(4): x_s, y_s, z_s = fn_s(a, b, c) x, y, z = fn(a, b, c) self.assertEqual(x_s.requires_grad, x.requires_grad) self.assertEqual(y_s.requires_grad, y.requires_grad) self.assertEqual(z_s.requires_grad, z.requires_grad) def test_requires_grad_outputs_side_effects(self): # same as above, but also add a CallFunction in between. @torch.jit.ignore def python_fn(x): return x.relu() def fn(a, b, c): r = a.relu().relu() z = python_fn(r) return torch.relu(r), torch.nn.functional.gelu(r), c.cos().relu() fn_s = torch.jit.script(fn) a = torch.rand((10, 10), requires_grad=False) b = torch.rand((10, 10), requires_grad=False) c = torch.rand((10, 10), requires_grad=True) for i in range(4): x_s, y_s, z_s = fn_s(a, b, c) x, y, z = fn(a, b, c) self.assertEqual(x_s.requires_grad, x.requires_grad) self.assertEqual(y_s.requires_grad, y.requires_grad) self.assertEqual(z_s.requires_grad, z.requires_grad) def test_autodiff_requires_grad_nograd(self): @torch.jit.ignore def python_fn(x): return x.relu() def fn(a, b, c): x = a.sin().relu() y = python_fn(b) with torch.no_grad(): z = x + c return x, y, z fn_s = torch.jit.script(fn) a = torch.rand((10, 10), requires_grad=True) b = torch.rand((10, 10), requires_grad=True) c = torch.rand((10, 10), requires_grad=True) for i in range(4): x_s, y_s, z_s = fn_s(a, b, c) x, y, z = fn(a, b, c) self.assertEqual(x_s.requires_grad, x.requires_grad) self.assertEqual(y_s.requires_grad, y.requires_grad) self.assertEqual(z_s.requires_grad, z.requires_grad)