1# Owner(s): ["module: autograd"] 2 3import contextlib 4import warnings 5 6import numpy as np 7 8import torch 9from torch.library import _scoped_library, Library 10from torch.testing._internal.common_utils import ( 11 instantiate_parametrized_tests, 12 parametrize, 13 run_tests, 14 TestCase, 15) 16 17 18@contextlib.contextmanager 19def autograd_fallback_mode(mode): 20 prev = torch._C._get_autograd_fallback_mode() 21 try: 22 torch._C._set_autograd_fallback_mode(mode) 23 yield 24 finally: 25 torch._C._set_autograd_fallback_mode(prev) 26 27 28class TestAutogradFallback(TestCase): 29 test_ns = "_test_autograd_fallback" 30 31 def tearDown(self): 32 if hasattr(torch.ops, self.test_ns): 33 delattr(torch.ops, self.test_ns) 34 if hasattr(self, "lib"): 35 del self.lib.m 36 del self.lib 37 38 def get_op(self, name): 39 return getattr(getattr(torch.ops, self.test_ns), name).default 40 41 def get_lib(self): 42 lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 43 self.lib = lib 44 return lib 45 46 @parametrize("mode", ("nothing", "warn")) 47 def test_no_grad(self, mode): 48 with autograd_fallback_mode(mode): 49 lib = self.get_lib() 50 lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") 51 lib.impl("foo", lambda a, b, c: a + b + c, "CPU") 52 op = self.get_op("foo") 53 54 with warnings.catch_warnings(): 55 warnings.simplefilter("error") 56 with torch.no_grad(): 57 a = torch.randn([], requires_grad=True) 58 b = torch.randn([], requires_grad=True) 59 out = op(a, b, 1) 60 self.assertFalse(out.requires_grad) 61 62 with warnings.catch_warnings(): 63 warnings.simplefilter("error") 64 a = torch.randn([]) 65 b = torch.randn([]) 66 out = op(a, b, 1) 67 self.assertFalse(out.requires_grad) 68 69 @parametrize("mode", ("nothing", "warn")) 70 def test_no_autograd_kernel(self, mode): 71 with autograd_fallback_mode(mode): 72 lib = self.get_lib() 73 lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") 74 op = self.get_op("foo") 75 76 def foo_impl(a, b, c): 77 result = a.detach().numpy() + b.detach().numpy() + c 78 return torch.tensor(result) 79 80 lib.impl("foo", foo_impl, "CPU") 81 82 # Some inputs requiring grad 83 a = torch.randn([], requires_grad=False) 84 b = torch.randn([], requires_grad=True) 85 out = op(a, b, 1).sum() 86 with self._check_ctx(mode, mode_nothing_raises=True): 87 out.backward() 88 self.assertIsNone(b.grad) 89 90 def _check_ctx(self, mode, *, mode_nothing_raises=False): 91 if mode == "warn": 92 return self.assertWarnsRegex( 93 UserWarning, "an autograd kernel was not registered" 94 ) 95 assert mode == "nothing" 96 if mode_nothing_raises: 97 return self.assertRaisesRegex(RuntimeError, "does not require grad") 98 return contextlib.nullcontext() 99 100 @parametrize("mode", ("nothing", "warn")) 101 def test_no_autograd_kernel_inplace(self, mode): 102 with autograd_fallback_mode(mode): 103 # input modified in-place gets returned as output 104 lib = self.get_lib() 105 lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))") 106 op = self.get_op("foo") 107 108 def foo_impl(x, y): 109 with torch.no_grad(): 110 x.sin_() 111 y.cos_() 112 return x, y 113 114 lib.impl("foo", foo_impl, "CPU") 115 116 x = torch.randn(3, requires_grad=True) 117 w = x.clone() 118 v = x.clone() 119 y0 = w[0] 120 y1 = v[1] 121 z0, z1 = op(y0, y1) 122 for tensor in [w, v, z0, z1, y0, y1]: 123 with self._check_ctx(mode): 124 tensor.sum().backward(retain_graph=True) 125 126 # no outputs: we don't do anything. Maybe we should in the future. 127 # This is not a common failure mode. 128 lib.define("bar(Tensor(a!) self) -> ()") 129 op = self.get_op("bar") 130 131 def bar_impl(x): 132 with torch.no_grad(): 133 x.sin_() 134 135 lib.impl("bar", bar_impl, "CPU") 136 with warnings.catch_warnings(): 137 warnings.simplefilter("error") 138 x = torch.randn([], requires_grad=True) 139 y = x.clone() 140 z = op(y) 141 y.backward() 142 self.assertEqual(x.grad, torch.ones_like(x)) 143 144 @parametrize("mode", ("nothing", "warn")) 145 def test_cpu_return_self(self, mode): 146 with autograd_fallback_mode(mode): 147 # To be clear, none of these situations are OK and will lead 148 # to other problems down the line. We're testing them because 149 # it is fairly common to actually do these things. 150 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 151 lib.define("foo(Tensor self) -> Tensor") 152 lib.impl("foo", lambda x: x, "CPU") 153 op = self.get_op("foo") 154 155 x = torch.randn(3, requires_grad=True) 156 y = op(x).sum() 157 with self._check_ctx(mode): 158 y.backward() 159 self.assertEqual(x.grad, torch.ones_like(x)) 160 161 lib.define("bar(Tensor(a!) self) -> Tensor(a!)") 162 lib.impl("bar", lambda x: x, "CPU") 163 op = self.get_op("bar") 164 165 x = torch.randn(3, requires_grad=True) 166 y = op(x).sum() 167 with self._check_ctx(mode): 168 y.backward() 169 self.assertEqual(x.grad, torch.ones_like(x)) 170 171 @parametrize("mode", ("nothing", "warn")) 172 def test_composite_registered_to_cpu(self, mode): 173 with autograd_fallback_mode(mode): 174 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 175 lib.define("foo(Tensor self) -> Tensor") 176 lib.impl("foo", lambda x: x.sin().sum(), "CPU") 177 op = self.get_op("foo") 178 179 x = torch.randn(3, requires_grad=True) 180 y = op(x) 181 with self._check_ctx(mode): 182 y.backward() 183 self.assertEqual(x.grad, x.cos()) 184 185 @parametrize("mode", ("nothing", "warn")) 186 def test_autograd_function_registered_to_cpu(self, mode): 187 with autograd_fallback_mode(mode): 188 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 189 lib.define("foo(Tensor self) -> Tensor") 190 191 class NumpySin(torch.autograd.Function): 192 @staticmethod 193 def forward(ctx, x): 194 ctx.save_for_backward(x) 195 return torch.tensor(np.sin(x.cpu().numpy())) 196 197 @staticmethod 198 def backward(ctx, gx): 199 (x,) = ctx.saved_tensors 200 return gx * x.cos() 201 202 lib.impl("foo", NumpySin.apply, "CPU") 203 op = self.get_op("foo") 204 205 x = torch.randn(3, requires_grad=True) 206 y = op(x).sum() 207 with self._check_ctx(mode): 208 y.backward() 209 self.assertEqual(x.grad, x.cos()) 210 211 @parametrize("mode", ("nothing", "warn")) 212 def test_inplace_autograd_function_registered_to_cpu(self, mode): 213 with autograd_fallback_mode(mode): 214 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 215 lib.define("foo(Tensor(a!) self) -> Tensor(a!)") 216 217 class NumpySin_(torch.autograd.Function): 218 @staticmethod 219 def forward(ctx, x): 220 ctx.save_for_backward(x.clone()) 221 x_np = x.detach().numpy() 222 np.sin(x_np, out=x_np) 223 ctx.mark_dirty(x) 224 return x 225 226 @staticmethod 227 def backward(ctx, gx): 228 (x,) = ctx.saved_tensors 229 return gx * x.cos() 230 231 lib.impl("foo", NumpySin_.apply, "CPU") 232 op = self.get_op("foo") 233 234 x = torch.randn(3, requires_grad=True) 235 z = x.clone() 236 w = z[0] 237 y = op(w) 238 239 expected = torch.zeros_like(x) 240 expected[0] = x[0].cos() 241 with self._check_ctx(mode): 242 (gx,) = torch.autograd.grad( 243 y, x, torch.ones_like(y), retain_graph=True 244 ) 245 self.assertEqual(gx, expected) 246 247 expected = torch.ones_like(x) 248 expected[0] = x[0].cos() 249 with self._check_ctx(mode): 250 (gx,) = torch.autograd.grad(z, x, torch.ones_like(z)) 251 self.assertEqual(gx, expected) 252 253 @parametrize("mode", ("nothing", "warn")) 254 def test_inplace_on_tensor_that_does_not_require_grad(self, mode): 255 # We don't do anything special (that is, we don't rebase history). 256 # See NOTE [autograd fallback and in-place operations] for why 257 with autograd_fallback_mode(mode): 258 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 259 # Correct usage of (a!) 260 lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)") 261 262 def foo_impl(x, y): 263 x_d = x.detach() 264 y = y.detach() 265 x_d.add_(y) 266 return x 267 268 lib.impl("foo", foo_impl, "CPU") 269 foo = self.get_op("foo") 270 271 # Incorrect usage of (a!): user doesn't return tensor as-is 272 lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)") 273 274 def bar_impl(x, y): 275 x_d = x.detach() 276 y = y.detach() 277 x_d.add_(y) 278 return x_d.clone() 279 280 lib.impl("bar", bar_impl, "CPU") 281 bar = self.get_op("bar") 282 283 # User mutated input tensor but didn't return it. 284 lib.define("baz(Tensor(a!) self, Tensor other) -> ()") 285 286 def baz_impl(x, y): 287 x_d = x.detach() 288 y = y.detach() 289 x_d.add_(y) 290 291 lib.impl("baz", baz_impl, "CPU") 292 baz = self.get_op("baz") 293 294 # Test in-place on non-view 295 for op in (foo, bar, baz): 296 x = torch.randn(3) 297 y = torch.randn(3, requires_grad=True) 298 with self.assertRaisesRegex(RuntimeError, "does not require grad"): 299 z = x.clone() 300 op(z, y) 301 torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True) 302 303 # Test in-place on view 304 for op in (foo, bar, baz): 305 x = torch.randn(3) 306 y = torch.randn(3, requires_grad=True) 307 with self.assertRaisesRegex(RuntimeError, "does not require grad"): 308 z = x[:] 309 op(z, y) 310 torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True) 311 312 @parametrize("mode", ("nothing", "warn")) 313 def test_post_autograd_returns_leaf(self, mode): 314 with autograd_fallback_mode(mode): 315 lib = self.get_lib() 316 lib.define("foo(Tensor a) -> (Tensor, Tensor)") 317 op = self.get_op("foo") 318 319 lib.impl( 320 "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU" 321 ) 322 x = torch.randn(3, requires_grad=True) 323 y, z = op(x) 324 with self._check_ctx(mode): 325 z.sum().backward() 326 327 @parametrize("mode", ("nothing", "warn")) 328 def test_undefined_inputs_outputs(self, mode): 329 with autograd_fallback_mode(mode): 330 lib = self.get_lib() 331 lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") 332 op = self.get_op("foo") 333 334 def foo_impl(a, b): 335 return None, b.clone() 336 337 lib.impl("foo", foo_impl, "CPU") 338 339 x = torch.randn(3, requires_grad=True) 340 # NB: PyTorch dispatcher treats "None" as undefined Tensor. 341 y, z = op(None, x) 342 with self._check_ctx(mode): 343 z.sum().backward() 344 345 @parametrize("mode", ("nothing", "warn")) 346 def test_undefined_grads(self, mode): 347 with autograd_fallback_mode(mode): 348 lib = self.get_lib() 349 lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") 350 op = self.get_op("foo") 351 352 def foo_impl(a, b): 353 return a.sin(), b.cos() 354 355 lib.impl("foo", foo_impl, "CPU") 356 357 x = torch.randn(3, requires_grad=True) 358 y = torch.randn(3) 359 w, z = op(x, y) 360 w = torch._C._functions.UndefinedGrad()(w) 361 z = torch._C._functions.UndefinedGrad()(z) 362 with self._check_ctx(mode): 363 (z + w).sum().backward() 364 365 @parametrize("mode", ("nothing", "warn")) 366 def test_base_does_not_require_grad(self, mode): 367 with autograd_fallback_mode(mode): 368 lib = self.get_lib() 369 lib.define("foo(Tensor(a!) x) -> Tensor(a!)") 370 op = self.get_op("foo") 371 372 def foo_impl(a): 373 with torch.no_grad(): 374 return a.zero_() 375 376 lib.impl("foo", foo_impl, "CPU") 377 x = torch.randn(3) 378 y = x[:] 379 y.requires_grad_() 380 w = y[:] 381 self.assertTrue(w._base is x) 382 383 # Hook should be registered on w, but not w._base 384 op(w) 385 with self._check_ctx(mode): 386 w.sum().backward() 387 388 @parametrize("mode", ("nothing", "warn")) 389 def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode): 390 with autograd_fallback_mode(mode): 391 lib = self.get_lib() 392 lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)") 393 op = self.get_op("foo") 394 395 def foo_impl(a, b): 396 with torch.no_grad(): 397 x = a.clone() 398 z = b.clone() 399 y = a * b 400 return x, y, z 401 402 lib.impl("foo", foo_impl, "CPU") 403 a = torch.randn(3, requires_grad=True) 404 b = torch.randn(3, requires_grad=True) 405 x, y, z = op(a, b) 406 407 with self._check_ctx(mode, mode_nothing_raises=True): 408 torch.autograd.grad( 409 x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True 410 ) 411 412 with self._check_ctx(mode, mode_nothing_raises=False): 413 torch.autograd.grad( 414 y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True 415 ) 416 417 with self._check_ctx(mode, mode_nothing_raises=True): 418 torch.autograd.grad( 419 z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True 420 ) 421 422 @parametrize("mode", ("nothing", "warn")) 423 def test_supports_tensor_lists(self, mode): 424 with autograd_fallback_mode(mode): 425 lib = self.get_lib() 426 lib.define("foo(Tensor[] a) -> Tensor[]") 427 op = self.get_op("foo") 428 429 def foo_impl(a): 430 x, y, z = a 431 with torch.no_grad(): 432 return x + y + z, x * y * z 433 434 lib.impl("foo", foo_impl, "CPU") 435 x = torch.randn(3, requires_grad=True) 436 y = torch.randn(1, requires_grad=True) 437 z = torch.randn(2, 1, requires_grad=True) 438 a, b = op([x, y, z]) 439 with self._check_ctx(mode, mode_nothing_raises=True): 440 torch.autograd.grad( 441 a, 442 (x, y, z), 443 torch.ones_like(a), 444 allow_unused=True, 445 retain_graph=True, 446 ) 447 with self._check_ctx(mode, mode_nothing_raises=True): 448 torch.autograd.grad( 449 b, 450 (x, y, z), 451 torch.ones_like(b), 452 allow_unused=True, 453 retain_graph=True, 454 ) 455 456 457instantiate_parametrized_tests(TestAutogradFallback) 458 459if __name__ == "__main__": 460 run_tests() 461