xref: /aosp_15_r20/external/pytorch/test/autograd/test_complex.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import gradcheck, run_tests, TestCase
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerclass TestAutogradComplex(TestCase):
8*da0073e9SAndroid Build Coastguard Worker    def test_view_func_for_complex_views(self):
9*da0073e9SAndroid Build Coastguard Worker        # case 1: both parent and child have view_func
10*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
11*da0073e9SAndroid Build Coastguard Worker        y = x.detach().requires_grad_(True)
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker        x0 = x.clone()
14*da0073e9SAndroid Build Coastguard Worker        x1 = torch.view_as_complex(x0)
15*da0073e9SAndroid Build Coastguard Worker        x2 = torch.view_as_real(x1)
16*da0073e9SAndroid Build Coastguard Worker        x2.mul_(2)
17*da0073e9SAndroid Build Coastguard Worker        x2.sum().abs().backward()
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker        y0 = y.clone()
20*da0073e9SAndroid Build Coastguard Worker        y0.mul_(2)
21*da0073e9SAndroid Build Coastguard Worker        y0.sum().abs().backward()
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, y.grad)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        # case 2: parent has view_func but child does not
26*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
27*da0073e9SAndroid Build Coastguard Worker        y = x.detach().requires_grad_(True)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        def fn(a):
30*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
31*da0073e9SAndroid Build Coastguard Worker            b1 = torch.view_as_complex(b)
32*da0073e9SAndroid Build Coastguard Worker            b2 = b1.reshape(b1.numel())
33*da0073e9SAndroid Build Coastguard Worker            return b2
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        x0 = fn(x)
36*da0073e9SAndroid Build Coastguard Worker        x0.mul_(2)
37*da0073e9SAndroid Build Coastguard Worker        x0.sum().abs().backward()
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        y0 = fn(y)
40*da0073e9SAndroid Build Coastguard Worker        y1 = y0.mul(2)
41*da0073e9SAndroid Build Coastguard Worker        y1.sum().abs().backward()
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, y.grad)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        # case 3: parent does not have a view_func but child does
46*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
47*da0073e9SAndroid Build Coastguard Worker        y = x.detach().requires_grad_(True)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        def fn(a, dim0_size=5):
50*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
51*da0073e9SAndroid Build Coastguard Worker            b1 = b.reshape(dim0_size, 2)
52*da0073e9SAndroid Build Coastguard Worker            b2 = torch.view_as_real(b1)
53*da0073e9SAndroid Build Coastguard Worker            return b2
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        x0 = fn(x)
56*da0073e9SAndroid Build Coastguard Worker        x0.mul_(2)
57*da0073e9SAndroid Build Coastguard Worker        x0.sum().abs().backward()
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        y0 = fn(y)
60*da0073e9SAndroid Build Coastguard Worker        y1 = y0.mul(2)
61*da0073e9SAndroid Build Coastguard Worker        y1.sum().abs().backward()
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, y.grad)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def test_view_with_multi_output(self):
66*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, 2, dtype=torch.double)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        x1 = torch.view_as_complex(x)
69*da0073e9SAndroid Build Coastguard Worker        # Taking an invalid view should always be allowed as long as it is not
70*da0073e9SAndroid Build Coastguard Worker        # modified inplace
71*da0073e9SAndroid Build Coastguard Worker        res = x1.unbind(0)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
74*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "output of a function that returns multiple views"
75*da0073e9SAndroid Build Coastguard Worker        ):
76*da0073e9SAndroid Build Coastguard Worker            res[0] += torch.rand(2, requires_grad=True)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_(True)
79*da0073e9SAndroid Build Coastguard Worker        x1 = torch.view_as_complex(x)
80*da0073e9SAndroid Build Coastguard Worker        # Taking an invalid view should always be allowed as long as it is not
81*da0073e9SAndroid Build Coastguard Worker        # modified inplace
82*da0073e9SAndroid Build Coastguard Worker        res = x1.unbind(0)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
85*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "output of a function that returns multiple views"
86*da0073e9SAndroid Build Coastguard Worker        ):
87*da0073e9SAndroid Build Coastguard Worker            res[0] += torch.rand(2, requires_grad=True)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    def as_identity(self):
90*da0073e9SAndroid Build Coastguard Worker        # view_as_real and view_as_complex behavior should be like an identity
91*da0073e9SAndroid Build Coastguard Worker        def func(z):
92*da0073e9SAndroid Build Coastguard Worker            z_ = torch.view_as_complex(z)
93*da0073e9SAndroid Build Coastguard Worker            z_select = torch.select(z_, z_.dim() - 1, 0)
94*da0073e9SAndroid Build Coastguard Worker            z_select_real = torch.view_as_real(z_select)
95*da0073e9SAndroid Build Coastguard Worker            return z_select_real.sum()
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True)
98*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [z])
99*da0073e9SAndroid Build Coastguard Worker        func(z).backward()
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        z1 = z.clone().detach().requires_grad_(True)
102*da0073e9SAndroid Build Coastguard Worker        torch.select(z1, z1.dim() - 2, 0).sum().backward()
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.grad, z1.grad)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
108*da0073e9SAndroid Build Coastguard Worker    run_tests()
109