1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import gradcheck, run_tests, TestCase 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerclass TestAutogradComplex(TestCase): 8*da0073e9SAndroid Build Coastguard Worker def test_view_func_for_complex_views(self): 9*da0073e9SAndroid Build Coastguard Worker # case 1: both parent and child have view_func 10*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) 11*da0073e9SAndroid Build Coastguard Worker y = x.detach().requires_grad_(True) 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker x0 = x.clone() 14*da0073e9SAndroid Build Coastguard Worker x1 = torch.view_as_complex(x0) 15*da0073e9SAndroid Build Coastguard Worker x2 = torch.view_as_real(x1) 16*da0073e9SAndroid Build Coastguard Worker x2.mul_(2) 17*da0073e9SAndroid Build Coastguard Worker x2.sum().abs().backward() 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker y0 = y.clone() 20*da0073e9SAndroid Build Coastguard Worker y0.mul_(2) 21*da0073e9SAndroid Build Coastguard Worker y0.sum().abs().backward() 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, y.grad) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker # case 2: parent has view_func but child does not 26*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) 27*da0073e9SAndroid Build Coastguard Worker y = x.detach().requires_grad_(True) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def fn(a): 30*da0073e9SAndroid Build Coastguard Worker b = a.clone() 31*da0073e9SAndroid Build Coastguard Worker b1 = torch.view_as_complex(b) 32*da0073e9SAndroid Build Coastguard Worker b2 = b1.reshape(b1.numel()) 33*da0073e9SAndroid Build Coastguard Worker return b2 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker x0 = fn(x) 36*da0073e9SAndroid Build Coastguard Worker x0.mul_(2) 37*da0073e9SAndroid Build Coastguard Worker x0.sum().abs().backward() 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker y0 = fn(y) 40*da0073e9SAndroid Build Coastguard Worker y1 = y0.mul(2) 41*da0073e9SAndroid Build Coastguard Worker y1.sum().abs().backward() 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, y.grad) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker # case 3: parent does not have a view_func but child does 46*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, dtype=torch.cdouble, requires_grad=True) 47*da0073e9SAndroid Build Coastguard Worker y = x.detach().requires_grad_(True) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def fn(a, dim0_size=5): 50*da0073e9SAndroid Build Coastguard Worker b = a.clone() 51*da0073e9SAndroid Build Coastguard Worker b1 = b.reshape(dim0_size, 2) 52*da0073e9SAndroid Build Coastguard Worker b2 = torch.view_as_real(b1) 53*da0073e9SAndroid Build Coastguard Worker return b2 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker x0 = fn(x) 56*da0073e9SAndroid Build Coastguard Worker x0.mul_(2) 57*da0073e9SAndroid Build Coastguard Worker x0.sum().abs().backward() 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker y0 = fn(y) 60*da0073e9SAndroid Build Coastguard Worker y1 = y0.mul(2) 61*da0073e9SAndroid Build Coastguard Worker y1.sum().abs().backward() 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, y.grad) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_view_with_multi_output(self): 66*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, 2, dtype=torch.double) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker x1 = torch.view_as_complex(x) 69*da0073e9SAndroid Build Coastguard Worker # Taking an invalid view should always be allowed as long as it is not 70*da0073e9SAndroid Build Coastguard Worker # modified inplace 71*da0073e9SAndroid Build Coastguard Worker res = x1.unbind(0) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 74*da0073e9SAndroid Build Coastguard Worker RuntimeError, "output of a function that returns multiple views" 75*da0073e9SAndroid Build Coastguard Worker ): 76*da0073e9SAndroid Build Coastguard Worker res[0] += torch.rand(2, requires_grad=True) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 79*da0073e9SAndroid Build Coastguard Worker x1 = torch.view_as_complex(x) 80*da0073e9SAndroid Build Coastguard Worker # Taking an invalid view should always be allowed as long as it is not 81*da0073e9SAndroid Build Coastguard Worker # modified inplace 82*da0073e9SAndroid Build Coastguard Worker res = x1.unbind(0) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 85*da0073e9SAndroid Build Coastguard Worker RuntimeError, "output of a function that returns multiple views" 86*da0073e9SAndroid Build Coastguard Worker ): 87*da0073e9SAndroid Build Coastguard Worker res[0] += torch.rand(2, requires_grad=True) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def as_identity(self): 90*da0073e9SAndroid Build Coastguard Worker # view_as_real and view_as_complex behavior should be like an identity 91*da0073e9SAndroid Build Coastguard Worker def func(z): 92*da0073e9SAndroid Build Coastguard Worker z_ = torch.view_as_complex(z) 93*da0073e9SAndroid Build Coastguard Worker z_select = torch.select(z_, z_.dim() - 1, 0) 94*da0073e9SAndroid Build Coastguard Worker z_select_real = torch.view_as_real(z_select) 95*da0073e9SAndroid Build Coastguard Worker return z_select_real.sum() 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True) 98*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [z]) 99*da0073e9SAndroid Build Coastguard Worker func(z).backward() 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker z1 = z.clone().detach().requires_grad_(True) 102*da0073e9SAndroid Build Coastguard Worker torch.select(z1, z1.dim() - 2, 0).sum().backward() 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.grad, z1.grad) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 108*da0073e9SAndroid Build Coastguard Worker run_tests() 109