1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport warnings 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport numpy as np 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerfrom torch.library import _scoped_library, Library 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 11*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 12*da0073e9SAndroid Build Coastguard Worker parametrize, 13*da0073e9SAndroid Build Coastguard Worker run_tests, 14*da0073e9SAndroid Build Coastguard Worker TestCase, 15*da0073e9SAndroid Build Coastguard Worker) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 19*da0073e9SAndroid Build Coastguard Workerdef autograd_fallback_mode(mode): 20*da0073e9SAndroid Build Coastguard Worker prev = torch._C._get_autograd_fallback_mode() 21*da0073e9SAndroid Build Coastguard Worker try: 22*da0073e9SAndroid Build Coastguard Worker torch._C._set_autograd_fallback_mode(mode) 23*da0073e9SAndroid Build Coastguard Worker yield 24*da0073e9SAndroid Build Coastguard Worker finally: 25*da0073e9SAndroid Build Coastguard Worker torch._C._set_autograd_fallback_mode(prev) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerclass TestAutogradFallback(TestCase): 29*da0073e9SAndroid Build Coastguard Worker test_ns = "_test_autograd_fallback" 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 32*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.ops, self.test_ns): 33*da0073e9SAndroid Build Coastguard Worker delattr(torch.ops, self.test_ns) 34*da0073e9SAndroid Build Coastguard Worker if hasattr(self, "lib"): 35*da0073e9SAndroid Build Coastguard Worker del self.lib.m 36*da0073e9SAndroid Build Coastguard Worker del self.lib 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker def get_op(self, name): 39*da0073e9SAndroid Build Coastguard Worker return getattr(getattr(torch.ops, self.test_ns), name).default 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def get_lib(self): 42*da0073e9SAndroid Build Coastguard Worker lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 43*da0073e9SAndroid Build Coastguard Worker self.lib = lib 44*da0073e9SAndroid Build Coastguard Worker return lib 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 47*da0073e9SAndroid Build Coastguard Worker def test_no_grad(self, mode): 48*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 49*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 50*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") 51*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", lambda a, b, c: a + b + c, "CPU") 52*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 55*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error") 56*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 57*da0073e9SAndroid Build Coastguard Worker a = torch.randn([], requires_grad=True) 58*da0073e9SAndroid Build Coastguard Worker b = torch.randn([], requires_grad=True) 59*da0073e9SAndroid Build Coastguard Worker out = op(a, b, 1) 60*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 63*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error") 64*da0073e9SAndroid Build Coastguard Worker a = torch.randn([]) 65*da0073e9SAndroid Build Coastguard Worker b = torch.randn([]) 66*da0073e9SAndroid Build Coastguard Worker out = op(a, b, 1) 67*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 70*da0073e9SAndroid Build Coastguard Worker def test_no_autograd_kernel(self, mode): 71*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 72*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 73*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") 74*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def foo_impl(a, b, c): 77*da0073e9SAndroid Build Coastguard Worker result = a.detach().numpy() + b.detach().numpy() + c 78*da0073e9SAndroid Build Coastguard Worker return torch.tensor(result) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker # Some inputs requiring grad 83*da0073e9SAndroid Build Coastguard Worker a = torch.randn([], requires_grad=False) 84*da0073e9SAndroid Build Coastguard Worker b = torch.randn([], requires_grad=True) 85*da0073e9SAndroid Build Coastguard Worker out = op(a, b, 1).sum() 86*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode, mode_nothing_raises=True): 87*da0073e9SAndroid Build Coastguard Worker out.backward() 88*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(b.grad) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def _check_ctx(self, mode, *, mode_nothing_raises=False): 91*da0073e9SAndroid Build Coastguard Worker if mode == "warn": 92*da0073e9SAndroid Build Coastguard Worker return self.assertWarnsRegex( 93*da0073e9SAndroid Build Coastguard Worker UserWarning, "an autograd kernel was not registered" 94*da0073e9SAndroid Build Coastguard Worker ) 95*da0073e9SAndroid Build Coastguard Worker assert mode == "nothing" 96*da0073e9SAndroid Build Coastguard Worker if mode_nothing_raises: 97*da0073e9SAndroid Build Coastguard Worker return self.assertRaisesRegex(RuntimeError, "does not require grad") 98*da0073e9SAndroid Build Coastguard Worker return contextlib.nullcontext() 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 101*da0073e9SAndroid Build Coastguard Worker def test_no_autograd_kernel_inplace(self, mode): 102*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 103*da0073e9SAndroid Build Coastguard Worker # input modified in-place gets returned as output 104*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 105*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))") 106*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y): 109*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 110*da0073e9SAndroid Build Coastguard Worker x.sin_() 111*da0073e9SAndroid Build Coastguard Worker y.cos_() 112*da0073e9SAndroid Build Coastguard Worker return x, y 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 117*da0073e9SAndroid Build Coastguard Worker w = x.clone() 118*da0073e9SAndroid Build Coastguard Worker v = x.clone() 119*da0073e9SAndroid Build Coastguard Worker y0 = w[0] 120*da0073e9SAndroid Build Coastguard Worker y1 = v[1] 121*da0073e9SAndroid Build Coastguard Worker z0, z1 = op(y0, y1) 122*da0073e9SAndroid Build Coastguard Worker for tensor in [w, v, z0, z1, y0, y1]: 123*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 124*da0073e9SAndroid Build Coastguard Worker tensor.sum().backward(retain_graph=True) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker # no outputs: we don't do anything. Maybe we should in the future. 127*da0073e9SAndroid Build Coastguard Worker # This is not a common failure mode. 128*da0073e9SAndroid Build Coastguard Worker lib.define("bar(Tensor(a!) self) -> ()") 129*da0073e9SAndroid Build Coastguard Worker op = self.get_op("bar") 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def bar_impl(x): 132*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 133*da0073e9SAndroid Build Coastguard Worker x.sin_() 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker lib.impl("bar", bar_impl, "CPU") 136*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 137*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error") 138*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 139*da0073e9SAndroid Build Coastguard Worker y = x.clone() 140*da0073e9SAndroid Build Coastguard Worker z = op(y) 141*da0073e9SAndroid Build Coastguard Worker y.backward() 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x)) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 145*da0073e9SAndroid Build Coastguard Worker def test_cpu_return_self(self, mode): 146*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 147*da0073e9SAndroid Build Coastguard Worker # To be clear, none of these situations are OK and will lead 148*da0073e9SAndroid Build Coastguard Worker # to other problems down the line. We're testing them because 149*da0073e9SAndroid Build Coastguard Worker # it is fairly common to actually do these things. 150*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 151*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor self) -> Tensor") 152*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", lambda x: x, "CPU") 153*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 156*da0073e9SAndroid Build Coastguard Worker y = op(x).sum() 157*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 158*da0073e9SAndroid Build Coastguard Worker y.backward() 159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x)) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker lib.define("bar(Tensor(a!) self) -> Tensor(a!)") 162*da0073e9SAndroid Build Coastguard Worker lib.impl("bar", lambda x: x, "CPU") 163*da0073e9SAndroid Build Coastguard Worker op = self.get_op("bar") 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 166*da0073e9SAndroid Build Coastguard Worker y = op(x).sum() 167*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 168*da0073e9SAndroid Build Coastguard Worker y.backward() 169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x)) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 172*da0073e9SAndroid Build Coastguard Worker def test_composite_registered_to_cpu(self, mode): 173*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 174*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 175*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor self) -> Tensor") 176*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", lambda x: x.sin().sum(), "CPU") 177*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 180*da0073e9SAndroid Build Coastguard Worker y = op(x) 181*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 182*da0073e9SAndroid Build Coastguard Worker y.backward() 183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x.cos()) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 186*da0073e9SAndroid Build Coastguard Worker def test_autograd_function_registered_to_cpu(self, mode): 187*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 188*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 189*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor self) -> Tensor") 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker class NumpySin(torch.autograd.Function): 192*da0073e9SAndroid Build Coastguard Worker @staticmethod 193*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 194*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 195*da0073e9SAndroid Build Coastguard Worker return torch.tensor(np.sin(x.cpu().numpy())) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker @staticmethod 198*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 199*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 200*da0073e9SAndroid Build Coastguard Worker return gx * x.cos() 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", NumpySin.apply, "CPU") 203*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 206*da0073e9SAndroid Build Coastguard Worker y = op(x).sum() 207*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 208*da0073e9SAndroid Build Coastguard Worker y.backward() 209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x.cos()) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 212*da0073e9SAndroid Build Coastguard Worker def test_inplace_autograd_function_registered_to_cpu(self, mode): 213*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 214*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 215*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a!) self) -> Tensor(a!)") 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker class NumpySin_(torch.autograd.Function): 218*da0073e9SAndroid Build Coastguard Worker @staticmethod 219*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 220*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x.clone()) 221*da0073e9SAndroid Build Coastguard Worker x_np = x.detach().numpy() 222*da0073e9SAndroid Build Coastguard Worker np.sin(x_np, out=x_np) 223*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(x) 224*da0073e9SAndroid Build Coastguard Worker return x 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker @staticmethod 227*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 228*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 229*da0073e9SAndroid Build Coastguard Worker return gx * x.cos() 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", NumpySin_.apply, "CPU") 232*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 235*da0073e9SAndroid Build Coastguard Worker z = x.clone() 236*da0073e9SAndroid Build Coastguard Worker w = z[0] 237*da0073e9SAndroid Build Coastguard Worker y = op(w) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros_like(x) 240*da0073e9SAndroid Build Coastguard Worker expected[0] = x[0].cos() 241*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 242*da0073e9SAndroid Build Coastguard Worker (gx,) = torch.autograd.grad( 243*da0073e9SAndroid Build Coastguard Worker y, x, torch.ones_like(y), retain_graph=True 244*da0073e9SAndroid Build Coastguard Worker ) 245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx, expected) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker expected = torch.ones_like(x) 248*da0073e9SAndroid Build Coastguard Worker expected[0] = x[0].cos() 249*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 250*da0073e9SAndroid Build Coastguard Worker (gx,) = torch.autograd.grad(z, x, torch.ones_like(z)) 251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx, expected) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 254*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_tensor_that_does_not_require_grad(self, mode): 255*da0073e9SAndroid Build Coastguard Worker # We don't do anything special (that is, we don't rebase history). 256*da0073e9SAndroid Build Coastguard Worker # See NOTE [autograd fallback and in-place operations] for why 257*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 258*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 259*da0073e9SAndroid Build Coastguard Worker # Correct usage of (a!) 260*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)") 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y): 263*da0073e9SAndroid Build Coastguard Worker x_d = x.detach() 264*da0073e9SAndroid Build Coastguard Worker y = y.detach() 265*da0073e9SAndroid Build Coastguard Worker x_d.add_(y) 266*da0073e9SAndroid Build Coastguard Worker return x 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 269*da0073e9SAndroid Build Coastguard Worker foo = self.get_op("foo") 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker # Incorrect usage of (a!): user doesn't return tensor as-is 272*da0073e9SAndroid Build Coastguard Worker lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)") 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker def bar_impl(x, y): 275*da0073e9SAndroid Build Coastguard Worker x_d = x.detach() 276*da0073e9SAndroid Build Coastguard Worker y = y.detach() 277*da0073e9SAndroid Build Coastguard Worker x_d.add_(y) 278*da0073e9SAndroid Build Coastguard Worker return x_d.clone() 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker lib.impl("bar", bar_impl, "CPU") 281*da0073e9SAndroid Build Coastguard Worker bar = self.get_op("bar") 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker # User mutated input tensor but didn't return it. 284*da0073e9SAndroid Build Coastguard Worker lib.define("baz(Tensor(a!) self, Tensor other) -> ()") 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker def baz_impl(x, y): 287*da0073e9SAndroid Build Coastguard Worker x_d = x.detach() 288*da0073e9SAndroid Build Coastguard Worker y = y.detach() 289*da0073e9SAndroid Build Coastguard Worker x_d.add_(y) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker lib.impl("baz", baz_impl, "CPU") 292*da0073e9SAndroid Build Coastguard Worker baz = self.get_op("baz") 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker # Test in-place on non-view 295*da0073e9SAndroid Build Coastguard Worker for op in (foo, bar, baz): 296*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 297*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, requires_grad=True) 298*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not require grad"): 299*da0073e9SAndroid Build Coastguard Worker z = x.clone() 300*da0073e9SAndroid Build Coastguard Worker op(z, y) 301*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker # Test in-place on view 304*da0073e9SAndroid Build Coastguard Worker for op in (foo, bar, baz): 305*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 306*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, requires_grad=True) 307*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not require grad"): 308*da0073e9SAndroid Build Coastguard Worker z = x[:] 309*da0073e9SAndroid Build Coastguard Worker op(z, y) 310*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 313*da0073e9SAndroid Build Coastguard Worker def test_post_autograd_returns_leaf(self, mode): 314*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 315*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 316*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor a) -> (Tensor, Tensor)") 317*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker lib.impl( 320*da0073e9SAndroid Build Coastguard Worker "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU" 321*da0073e9SAndroid Build Coastguard Worker ) 322*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 323*da0073e9SAndroid Build Coastguard Worker y, z = op(x) 324*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 325*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 328*da0073e9SAndroid Build Coastguard Worker def test_undefined_inputs_outputs(self, mode): 329*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 330*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 331*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") 332*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker def foo_impl(a, b): 335*da0073e9SAndroid Build Coastguard Worker return None, b.clone() 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 340*da0073e9SAndroid Build Coastguard Worker # NB: PyTorch dispatcher treats "None" as undefined Tensor. 341*da0073e9SAndroid Build Coastguard Worker y, z = op(None, x) 342*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 343*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 346*da0073e9SAndroid Build Coastguard Worker def test_undefined_grads(self, mode): 347*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 348*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 349*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") 350*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def foo_impl(a, b): 353*da0073e9SAndroid Build Coastguard Worker return a.sin(), b.cos() 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 358*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 359*da0073e9SAndroid Build Coastguard Worker w, z = op(x, y) 360*da0073e9SAndroid Build Coastguard Worker w = torch._C._functions.UndefinedGrad()(w) 361*da0073e9SAndroid Build Coastguard Worker z = torch._C._functions.UndefinedGrad()(z) 362*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 363*da0073e9SAndroid Build Coastguard Worker (z + w).sum().backward() 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 366*da0073e9SAndroid Build Coastguard Worker def test_base_does_not_require_grad(self, mode): 367*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 368*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 369*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor(a!) x) -> Tensor(a!)") 370*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker def foo_impl(a): 373*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 374*da0073e9SAndroid Build Coastguard Worker return a.zero_() 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 377*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 378*da0073e9SAndroid Build Coastguard Worker y = x[:] 379*da0073e9SAndroid Build Coastguard Worker y.requires_grad_() 380*da0073e9SAndroid Build Coastguard Worker w = y[:] 381*da0073e9SAndroid Build Coastguard Worker self.assertTrue(w._base is x) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker # Hook should be registered on w, but not w._base 384*da0073e9SAndroid Build Coastguard Worker op(w) 385*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode): 386*da0073e9SAndroid Build Coastguard Worker w.sum().backward() 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 389*da0073e9SAndroid Build Coastguard Worker def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode): 390*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 391*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 392*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)") 393*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker def foo_impl(a, b): 396*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 397*da0073e9SAndroid Build Coastguard Worker x = a.clone() 398*da0073e9SAndroid Build Coastguard Worker z = b.clone() 399*da0073e9SAndroid Build Coastguard Worker y = a * b 400*da0073e9SAndroid Build Coastguard Worker return x, y, z 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 403*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, requires_grad=True) 404*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, requires_grad=True) 405*da0073e9SAndroid Build Coastguard Worker x, y, z = op(a, b) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode, mode_nothing_raises=True): 408*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 409*da0073e9SAndroid Build Coastguard Worker x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode, mode_nothing_raises=False): 413*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 414*da0073e9SAndroid Build Coastguard Worker y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True 415*da0073e9SAndroid Build Coastguard Worker ) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode, mode_nothing_raises=True): 418*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 419*da0073e9SAndroid Build Coastguard Worker z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True 420*da0073e9SAndroid Build Coastguard Worker ) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker @parametrize("mode", ("nothing", "warn")) 423*da0073e9SAndroid Build Coastguard Worker def test_supports_tensor_lists(self, mode): 424*da0073e9SAndroid Build Coastguard Worker with autograd_fallback_mode(mode): 425*da0073e9SAndroid Build Coastguard Worker lib = self.get_lib() 426*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor[] a) -> Tensor[]") 427*da0073e9SAndroid Build Coastguard Worker op = self.get_op("foo") 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def foo_impl(a): 430*da0073e9SAndroid Build Coastguard Worker x, y, z = a 431*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 432*da0073e9SAndroid Build Coastguard Worker return x + y + z, x * y * z 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 435*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 436*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, requires_grad=True) 437*da0073e9SAndroid Build Coastguard Worker z = torch.randn(2, 1, requires_grad=True) 438*da0073e9SAndroid Build Coastguard Worker a, b = op([x, y, z]) 439*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode, mode_nothing_raises=True): 440*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 441*da0073e9SAndroid Build Coastguard Worker a, 442*da0073e9SAndroid Build Coastguard Worker (x, y, z), 443*da0073e9SAndroid Build Coastguard Worker torch.ones_like(a), 444*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 445*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 446*da0073e9SAndroid Build Coastguard Worker ) 447*da0073e9SAndroid Build Coastguard Worker with self._check_ctx(mode, mode_nothing_raises=True): 448*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 449*da0073e9SAndroid Build Coastguard Worker b, 450*da0073e9SAndroid Build Coastguard Worker (x, y, z), 451*da0073e9SAndroid Build Coastguard Worker torch.ones_like(b), 452*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 453*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 454*da0073e9SAndroid Build Coastguard Worker ) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestAutogradFallback) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 460*da0073e9SAndroid Build Coastguard Worker run_tests() 461