xref: /aosp_15_r20/external/pytorch/test/jit/test_autodiff.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom typing import List
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo()
11*da0073e9SAndroid Build Coastguard Workerclass TestAutodiffJit(JitTestCase):
12*da0073e9SAndroid Build Coastguard Worker    def test_undefined_tensor_lists(self):
13*da0073e9SAndroid Build Coastguard Worker        def fn(tensor_list: List[torch.Tensor], add_tensor):
14*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat(tensor_list, dim=1)
15*da0073e9SAndroid Build Coastguard Worker            r = torch.sin(cat + add_tensor)
16*da0073e9SAndroid Build Coastguard Worker            return r
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((3, 6), requires_grad=True)
21*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((3, 10), requires_grad=True)
22*da0073e9SAndroid Build Coastguard Worker        x = [a, b]
23*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((3, 16), requires_grad=True)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        ret = fn_s(x, y)
26*da0073e9SAndroid Build Coastguard Worker        ret.sum().backward()
27*da0073e9SAndroid Build Coastguard Worker        ret = fn_s(x, y)
28*da0073e9SAndroid Build Coastguard Worker        ret.sum().backward()
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker        ret = fn_s(x, y)
31*da0073e9SAndroid Build Coastguard Worker        s = ret.sum()
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        # backward_fn expects 2 inputs: (grad_output, current_grad_r)
34*da0073e9SAndroid Build Coastguard Worker        # current_grad_r is provided because we need to add this contribution
35*da0073e9SAndroid Build Coastguard Worker        # to grad_r when we return it.
36*da0073e9SAndroid Build Coastguard Worker        backward_fn = s.grad_fn.next_functions[0][0]
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        # check behavior with defined tensor
39*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.rand((3, 16))
40*da0073e9SAndroid Build Coastguard Worker        grad_inputs = backward_fn(grad_out, None)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        # expect 3 tensors: grad_y, grad_a, grad_b
43*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, len(grad_inputs))
44*da0073e9SAndroid Build Coastguard Worker        for x in grad_inputs:
45*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(x, torch.Tensor))
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        # now test with undefined grad_out
48*da0073e9SAndroid Build Coastguard Worker        grad_inputs = backward_fn(None, None)
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        # expect all of them to be None
51*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, len(grad_inputs))
52*da0073e9SAndroid Build Coastguard Worker        for x in grad_inputs:
53*da0073e9SAndroid Build Coastguard Worker            if x is not None:
54*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(0, torch.max(torch.abs(x)).item())
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_outputs(self):
57*da0073e9SAndroid Build Coastguard Worker        # outputs should require_grad only if eager outputs would require_grad.
58*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c):
59*da0073e9SAndroid Build Coastguard Worker            return a.relu() + b.relu(), c.relu()
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((10, 10), requires_grad=False)
62*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((10, 10), requires_grad=False)
63*da0073e9SAndroid Build Coastguard Worker        c = torch.rand((10, 10), requires_grad=True)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
68*da0073e9SAndroid Build Coastguard Worker            x, y = fn_s(a, b, c)
69*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.requires_grad)
70*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.requires_grad)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_outputs_profiled_twice(self):
73*da0073e9SAndroid Build Coastguard Worker        # the value "r" is used twice, by gammaln and by entr, so it is profiled twice.
74*da0073e9SAndroid Build Coastguard Worker        # So during autodiff graph formation the profile nodes are unmerged because
75*da0073e9SAndroid Build Coastguard Worker        # they are aliasing. Then the DifferentiableGraph doesn't have a profile
76*da0073e9SAndroid Build Coastguard Worker        # node on the output. The requires_grad info should then be added onto the
77*da0073e9SAndroid Build Coastguard Worker        # output value (otherwise autodiff will make the output require_grad).
78*da0073e9SAndroid Build Coastguard Worker        # Note: this relies on gammaln and entr not having autodiff implementations.
79*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c):
80*da0073e9SAndroid Build Coastguard Worker            r = a.relu().relu()
81*da0073e9SAndroid Build Coastguard Worker            return torch.special.gammaln(r), torch.special.entr(r), c.cos().relu()
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((10, 10), requires_grad=False)
86*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((10, 10), requires_grad=False)
87*da0073e9SAndroid Build Coastguard Worker        c = torch.rand((10, 10), requires_grad=True)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
90*da0073e9SAndroid Build Coastguard Worker            x_s, y_s, z_s = fn_s(a, b, c)
91*da0073e9SAndroid Build Coastguard Worker            x, y, z = fn(a, b, c)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_s.requires_grad, x.requires_grad)
94*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_s.requires_grad, y.requires_grad)
95*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z_s.requires_grad, z.requires_grad)
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_outputs_side_effects(self):
98*da0073e9SAndroid Build Coastguard Worker        # same as above, but also add a CallFunction in between.
99*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
100*da0073e9SAndroid Build Coastguard Worker        def python_fn(x):
101*da0073e9SAndroid Build Coastguard Worker            return x.relu()
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c):
104*da0073e9SAndroid Build Coastguard Worker            r = a.relu().relu()
105*da0073e9SAndroid Build Coastguard Worker            z = python_fn(r)
106*da0073e9SAndroid Build Coastguard Worker            return torch.relu(r), torch.nn.functional.gelu(r), c.cos().relu()
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((10, 10), requires_grad=False)
111*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((10, 10), requires_grad=False)
112*da0073e9SAndroid Build Coastguard Worker        c = torch.rand((10, 10), requires_grad=True)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
115*da0073e9SAndroid Build Coastguard Worker            x_s, y_s, z_s = fn_s(a, b, c)
116*da0073e9SAndroid Build Coastguard Worker            x, y, z = fn(a, b, c)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_s.requires_grad, x.requires_grad)
119*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_s.requires_grad, y.requires_grad)
120*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z_s.requires_grad, z.requires_grad)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    def test_autodiff_requires_grad_nograd(self):
123*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
124*da0073e9SAndroid Build Coastguard Worker        def python_fn(x):
125*da0073e9SAndroid Build Coastguard Worker            return x.relu()
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c):
128*da0073e9SAndroid Build Coastguard Worker            x = a.sin().relu()
129*da0073e9SAndroid Build Coastguard Worker            y = python_fn(b)
130*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
131*da0073e9SAndroid Build Coastguard Worker                z = x + c
132*da0073e9SAndroid Build Coastguard Worker            return x, y, z
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((10, 10), requires_grad=True)
137*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((10, 10), requires_grad=True)
138*da0073e9SAndroid Build Coastguard Worker        c = torch.rand((10, 10), requires_grad=True)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
141*da0073e9SAndroid Build Coastguard Worker            x_s, y_s, z_s = fn_s(a, b, c)
142*da0073e9SAndroid Build Coastguard Worker            x, y, z = fn(a, b, c)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_s.requires_grad, x.requires_grad)
145*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_s.requires_grad, y.requires_grad)
146*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z_s.requires_grad, z.requires_grad)
147