1# Owner(s): ["module: functorch"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import copy 10import math 11import os 12import subprocess 13import sys 14import unittest 15import warnings 16from functools import partial, wraps 17 18# NB: numpy is a testing dependency! 19import numpy as np 20from common_utils import expectedFailureIf 21 22import functorch 23import torch 24import torch.autograd.forward_ad as fwAD 25import torch.nn as nn 26import torch.nn.functional as F 27from functorch import ( 28 combine_state_for_ensemble, 29 grad, 30 grad_and_value, 31 hessian, 32 jacfwd, 33 jacrev, 34 jvp, 35 make_functional, 36 make_functional_with_buffers, 37 make_fx, 38 vjp, 39 vmap, 40) 41from functorch.experimental import functionalize, replace_all_batch_norm_modules_ 42from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet 43from torch._dynamo import allow_in_graph 44from torch._functorch.eager_transforms import _slice_argnums 45from torch._functorch.make_functional import ( 46 functional_init, 47 functional_init_with_buffers, 48) 49from torch._functorch.utils import enable_single_level_autograd_function 50from torch._ops import HigherOrderOperator 51from torch._subclasses.fake_tensor import FakeTensorMode 52from torch.func import functional_call, linearize, stack_module_state 53from torch.testing import make_tensor 54from torch.testing._internal.common_cuda import ( 55 SM70OrLater, 56 TEST_CUDA, 57 tf32_on_and_off, 58 with_tf32_off, 59) 60from torch.testing._internal.common_device_type import ( 61 dtypes, 62 instantiate_device_type_tests, 63 onlyCPU, 64 onlyCUDA, 65) 66from torch.testing._internal.common_dtype import get_all_fp_dtypes 67from torch.testing._internal.common_utils import ( 68 freeze_rng_state, 69 instantiate_parametrized_tests, 70 IS_FBCODE, 71 IS_WINDOWS, 72 markDynamoStrictTest, 73 parametrize, 74 run_tests, 75 skipIfRocm, 76 skipIfTorchDynamo, 77 subtest, 78 TEST_WITH_TORCHDYNAMO, 79 TestCase, 80 xfailIfTorchDynamo, 81) 82from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten 83 84 85USE_TORCHVISION = False 86try: 87 import torchvision # noqa: F401 88 89 USE_TORCHVISION = True 90except ImportError: 91 warnings.warn( 92 "Couldn't import torchvision. Some of our tests use it, try " 93 "to install it with commands from pytorch.org, post-fixed with " 94 "`--no-deps` to avoid overwriting the pytorch installation", 95 UserWarning, 96 ) 97 98# TestCase for _slice_argnums, an important helper function 99 100 101class VmapTearDownMixin: 102 def tearDown(self): 103 # Ensure that in the case of a test failure, the next test won't fail 104 # because of a previous call to _vmap_increment_nesting that wasn't undone 105 # i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1 106 # and the call to increment nesting is not undone 107 if not TEST_WITH_TORCHDYNAMO: 108 return 109 110 warn = False 111 while ci := torch._C._functorch.peek_interpreter_stack(): 112 if ci.key() == torch._C._functorch.TransformType.Vmap: 113 warn = True 114 torch._C._functorch._vmap_decrement_nesting() 115 else: 116 break 117 118 if warn: 119 msg = ( 120 "Interpreter stack is not empty. Test should have called " 121 "'torch._C._functorch._vmap_decrement_nesting()'" 122 ) 123 warnings.warn(msg) 124 125 126@markDynamoStrictTest 127class TestSliceArgnums(TestCase): 128 def test_invalid_argnum_type(self): 129 x = torch.randn(3) 130 args = (x,) 131 with self.assertRaisesRegex(RuntimeError, "int or Tuple"): 132 _slice_argnums(args, 0.0) 133 with self.assertRaisesRegex(RuntimeError, "int or Tuple"): 134 _slice_argnums(args, [0]) 135 with self.assertRaisesRegex(RuntimeError, "must be int"): 136 _slice_argnums(args, (0.0,)) 137 138 args = (0.1, 1.1, 2.1, 3.1, 4.1) 139 140 with self.assertRaisesRegex(RuntimeError, "must be int"): 141 _slice_argnums(args, ((0, 1), 2)) 142 143 def test_out_of_bounds_argnum_values(self): 144 x = torch.randn(3) 145 args = (x,) 146 with self.assertRaisesRegex(RuntimeError, "positional inputs"): 147 _slice_argnums(args, 1) 148 with self.assertRaisesRegex(RuntimeError, "positional inputs"): 149 _slice_argnums(args, -2) 150 with self.assertRaisesRegex(RuntimeError, "positional inputs"): 151 _slice_argnums(args, (-2,)) 152 153 def test_not_enough_argnums(self): 154 x = torch.randn(3) 155 args = (x,) 156 with self.assertRaisesRegex(RuntimeError, "must be non-empty"): 157 _slice_argnums(args, ()) 158 159 def test_duplicate_argnums(self): 160 x = torch.randn(3) 161 args = (x, x) 162 with self.assertRaisesRegex(RuntimeError, "must be unique"): 163 _slice_argnums(args, (0, 0)) 164 with self.assertRaisesRegex(RuntimeError, "must be unique"): 165 _slice_argnums(args, (0, -2)) 166 167 def test_flat_args_with_positive_int_argnum(self): 168 args = (0.1, 1.1, 2.1, 3.1, 4.1) 169 170 res = _slice_argnums(args, 0) 171 self.assertEqual(res, (0.1,)) 172 173 res = _slice_argnums(args, 4) 174 self.assertEqual(res, (4.1,)) 175 176 def test_flat_args_with_negative_int_argnum(self): 177 args = (0.1, 1.1, 2.1, 3.1, 4.1) 178 179 res = _slice_argnums(args, -1) 180 self.assertEqual(res, (4.1,)) 181 182 res = _slice_argnums(args, -5) 183 self.assertEqual(res, (0.1,)) 184 185 def test_flat_args_with_tuple_argnum(self): 186 args = (0.1, 1.1, 2.1, 3.1, 4.1) 187 188 res = _slice_argnums(args, (0, 1, 2, 3, 4)) 189 self.assertEqual(res, args) 190 191 res = _slice_argnums(args, (0, -3)) 192 self.assertEqual(res, (0.1, 2.1)) 193 194 def test_pytree_args(self): 195 args = ((0.1, 1.1), 2.0, [3.1]) 196 197 res = _slice_argnums(args, 0) 198 self.assertEqual(res, args[0:1]) 199 200 res = _slice_argnums(args, (0,)) 201 self.assertEqual(res, args[0:1]) 202 203 res = _slice_argnums(args, -1) 204 self.assertEqual(res, args[-1:]) 205 206 res = _slice_argnums(args, (0, -2)) 207 self.assertEqual(res, args[0:2]) 208 209 def test_argnums_reorders(self): 210 args = ((0.1, 1.1, 2.1), 3.1, 4.1) 211 212 res = _slice_argnums(args, (1, 0)) 213 self.assertEqual(res, (args[1], args[0])) 214 215 216def _get_weights_and_functional_call(net, mechanism): 217 if mechanism == "make_functional": 218 return make_functional(net) 219 else: 220 assert mechanism == "functional_call" 221 # this makes it so the function from make_functional and this call have the same signature 222 223 def net_func(weights, data): 224 return functional_call(net, weights, (data,)) 225 226 return net_func, dict(net.named_parameters()) 227 228 229def _get_weights_and_functional_call_with_buffers(net, mechanism): 230 if mechanism == "make_functional": 231 return make_functional_with_buffers(net) 232 else: 233 assert mechanism == "functional_call" 234 235 # this makes it so the function from make_functional and this call have the same signature 236 def net_func(weights, buffers, data): 237 return functional_call(net, (weights, buffers), (data,)) 238 239 return net_func, dict(net.named_parameters()), dict(net.named_buffers()) 240 241 242@markDynamoStrictTest 243class TestGradTransform(TestCase): 244 def test_primitive(self, device): 245 x = torch.randn([], device=device) 246 result = grad(torch.sin)(x) 247 self.assertEqual(result, torch.cos(x)) 248 249 def test_composite_simple(self, device): 250 x = torch.randn(2, 3, 4, device=device) 251 result = grad(lambda x: torch.flatten(x).sum())(x) 252 self.assertEqual(result, torch.ones_like(x)) 253 254 def test_fn_with_kwargs(self, device): 255 def foo(x, y): 256 return (x * y).sum() 257 258 x = torch.randn(3, device=device) 259 y = torch.randn(3, device=device) 260 expected = grad(foo)(x, y) 261 result = grad(foo)(x, y=y) 262 self.assertEqual(result, expected) 263 264 def test_composite_complicated(self, device): 265 x = torch.randn(3, device=device) 266 y = torch.randn(3, 5, device=device) 267 268 def foo(x, y): 269 result = x @ y 270 return result.sum() 271 272 result = grad(foo)(x, y) 273 274 x.requires_grad_() 275 out = foo(x, y) 276 (expected,) = torch.autograd.grad(out, x) 277 278 self.assertEqual(result, expected) 279 280 def test_composite_two_ops(self, device): 281 N, C = 2, 5 282 y = torch.randn(N, C, device=device) 283 targets = torch.randint(0, C, (N,), device=device) 284 285 def foo(y, targets): 286 return F.cross_entropy(y, targets) 287 288 result = grad(foo)(y, targets) 289 290 y.requires_grad_() 291 (expected,) = torch.autograd.grad(foo(y, targets), y) 292 293 self.assertEqual(result, expected) 294 295 def _test_attributes(self, get_attr_lambda, device): 296 x = torch.randn(2, 3, 5, dtype=torch.double, device=device) 297 expected = get_attr_lambda(x) 298 299 def foo(x): 300 self.assertEqual(get_attr_lambda(x), expected) 301 return x.sum() 302 303 grad(foo)(x) 304 305 def test_shape(self, device): 306 self._test_attributes(lambda x: x.shape, device) 307 308 def test_dtype(self, device): 309 self._test_attributes(lambda x: x.dtype, device) 310 311 def test_is_cuda(self, device): 312 self._test_attributes(lambda x: x.is_cuda, device) 313 314 def test_numel(self, device): 315 self._test_attributes(lambda x: x.numel(), device) 316 317 def test_inplace(self, device): 318 x = torch.randn([], device=device) 319 320 def foo(x): 321 return x.clone().sin_() 322 323 result = grad(foo)(x) 324 self.assertEqual(result, x.cos()) 325 326 def test_inplace_on_view(self, device): 327 x = torch.randn(3, device=device) 328 329 def foo(x): 330 y = x.clone() 331 y0 = y[0] 332 y0.sin_() 333 return y.sum() 334 335 result = grad(foo)(x) 336 337 x.requires_grad_() 338 out = foo(x) 339 (expected,) = torch.autograd.grad(out, x) 340 341 self.assertEqual(result, expected) 342 343 def test_inplace_on_view_base(self, device): 344 x = torch.randn(3, device=device) 345 346 def foo(x): 347 y = x.clone() 348 y0 = y[0] 349 y.sin_() 350 return y0 351 352 result = grad(foo)(x) 353 354 x.requires_grad_() 355 out = foo(x) 356 (expected,) = torch.autograd.grad(out, x) 357 358 self.assertEqual(result, expected) 359 360 def test_inplace_on_captures(self, device): 361 x = torch.tensor([1.0, 2.0, 3.0], device=device) 362 captured = torch.randn(3, device=device) 363 364 def foo(x): 365 captured.copy_(x) 366 return (x * captured).sum() 367 368 with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"): 369 grad(foo)(x) 370 371 def test_nesting_simple(self, device): 372 x = torch.randn([], device=device) 373 result = grad(grad(torch.sin))(x) 374 self.assertEqual(result, -torch.sin(x)) 375 376 @skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613") 377 def test_escaped_wrappers_are_marked_as_dead(self, device): 378 x = torch.randn([], device=device) 379 escaped = [] 380 381 def foo(x): 382 y = x.sin() 383 escaped.append(y) 384 return y 385 386 grad(foo)(x) 387 self.assertEqual(torch._C._functorch.dlevel(escaped[0]), -1) 388 389 @skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613") 390 def test_escaped_wrappers_are_ignored(self, device): 391 x = torch.randn([], device=device) 392 escaped = [] 393 394 def foo(x): 395 y = x.sin() 396 escaped.append(y) 397 return y 398 399 grad(foo)(x) 400 401 something = escaped[0].sum() 402 self.assertEqual(torch._C._functorch.dlevel(something), 0) 403 self.assertEqual(something, x.sin().sum()) 404 405 def test_manual_seed_inside_grad(self, device): 406 x = torch.randn([], device=device) 407 408 def f(x): 409 torch.manual_seed(0) 410 return x * torch.randn_like(x) 411 412 with freeze_rng_state(): 413 result = grad(f)(x) 414 x.requires_grad_() 415 (expected,) = torch.autograd.grad(f(x), x) 416 self.assertEqual(result, expected) 417 418 def test_vjp(self, device): 419 x = torch.randn([], device=device) 420 out, vjp_fn = vjp(torch.sin, x) 421 self.assertEqual(out, x.sin()) 422 423 v = torch.randn([], device=device) 424 (result,) = vjp_fn(v) 425 self.assertEqual(result, v * x.cos()) 426 427 def test_vjp_two_outputs(self, device): 428 def f(x): 429 return x, x 430 431 result, vjp_fn = vjp(f, torch.tensor(1.0)) 432 vjp_fn(result) 433 434 def test_conj_bit(self): 435 x = torch.tensor(1 + 1j) 436 437 def foo(x): 438 assert not x.is_conj() 439 y = x.conj() 440 assert y.is_conj() 441 return y.abs() 442 443 res = grad(foo)(x) 444 with torch.no_grad(): 445 self.assertEqual(res, torch.ones_like(res) * torch.sgn(x)) 446 447 def test_composed_with_autograd(self, device): 448 x = torch.randn([], requires_grad=True, device=device) 449 450 y = grad(torch.sin)(x) 451 (result,) = torch.autograd.grad(y, x) 452 self.assertEqual(result, -x.sin()) 453 454 def test_grad_of_vjp_composition(self, device): 455 x = torch.randn([], device=device) 456 y = torch.randn([], device=device) 457 458 def foo(x, y): 459 out, vjp_fn = vjp(torch.sin, x) 460 return grad(lambda y: vjp_fn(y)[0])(y) 461 462 result = foo(x, y) 463 expected = x.cos() 464 self.assertEqual(result, expected) 465 466 def test_vjp_of_grad_composition(self, device): 467 x = torch.randn([], device=device) 468 y = torch.randn([], device=device) 469 470 def foo(x, y): 471 out, vjp_fn = vjp(grad(torch.sin), x) 472 return vjp_fn(y)[0] 473 474 result = foo(x, y) 475 expected = -y * x.sin() 476 self.assertEqual(result, expected) 477 478 def test_grad_of_vjp_of_grad_composition(self, device): 479 x = torch.randn([], device=device) 480 y = torch.randn([], device=device) 481 482 def foo(x, y): 483 df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x) 484 return grad(lambda y: vjp_fn(y)[0])(y) 485 486 result = foo(x, y) 487 expected = x.cos() 488 self.assertEqual(result, expected) 489 490 def test_views(self, device): 491 x = torch.randn([], requires_grad=True, device=device) 492 y = torch.randn([], requires_grad=True, device=device) 493 494 def silly_sin(x): 495 x = x.view([]) 496 x = x.sin() 497 return x 498 499 def foo(x, y): 500 z1 = grad(silly_sin)(x) 501 z2 = torch.cos(y) 502 return z1 + z2 503 504 result = foo(x, y) 505 grads = torch.autograd.grad(result, [x, y]) 506 self.assertEqual(grads[0], -x.sin()) 507 self.assertEqual(grads[1], -y.sin()) 508 509 def test_view_inplace_simple(self, device): 510 def foo(x): 511 x = x.clone() 512 x.view([]).sin_() 513 return x 514 515 x = torch.randn([], requires_grad=True, device=device) 516 result = grad(foo)(x) 517 self.assertEqual(result, x.cos()) 518 519 def test_invalid_argnums(self, device): 520 x = torch.randn([]) 521 y = torch.randn([]) 522 with self.assertRaisesRegex(RuntimeError, "but only"): 523 grad(torch.mul, argnums=-3)(x, y) 524 with self.assertRaisesRegex(RuntimeError, "but only"): 525 grad(torch.mul, argnums=2)(x, y) 526 with self.assertRaisesRegex(RuntimeError, "int or Tuple"): 527 grad(torch.mul, argnums=[0])(x, y) 528 with self.assertRaisesRegex(RuntimeError, "must be int"): 529 grad(torch.mul, argnums=("0",))(x, y) 530 with self.assertRaisesRegex(RuntimeError, "must be unique"): 531 grad(torch.mul, argnums=(0, 0))(x, y) 532 with self.assertRaisesRegex(RuntimeError, "must be unique"): 533 grad(torch.mul, argnums=(0, -2))(x, y) 534 535 def test_argnums(self, device): 536 x = torch.randn([]) 537 y = torch.randn([]) 538 gx = grad(torch.mul, argnums=0)(x, y) 539 self.assertEqual(gx, y) 540 541 gy = grad(torch.mul, argnums=1)(x, y) 542 self.assertEqual(gy, x) 543 544 (gx,) = grad(torch.mul, argnums=(0,))(x, y) 545 self.assertEqual(gx, y) 546 547 gx, gy = grad(torch.mul, argnums=(0, 1))(x, y) 548 self.assertEqual(gx, y) 549 self.assertEqual(gy, x) 550 551 def test_out_of_order_argnums(self, device): 552 x = torch.randn([]) 553 y = torch.randn([]) 554 gy, gx = grad(torch.mul, argnums=(1, 0))(x, y) 555 self.assertEqual(gx, y) 556 self.assertEqual(gy, x) 557 558 def test_negative_argnums(self, device): 559 x = torch.randn([]) 560 y = torch.randn([]) 561 gx = grad(torch.mul, argnums=-2)(x, y) 562 self.assertEqual(gx, y) 563 564 gy = grad(torch.mul, argnums=-1)(x, y) 565 self.assertEqual(gy, x) 566 567 (gx,) = grad(torch.mul, argnums=(-2,))(x, y) 568 self.assertEqual(gx, y) 569 570 gx, gy = grad(torch.mul, argnums=(-2, -1))(x, y) 571 self.assertEqual(gx, y) 572 self.assertEqual(gy, x) 573 574 def test_grad_pytree_inputs(self, device): 575 x = torch.randn([], device=device) 576 577 def f(a, b): 578 x, y = a 579 return 1 * x + 2 * y + 3 * b["foo"] 580 581 args = ((x, x), {"foo": x}) 582 583 gx, gy = grad(f)(*args) 584 self.assertEqual(gx, torch.tensor(1.0, device=device)) 585 self.assertEqual(gy, torch.tensor(2.0, device=device)) 586 587 ((gx, gy),) = grad(f, argnums=(0,))(*args) 588 self.assertEqual(gx, torch.tensor(1.0, device=device)) 589 self.assertEqual(gy, torch.tensor(2.0, device=device)) 590 591 (gx, gy), gz = grad(f, argnums=(0, 1))(*args) 592 self.assertEqual(gx, torch.tensor(1.0, device=device)) 593 self.assertEqual(gy, torch.tensor(2.0, device=device)) 594 self.assertEqual(gz["foo"], torch.tensor(3.0, device=device)) 595 596 def test_grad_aux_tensor(self, device): 597 x = torch.randn(3, device=device) 598 599 with self.assertRaisesRegex( 600 RuntimeError, 601 r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple", 602 ): 603 grad(lambda t: [t, t], has_aux=True)(x) 604 605 with self.assertRaisesRegex( 606 RuntimeError, 607 r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple", 608 ): 609 grad(lambda t: (t, t + 2, t + 3), has_aux=True)(x) 610 611 def f(t): 612 y = t.sin() 613 return y.sum(), t.cos() 614 615 out, aux = grad(f, has_aux=True)(x) 616 self.assertEqual(aux, x.cos()) 617 self.assertEqual(out, x.cos()) 618 619 def test_grad_aux_pytree(self, device): 620 def f(x): 621 y = x.sin() 622 return y.sum(), {"a": x.cos(), "b": [x.tan()]} 623 624 x = torch.randn(3, device=device) 625 626 out, aux = grad(f, has_aux=True)(x) 627 _, expected_aux = f(x) 628 self.assertEqual(aux, expected_aux) 629 self.assertEqual(out, x.cos()) 630 631 for aux in [1, 1.0, "abc"]: 632 with self.assertRaisesRegex( 633 RuntimeError, r"Expected tensors, got unsupported type" 634 ): 635 _ = grad(lambda x: (x.sum(), aux), has_aux=True)(x) 636 with self.assertRaisesRegex( 637 RuntimeError, r"Expected tensors, got unsupported type" 638 ): 639 _ = grad(lambda x: (x.sum(), [x, aux]), has_aux=True)(x) 640 641 def test_zero_grad(self, device): 642 def f(x): 643 return (x["a"] ** 2.0).sum() 644 645 inps = { 646 "a": torch.randn(10, device=device) + 3, 647 "b": torch.randn(10, device=device), 648 } 649 grads = grad(f)(inps) 650 self.assertNotEqual(grads["a"].sum(), 0.0) 651 self.assertEqual(grads["b"].sum(), 0.0) 652 653 def test_unrelated_grad(self, device): 654 x = torch.tensor(1.0, device=device) 655 y = torch.tensor(2.0, device=device) 656 657 def unrelated(x): 658 return y 659 660 result = grad(unrelated)(x) 661 self.assertEqual(result, torch.zeros_like(x)) 662 663 def test_unrelated_vjp(self, device): 664 x = torch.tensor(1.0, device=device) 665 y = torch.tensor(2.0, device=device) 666 v = torch.tensor(1.0, device=device) 667 668 def unrelated(x): 669 return y 670 671 out, vjp_fn = vjp(unrelated, x) 672 result = vjp_fn(v) 673 expected = (torch.zeros_like(x),) 674 self.assertEqual(result, expected) 675 676 def test_unrelated_vjp_multiple_inputs_outputs(self, device): 677 w = torch.tensor(3.0, device=device) 678 x = torch.tensor(4.0, device=device) 679 y = torch.tensor(2.0, device=device) 680 v = torch.tensor(1.0, device=device) 681 682 def unrelated(w, x): 683 return y, y, x 684 685 out, vjp_fn = vjp(unrelated, w, x) 686 result = vjp_fn((v, v, v)) 687 expected = (torch.zeros_like(x), torch.ones_like(x)) 688 self.assertEqual(result, expected) 689 690 # TODO: https://github.com/zou3519/functorch/issues/12 691 @onlyCPU 692 def test_unrelated_hessian(self, device): 693 N = 5 694 M = 3 695 W = torch.randn(N, M, device=device) 696 697 def f(x): 698 return W @ x 699 700 x = torch.randn(M) 701 result = jacrev(jacrev(f))(x) 702 expected = torch.zeros(N, M, M, device=device) 703 self.assertEqual(result, expected) 704 705 def test_vjp_pytree_input(self, device): 706 def f(x): 707 return x[0] * x[1][0] 708 709 x = torch.randn([], device=device) 710 v = torch.randn([], device=device) 711 out, vjp_fn = vjp(f, (x, (x, x))) 712 self.assertEqual(out, x * x) 713 result = vjp_fn(v) 714 self.assertEqual(result, ((x * v, (x * v, 0.0)),)) 715 716 def test_vjp_pytree_output(self, device): 717 def f(x): 718 return x, (x, x) 719 720 x = torch.randn([], device=device) 721 v1 = torch.randn([], device=device) 722 v2 = torch.randn([], device=device) 723 v3 = torch.randn([], device=device) 724 _, vjp_fn = vjp(f, x) 725 (result,) = vjp_fn((v1, (v2, v3))) 726 self.assertEqual(result, v1 + v2 + v3) 727 728 def test_vjp_outputs_can_any_pytree(self, device): 729 x = torch.randn(2, 3, device=device) 730 t = torch.randn(2, 3, device=device) 731 732 for output in [None, ()]: 733 with self.assertRaisesRegex( 734 RuntimeError, 735 r"vjp\(f, \*primals\): Expected f to be a function that has non-empty output", 736 ): 737 _, vjp_fn = vjp(lambda _: output, x) 738 vjp_fn(t) 739 740 for output in [1, True, 12.2, "abc"]: 741 with self.assertRaisesRegex( 742 RuntimeError, 743 r"vjp\(f, \*primals\): expected f\(\*primals\) to return only tensors", 744 ): 745 _, vjp_fn = vjp(lambda _: output, x) 746 vjp_fn(t) 747 748 # Check list output 749 output, vjp_fn = vjp(lambda x: [x, x.sum()], x) 750 (vjp_out,) = vjp_fn([t, t.sum()]) 751 assert isinstance(output, list) and len(output) == 2 752 assert isinstance(vjp_out, torch.Tensor) 753 754 # Check dict output 755 output, vjp_fn = vjp(lambda x: {"x": x, "xsum": x.sum()}, x) 756 (vjp_out,) = vjp_fn({"x": t, "xsum": t.sum()}) 757 assert isinstance(output, dict) and len(output) == 2 and "xsum" in output 758 assert isinstance(vjp_out, torch.Tensor) 759 760 def composite_output(x): 761 out = x.sum() 762 return [ 763 (out, {"a": x, "out": [x, out]}), 764 ] 765 766 output, vjp_fn = vjp(composite_output, x) 767 (vjp_out,) = vjp_fn( 768 [ 769 (t.sum(), {"a": t, "out": [t, t.sum()]}), 770 ] 771 ) 772 assert isinstance(output, list) 773 assert isinstance(output[0], tuple) and isinstance(output[0][1], dict) 774 assert isinstance(vjp_out, torch.Tensor) 775 776 def test_vjp_pytree_error(self, device): 777 def f(x): 778 return x, (x, x) 779 780 x = torch.randn([], device=device) 781 v1 = torch.randn([], device=device) 782 v2 = torch.randn([], device=device) 783 v3 = torch.randn([], device=device) 784 _, vjp_fn = vjp(f, x) 785 with self.assertRaisesRegex(RuntimeError, "Expected pytree structure"): 786 (result,) = vjp_fn(((v1, (v2, v3)),)) 787 788 def test_vjp_aux_tensor(self, device): 789 x = torch.randn(3, device=device) 790 791 with self.assertRaisesRegex( 792 RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple" 793 ): 794 vjp(lambda t: [t, t], x, has_aux=True) 795 796 with self.assertRaisesRegex( 797 RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple" 798 ): 799 vjp(lambda t: (t, t + 2, t + 3), x, has_aux=True) 800 801 def f(t): 802 y = t.sin() 803 return y, t.cos() 804 805 out, vjp_fn, aux = vjp(f, x, has_aux=True) 806 self.assertEqual(aux, x.cos()) 807 self.assertEqual(out, x.sin()) 808 809 v = torch.randn(3, device=device) 810 (grad_x,) = vjp_fn(v) 811 self.assertEqual(grad_x, v * x.cos()) 812 813 def test_vjp_aux_pytree(self, device): 814 def f(x): 815 y = x.sin() 816 return y, {"a": x.cos(), "b": [x.tan()]} 817 818 x = torch.randn(3, device=device) 819 820 out, vjp_fn, aux = vjp(f, x, has_aux=True) 821 expected_out, expected_aux = f(x) 822 self.assertEqual(out, expected_out) 823 self.assertEqual(aux, expected_aux) 824 825 v = torch.randn(3, device=device) 826 (grad_x,) = vjp_fn(v) 827 self.assertEqual(grad_x, v * x.cos()) 828 829 for aux in [1, 1.0, "abc"]: 830 with self.assertRaisesRegex( 831 RuntimeError, r"Expected tensors, got unsupported type" 832 ): 833 _ = vjp(lambda x: (x, aux), x, has_aux=True) 834 with self.assertRaisesRegex( 835 RuntimeError, r"Expected tensors, got unsupported type" 836 ): 837 _ = vjp(lambda x: (x, [x, aux]), x, has_aux=True) 838 839 def test_functional_init(self, device): 840 class MLPClassifier(nn.Module): 841 def __init__(self, hidden_dim=32, n_classes=2): 842 super().__init__() 843 self.hidden_dim = hidden_dim 844 self.n_classes = n_classes 845 846 self.fc1 = nn.Linear(2, self.hidden_dim) 847 self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) 848 849 def forward(self, x): 850 x = self.fc1(x) 851 x = F.relu(x) 852 x = self.fc2(x) 853 x = F.log_softmax(x, -1) 854 return x 855 856 B = 10 857 weights, fn, _ = functional_init(MLPClassifier, (B,), device=device)(32, 2) 858 inputs = torch.randn(B, 7, 2, device=device) 859 vmap(fn)(weights, (inputs,)) 860 861 def test_functional_init_with_buffers(self, device): 862 class MLPClassifier(nn.Module): 863 def __init__(self, hidden_dim=32, n_classes=2): 864 super().__init__() 865 self.hidden_dim = hidden_dim 866 self.n_classes = n_classes 867 868 self.fc1 = nn.Linear(2, self.hidden_dim) 869 self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True) 870 self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) 871 872 def forward(self, x): 873 x = self.fc1(x) 874 x = F.relu(x) 875 x = self.bn(x) 876 x = self.fc2(x) 877 x = F.log_softmax(x, -1) 878 return x 879 880 B = 10 881 weights, buffers, fn, _, _ = functional_init_with_buffers( 882 MLPClassifier, [B], device=device 883 )(32, 2) 884 inputs = torch.randn(B, 7, 2, device=device) 885 vmap(fn)(weights, buffers, (inputs,)) 886 887 def test_advanced_indexing(self, device): 888 def f(value): 889 log_prob = torch.ones((), device=device) 890 val = torch.zeros(()) > 0 891 log_prob[val] = 0 892 return value 893 894 result = grad(f)(torch.randn((), device=device)) 895 self.assertEqual(result, torch.ones_like(result)) 896 897 def f2(value): 898 value = value.clone() 899 value[value > 0] = 0 900 return value.sum() 901 902 x = torch.randn(100, device=device) 903 result = grad(f2)(x) 904 self.assertEqual(result, (x <= 0).type_as(x)) 905 906 def test_tensor_ctor_inside_grad(self, device): 907 def foo(x): 908 return x * torch.tensor(2.0, device=device) 909 910 x = torch.tensor(3.14, device=device) 911 functorch.grad(foo)(x) 912 913 @parametrize( 914 "op_list_data", 915 [ 916 subtest( 917 ( 918 [ 919 vmap, 920 ], 921 [(4, 2), (64, 3, 32, 32)], 922 ), 923 name="vmap", 924 ), 925 subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name="vmap_vmap"), 926 subtest( 927 ( 928 [ 929 grad, 930 ], 931 [(0,), [], (4, 2), (64, 3, 32, 32)], 932 ), 933 name="grad", 934 ), 935 subtest( 936 ( 937 [grad, grad], 938 [ 939 [], 940 ], 941 ), 942 name="grad_grad", 943 ), 944 subtest(([vmap, grad], [(4, 2)]), name="vmap_grad"), 945 ], 946 ) 947 def test_tensor_print(self, device, op_list_data): 948 op_list, shapes = op_list_data 949 950 for dt in get_all_fp_dtypes(): 951 data = [torch.randn(s, dtype=dt, device=device) for s in shapes] 952 953 for x in data: 954 buf = None 955 956 def foo(t): 957 nonlocal buf 958 buf = repr(t) 959 return t.mean() 960 961 fn = foo 962 bdim = 0 963 for op in reversed(op_list): 964 if op == vmap: 965 fn = op(fn, in_dims=bdim) 966 bdim += 1 967 else: 968 fn = op(fn) 969 970 expected = f"{repr(x)}" 971 level = 0 972 for op in op_list: 973 level += 1 974 if op == grad: 975 expected = f"GradTrackingTensor(lvl={level}, value={expected})" 976 elif op == vmap: 977 bdim -= 1 978 expected = ( 979 f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})" 980 ) 981 982 fn(x) 983 buf = buf.replace("\n", "").replace(" ", "") 984 expected = expected.replace("\n", "").replace(" ", "") 985 self.assertEqual(expected, buf) 986 987 def test_print_captured_tensor_inside_transform(self, device): 988 x = torch.tensor([1.0, 2.0, 3.0], device=device) 989 out = None 990 991 def f(y): 992 nonlocal out 993 out = repr(x) 994 return y 995 996 vjp(f, torch.randn(4, device=device)) 997 self.assertEqual(out, repr(x)) 998 999 def test_no_grad_outside(self, device): 1000 x = torch.randn([], device=device, requires_grad=True) 1001 with torch.no_grad(): 1002 y = grad(torch.sin)(x) 1003 self.assertEqual(y, x.cos()) 1004 self.assertFalse(y.requires_grad) 1005 1006 def test_no_grad_inside(self, device): 1007 def f(x): 1008 with torch.no_grad(): 1009 shift = x**2 1010 return x**2 - shift 1011 1012 x = torch.randn([], device=device) 1013 y = grad(f)(x) 1014 self.assertEqual(y, 2 * x) 1015 y = grad(grad(f))(x) 1016 self.assertEqual(y, 2) 1017 1018 x = torch.randn([], device=device, requires_grad=True) 1019 y = grad(f)(x) 1020 (z,) = torch.autograd.grad(y, x) 1021 self.assertEqual(z, 2) 1022 1023 def test_no_grad_mixed(self, device): 1024 def f(x): 1025 with torch.no_grad(): 1026 shift = x**2 1027 return x**2 - shift 1028 1029 x = torch.randn([], device=device, requires_grad=True) 1030 with torch.no_grad(): 1031 y = grad(f)(x) 1032 1033 self.assertEqual(y, 2 * x) 1034 self.assertFalse(y.requires_grad) 1035 1036 def test_no_grad_nested_simple(self, device): 1037 def h(x): 1038 with torch.no_grad(): 1039 shift = grad(lambda x: 0.25 * x**4)(x) 1040 return x**3 - shift 1041 1042 x = torch.tensor(1.5, device=device, requires_grad=True) 1043 y = grad(h)(x) 1044 self.assertEqual(y, 3 * x**2) 1045 1046 (z,) = torch.autograd.grad(y, x) 1047 self.assertEqual(z, 6 * x) 1048 1049 def test_no_grad_nested_complicated(self, device): 1050 def f(x): 1051 with torch.no_grad(): 1052 shift = x**3 1053 return x**3 - shift 1054 1055 def g(x): 1056 r1 = grad(f)(x) 1057 with torch.no_grad(): 1058 shift = grad(f)(x) 1059 return r1 - shift 1060 1061 x = torch.randn([], requires_grad=True, device=device) 1062 y = grad(g)(x) 1063 # The only differential part of g is x ** 3 1064 self.assertEqual(y, 6 * x) 1065 1066 (z,) = torch.autograd.grad(y, x) 1067 self.assertEqual(z, 6) 1068 1069 def test_no_grad_value(self, device): 1070 def h(x): 1071 with torch.no_grad(): 1072 gvalue, value = grad_and_value(lambda x: x**3)(x) 1073 return x**3 - value 1074 1075 x = torch.tensor(1.6, device=device, requires_grad=True) 1076 y = grad(h)(x) 1077 self.assertEqual(y, 3 * x**2) 1078 1079 (z,) = torch.autograd.grad(y, x) 1080 self.assertEqual(z, 6 * x) 1081 1082 def test_no_grad_outside_vjp(self, device): 1083 def h(x): 1084 return x**2 1085 1086 x = torch.tensor(2.0, requires_grad=True, device=device) 1087 with torch.no_grad(): 1088 out, vjp_fn = vjp(h, x) 1089 (y,) = vjp_fn(torch.tensor(1.0, device=device)) 1090 1091 self.assertEqual(y, 2 * x) 1092 self.assertFalse(y.requires_grad) 1093 self.assertFalse(out.requires_grad) 1094 1095 def test_no_grad_outside_vjp_fn(self, device): 1096 def h(x): 1097 return x**2 1098 1099 x = torch.tensor(3.14, requires_grad=True, device=device) 1100 out, vjp_fn = vjp(h, x) 1101 with torch.no_grad(): 1102 (y,) = vjp_fn(torch.tensor(1.0, device=device)) 1103 1104 self.assertEqual(y, 2 * x) 1105 self.assertFalse(y.requires_grad) 1106 self.assertTrue(out.requires_grad) 1107 1108 (z,) = torch.autograd.grad(out, x) 1109 self.assertEqual(z, 2 * x) 1110 1111 def test_no_grad_outside_vjp_only(self, device): 1112 def h(x): 1113 return x**2 1114 1115 x = torch.tensor(3.14, requires_grad=True, device=device) 1116 with torch.no_grad(): 1117 out, vjp_fn = vjp(h, x) 1118 (y,) = vjp_fn(torch.tensor(1.0, device=device)) 1119 1120 self.assertEqual(y, 2 * x) 1121 self.assertFalse(out.requires_grad) 1122 1123 # This one is a little weird... 1124 self.assertTrue(y.requires_grad) 1125 1126 (z,) = torch.autograd.grad(y, x) 1127 self.assertEqual(z, 2) 1128 1129 1130@markDynamoStrictTest 1131class TestAutogradFunction(TestCase): 1132 def test_set_materialize_grads(self, device): 1133 class A(torch.autograd.Function): 1134 @staticmethod 1135 def forward(x, y): 1136 return x, y 1137 1138 @staticmethod 1139 def setup_context(ctx, inputs, output): 1140 ctx.set_materialize_grads(False) 1141 1142 @staticmethod 1143 def backward(ctx, gx, gy): 1144 self.assertIsNotNone(gx) 1145 self.assertIsNone(gy) 1146 return gx, gy 1147 1148 def f(y, x): 1149 x, y = A.apply(x, y) 1150 return x**2 1151 1152 x = torch.tensor(2.0, device=device) 1153 y = torch.tensor(3.0, device=device) 1154 # grad differentiates w.r.t. arg 0 by default 1155 grad(f)(y, x) 1156 grad(grad(f))(y, x) 1157 1158 @parametrize("inner_requires_grad", [True, False]) 1159 @parametrize("save_for", ["jvp", "vjp"]) 1160 @parametrize("save_tensors", ["input", "output", "neither"]) 1161 @parametrize("mark_dirty", [True, False]) 1162 def test_function_returns_input( 1163 self, device, inner_requires_grad, save_for, save_tensors, mark_dirty 1164 ): 1165 class A(torch.autograd.Function): 1166 @staticmethod 1167 def forward(x): 1168 return x 1169 1170 @staticmethod 1171 def setup_context(ctx, inputs, output): 1172 if save_for == "jvp": 1173 save_fn = ctx.save_for_forward 1174 else: 1175 save_fn = ctx.save_for_backward 1176 1177 if mark_dirty: 1178 ctx.mark_dirty(inputs[0]) 1179 1180 if save_tensors == "input": 1181 save_fn(inputs[0]) 1182 elif save_tensors == "output": 1183 save_fn(output) 1184 elif save_tensors == "neither": 1185 pass 1186 1187 @staticmethod 1188 def backward(ctx, grad_output): 1189 return grad_output 1190 1191 @staticmethod 1192 def jvp(ctx, x_t): 1193 # NB: the logic to check ctx.save_for_forward happens 1194 # before we reach this! 1195 if mark_dirty: 1196 ret = x_t.add_(0) 1197 else: 1198 ret = x_t.view_as(x_t) 1199 return ret 1200 1201 def fn(x): 1202 return A.apply(x.clone()) 1203 1204 err_msg = "A input that has been returned as-is" 1205 1206 a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad) 1207 a_t = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad) 1208 if save_tensors in ("input", "output") and not mark_dirty: 1209 with self.assertRaisesRegex(RuntimeError, err_msg): 1210 grad(fn)(a) 1211 with self.assertRaisesRegex(RuntimeError, err_msg): 1212 jvp(fn, (a,), (a_t,)) 1213 else: 1214 grad(fn)(a) 1215 jvp(fn, (a,), (a_t,)) 1216 1217 a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad).clone() 1218 a_t = torch.tensor( 1219 2.0, device=device, requires_grad=inner_requires_grad 1220 ).clone() 1221 1222 if save_tensors in ("input", "output") and not mark_dirty: 1223 with self.assertRaisesRegex(RuntimeError, err_msg): 1224 A.apply(a) 1225 with self.assertRaisesRegex(RuntimeError, err_msg): 1226 with fwAD.dual_level(): 1227 A.apply(fwAD.make_dual(a, a_t)) 1228 else: 1229 b = A.apply(a) 1230 if mark_dirty: 1231 self.assertTrue(a is b) 1232 if not ( 1233 mark_dirty and save_for == "vjp" and save_tensors in ("input", "output") 1234 ): 1235 # TODO(soulitzer): https://github.com/pytorch/pytorch/issues/97827 1236 with fwAD.dual_level(): 1237 a_dual = fwAD.make_dual(a, a_t) 1238 b_dual = A.apply(a_dual) 1239 if mark_dirty: 1240 self.assertTrue(a_dual is b_dual) 1241 1242 def test_needs_input_grads(self, device): 1243 class A(torch.autograd.Function): 1244 @staticmethod 1245 def forward(x, y): 1246 return x * y 1247 1248 @staticmethod 1249 def setup_context(ctx, inputs, output): 1250 return 1251 1252 @staticmethod 1253 def backward(ctx, grad_output): 1254 self.assertTrue(ctx.needs_input_grad[0]) 1255 self.assertFalse(ctx.needs_input_grad[1]) 1256 return None, None 1257 1258 x = torch.tensor(2.0, device=device) 1259 y = torch.tensor(3.0, device=device) 1260 # grad differentiates w.r.t. arg 0 by default 1261 grad(A.apply)(x, y) 1262 grad(grad(A.apply))(x, y) 1263 1264 def _get_NumpyCubeNotComposable(self): 1265 class NumpyCubeNotComposable(torch.autograd.Function): 1266 @staticmethod 1267 def forward(input): 1268 input_np = input.cpu().numpy() 1269 return torch.tensor(input_np**3, device=input.device), input_np 1270 1271 @staticmethod 1272 def setup_context(ctx, inputs, output): 1273 ctx.input_np = output[1] 1274 ctx.device = inputs[0].device 1275 1276 @staticmethod 1277 @torch.autograd.function.once_differentiable 1278 def backward(ctx, grad_output, grad_saved): 1279 result_np = 3 * (ctx.input_np**2) 1280 return torch.tensor(result_np, device=ctx.device) 1281 1282 return NumpyCubeNotComposable 1283 1284 def test_once_differentiable_autograd_vjp(self, device): 1285 NumpyCubeNotComposable = self._get_NumpyCubeNotComposable() 1286 1287 def f(x): 1288 y, _ = NumpyCubeNotComposable.apply(x) 1289 return y 1290 1291 # regular autograd x vjp 1292 x = torch.randn([], requires_grad=True, device=device) 1293 grad_y = torch.randn_like(x, requires_grad=True) 1294 _, vjp_fn = vjp(f, x) 1295 (gx,) = vjp_fn(grad_y) 1296 1297 with self.assertRaisesRegex(RuntimeError, "marked with @once_differentiable"): 1298 gx.backward() 1299 1300 # TODO: support torch.autograd.function.once_differentiable 1301 # (or, if impossible, figure out how to raise a nice error) 1302 # https://github.com/pytorch/pytorch/issues/90224 1303 @unittest.expectedFailure 1304 def test_once_differentiable_grad_vjp(self, device): 1305 NumpyCubeNotComposable = self._get_NumpyCubeNotComposable() 1306 1307 # grad x vjp 1308 x = torch.randn([], device=device) 1309 grad_y = torch.randn_like(x) 1310 1311 def h(x, grad_y): 1312 _, vjp_fn = vjp(f, x) # noqa: F821 1313 (gx,) = vjp_fn(grad_y) 1314 return gx 1315 1316 grad(h, argnums=(0, 1))(x, grad_y) 1317 1318 def test_grad_fn_name(self, device): 1319 names = [] 1320 1321 class FooBar(torch.autograd.Function): 1322 @staticmethod 1323 def forward(x): 1324 return x.clone() 1325 1326 @staticmethod 1327 def setup_context(ctx, inputs, output): 1328 return 1329 1330 @staticmethod 1331 def backward(ctx, grad_output): 1332 return grad_output 1333 1334 def f(x): 1335 y = FooBar.apply(x) 1336 names.append(type(y.grad_fn).__name__) 1337 return y 1338 1339 x = torch.tensor(1.0) 1340 grad(f)(x) 1341 self.assertEqual(names, ["FooBarGeneratedBackward"]) 1342 1343 1344@markDynamoStrictTest 1345class TestAutogradFunctionVmapAPI(TestCase): 1346 def test_no_vmap_staticmethod_and_no_generate_vmap_rule(self, device): 1347 class NumpyCube(torch.autograd.Function): 1348 @staticmethod 1349 def forward(input): 1350 input_np = to_numpy(input) # noqa: F821 1351 dinput = torch.tensor(3 * input_np**2, device=input.device) 1352 return torch.tensor(input_np**3, device=input.device), dinput 1353 1354 @staticmethod 1355 def setup_context(ctx, inputs, output): 1356 ctx.save_for_backward(inputs, output[1]) 1357 1358 @staticmethod 1359 def backward(ctx, grad_output, grad_saved): 1360 raise RuntimeError("foobar") 1361 1362 x = torch.randn(3, device=device) 1363 with self.assertRaisesRegex(RuntimeError, "does not have vmap support"): 1364 vmap(NumpyCube.apply)(x) 1365 1366 def test_has_vmap_staticmethod_and_has_generate_vmap_rule(self, device): 1367 class NumpyCube(torch.autograd.Function): 1368 generate_vmap_rule = True 1369 1370 @staticmethod 1371 def forward(input): 1372 input_np = to_numpy(input) # noqa: F821 1373 dinput = torch.tensor(3 * input_np**2, device=input.device) 1374 return torch.tensor(input_np**3, device=input.device), dinput 1375 1376 @staticmethod 1377 def setup_context(ctx, outputs, input): 1378 ctx.save_for_backward(input, outputs[1]) 1379 1380 @staticmethod 1381 def backward(ctx, grad_output, grad_saved): 1382 raise RuntimeError("foobar") 1383 1384 @staticmethod 1385 def vmap(infos, in_dims, x): 1386 raise RuntimeError("foobar") 1387 1388 x = torch.randn(3, device=device) 1389 with self.assertRaisesRegex(RuntimeError, "generate_vmap_rule=True and"): 1390 vmap(NumpyCube.apply)(x) 1391 1392 def test_info_object(self, device): 1393 batch_size = 10 1394 1395 class Id(torch.autograd.Function): 1396 @staticmethod 1397 def forward(input): 1398 pass 1399 1400 @staticmethod 1401 def setup_context(ctx, inputs, output): 1402 pass 1403 1404 @staticmethod 1405 def backward(ctx, grad_output, grad_saved): 1406 pass 1407 1408 @staticmethod 1409 def vmap(info, in_dims, input): 1410 self.assertEqual(info.batch_size, batch_size) 1411 self.assertEqual(info.randomness, randomness) 1412 return input, in_dims[0] 1413 1414 x = torch.randn(batch_size, 3, device=device) 1415 1416 for randomness in ("error", "different", "same"): 1417 vmap(Id.apply, randomness=randomness)(x) 1418 1419 def test_in_dims_single_input(self, device): 1420 class Id(torch.autograd.Function): 1421 @staticmethod 1422 def forward(input): 1423 pass 1424 1425 @staticmethod 1426 def setup_context(ctx, inputs, output): 1427 pass 1428 1429 @staticmethod 1430 def backward(ctx, grad_output, grad_saved): 1431 pass 1432 1433 @staticmethod 1434 def vmap(info, in_dims, input): 1435 self.assertEqual(in_dims, (1,)) 1436 return input, in_dims[0] 1437 1438 B = 10 1439 x = torch.randn(3, B, device=device) 1440 vmap(Id.apply, in_dims=1)(x) 1441 vmap(Id.apply, in_dims=(1,))(x) 1442 1443 def test_in_dims_multiple_inputs(self, device): 1444 class Id(torch.autograd.Function): 1445 @staticmethod 1446 def forward(x, y): 1447 pass 1448 1449 @staticmethod 1450 def setup_context(ctx, inputs, output): 1451 pass 1452 1453 @staticmethod 1454 def backward(ctx, grad_output, grad_saved): 1455 pass 1456 1457 @staticmethod 1458 def vmap(info, in_dims, x, y): 1459 self.assertEqual(in_dims, (0, [0, 0])) 1460 self.assertTrue(isinstance(in_dims, tuple)) 1461 self.assertTrue(isinstance(in_dims[1], list)) 1462 return (x, y), in_dims 1463 1464 x = torch.randn(2, device=device) 1465 vmap(Id.apply)(x, [x, x]) 1466 1467 def test_skips_empty_layer(self, device): 1468 class Id(torch.autograd.Function): 1469 @staticmethod 1470 def forward(input): 1471 return input 1472 1473 @staticmethod 1474 def setup_context(ctx, inputs, output): 1475 pass 1476 1477 @staticmethod 1478 def backward(ctx, grad_output, grad_saved): 1479 pass 1480 1481 @staticmethod 1482 def vmap(info, in_dims, input): 1483 raise RuntimeError("expected to not be called") 1484 1485 def f(x): 1486 y = torch.tensor(1.0) 1487 y = Id.apply(y) 1488 return x * 1 1489 1490 x = torch.randn(2, 3) 1491 vmap(f)(x) 1492 1493 def test_none_returns(self, device): 1494 class Zeros(torch.autograd.Function): 1495 @staticmethod 1496 def forward(input): 1497 return torch.zeros(input.shape, device=input.device) 1498 1499 @staticmethod 1500 def setup_context(ctx, inputs, output): 1501 pass 1502 1503 @staticmethod 1504 def vmap(info, in_dims, input): 1505 assert in_dims == (0,) 1506 return torch.zeros(input.shape[1:], device=input.device), None 1507 1508 B = 2 1509 x = torch.randn(B, 3) 1510 y = vmap(Zeros.apply)(x) 1511 self.assertEqual(y, torch.zeros_like(x)) 1512 1513 class TwoZeros(torch.autograd.Function): 1514 @staticmethod 1515 def forward(input): 1516 r = torch.zeros(input.shape, device=input.device) 1517 return r, r 1518 1519 @staticmethod 1520 def setup_context(ctx, inputs, output): 1521 pass 1522 1523 @staticmethod 1524 def vmap(info, in_dims, input): 1525 assert in_dims == (0,) 1526 r = torch.zeros(input.shape[1:], device=input.device) 1527 return (r, r), None 1528 1529 B = 2 1530 x = torch.randn(B, 3) 1531 result = vmap(TwoZeros.apply)(x) 1532 1533 self.assertTrue(isinstance(result, tuple)) 1534 y, z = result 1535 self.assertEqual(y, torch.zeros_like(x)) 1536 self.assertEqual(z, torch.zeros_like(x)) 1537 1538 def test_should_have_two_returns(self, device): 1539 class Zeros(torch.autograd.Function): 1540 @staticmethod 1541 def forward(input): 1542 r = torch.zeros(input.shape, device=input.device) 1543 return r 1544 1545 @staticmethod 1546 def setup_context(ctx, inputs, output): 1547 pass 1548 1549 @staticmethod 1550 def vmap(info, in_dims, input): 1551 r = torch.zeros(input.shape[1:], device=input.device) 1552 return r 1553 1554 B = 2 1555 x = torch.randn(B, 3) 1556 with self.assertRaisesRegex(RuntimeError, "to have two returns"): 1557 result = vmap(Zeros.apply)(x) 1558 1559 class TwoZeros(torch.autograd.Function): 1560 @staticmethod 1561 def forward(input): 1562 r = torch.zeros(input.shape, device=input.device) 1563 return r, r 1564 1565 @staticmethod 1566 def setup_context(ctx, inputs, output): 1567 pass 1568 1569 @staticmethod 1570 def vmap(info, in_dims, input): 1571 r = torch.zeros(input.shape[1:], device=input.device) 1572 return r, r, 0, 0 1573 1574 B = 2 1575 x = torch.randn(B, 3) 1576 with self.assertRaisesRegex(RuntimeError, "to have two returns"): 1577 result = vmap(Zeros.apply)(x) 1578 1579 def test_incompatible_out_dims_error_msg(self, device): 1580 class Zeros(torch.autograd.Function): 1581 @staticmethod 1582 def forward(input): 1583 r = torch.zeros(input.shape, device=input.device) 1584 return r 1585 1586 @staticmethod 1587 def setup_context(ctx, inputs, output): 1588 pass 1589 1590 @staticmethod 1591 def vmap(info, in_dims, input): 1592 r = torch.zeros(input.shape[1:], device=input.device) 1593 return r, (None,) 1594 1595 B = 2 1596 x = torch.randn(B, 3) 1597 with self.assertRaisesRegex(RuntimeError, "returned an incompatible"): 1598 result = vmap(Zeros.apply)(x) 1599 1600 class Zeros(torch.autograd.Function): 1601 @staticmethod 1602 def forward(input): 1603 r = torch.zeros(input.shape, device=input.device) 1604 return [r] 1605 1606 @staticmethod 1607 def setup_context(ctx, inputs, output): 1608 pass 1609 1610 @staticmethod 1611 def vmap(info, in_dims, input): 1612 r = torch.zeros(input.shape[1:], device=input.device) 1613 return [r], (None,) 1614 1615 B = 2 1616 x = torch.randn(B, 3) 1617 with self.assertRaisesRegex(RuntimeError, "returned an incompatible"): 1618 result = vmap(Zeros.apply)(x) 1619 1620 def test_kwarg_only_tensors(self, device): 1621 with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 1622 1623 class MyClass(torch.autograd.Function): 1624 @staticmethod 1625 def forward(x, *, y): 1626 return x + y 1627 1628 @staticmethod 1629 def setup_context(ctx, inputs, output): 1630 pass 1631 1632 @staticmethod 1633 def vmap(info, in_dims, x, *, y): 1634 assert in_dims == (0,) 1635 return x + y, 0 1636 1637 x = torch.randn(3) 1638 y = torch.randn(3) 1639 1640 vmap(MyClass.apply)(x, y=y) 1641 1642 1643@markDynamoStrictTest 1644class TestVmapOfGrad(TestCase): 1645 def test_per_sample_grads_inplace_view(self, device): 1646 def compute_loss(weight, x, t): 1647 x = x.mm(weight) 1648 y = x.squeeze_(0) 1649 return (y - t).sum() 1650 1651 weight = torch.randn(16, 2, device=device) 1652 x = torch.randn(64, 1, 16, device=device) 1653 t = torch.randn(64, 2, device=device) 1654 result = vmap(partial(grad(compute_loss), weight))(x, t) 1655 expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] 1656 expected = torch.stack(expected) 1657 # TODO: Check if the rtol is a problem 1658 self.assertEqual(result, expected, atol=0, rtol=5e-4) 1659 1660 def test_new_zeros_materializes_tensor(self, device): 1661 N = 3 1662 C = 5 1663 1664 def foo(y, x): 1665 result = x.new_zeros((C,)) 1666 result.copy_(y) 1667 return result.sum() 1668 1669 x = torch.randn(N, device=device) 1670 y = torch.randn(N, C, device=device) 1671 result = vmap(grad(foo))(y, x) 1672 self.assertEqual(result, torch.ones_like(y)) 1673 1674 def test_new_empty_materializes_tensor(self, device): 1675 N = 3 1676 C = 5 1677 1678 def foo(y, x): 1679 result = x.new_empty((C,)) 1680 result.copy_(y) 1681 return result.sum() 1682 1683 x = torch.randn(N, device=device) 1684 y = torch.randn(N, C, device=device) 1685 result = vmap(grad(foo))(y, x) 1686 self.assertEqual(result, torch.ones_like(y)) 1687 1688 def test_per_sample_grads_simple(self, device): 1689 def compute_loss(weight, x, t): 1690 y = x @ weight 1691 return ((y - t) ** 2).sum() 1692 1693 weight = torch.randn(16, 2, device=device) 1694 x = torch.randn(64, 16, device=device) 1695 t = torch.randn(64, 2, device=device) 1696 result = vmap(partial(grad(compute_loss), weight))(x, t) 1697 expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] 1698 expected = torch.stack(expected) 1699 # TODO: Check if the rtol is a problem 1700 self.assertEqual(result, expected, atol=0, rtol=5e-4) 1701 1702 def _compare_expected_and_result(self, expected, result, mechanism): 1703 if mechanism == "make_functional": 1704 expected = zip(*expected) 1705 expected = tuple(torch.stack(shards) for shards in expected) 1706 for r, e in zip(result, expected): 1707 self.assertEqual(r, e, atol=0, rtol=1.5e-3) 1708 else: 1709 assert mechanism == "functional_call" 1710 expected = { 1711 k: tuple(d[k] for d in expected) for k, v in expected[0].items() 1712 } 1713 expected = {k: torch.stack(shards) for k, shards in expected.items()} 1714 for key in result: 1715 self.assertEqual(result[key], expected[key], atol=0, rtol=1.5e-3) 1716 1717 @tf32_on_and_off(0.005) 1718 @parametrize("mechanism", ["make_functional", "functional_call"]) 1719 def test_per_sample_grads_embeddingnet(self, device, mechanism): 1720 class SampleNet(nn.Module): 1721 def __init__(self, vocab_size: int): 1722 super().__init__() 1723 self.emb = nn.Embedding(vocab_size, 16) 1724 self.fc1 = nn.Linear(16, 16) 1725 self.fc2 = nn.Linear(16, 2) 1726 1727 def forward(self, x): 1728 x = self.emb(x) 1729 x = torch.transpose(x, -1, -2) 1730 x = torch.mean(x, -1) 1731 x = self.fc1(x) 1732 x = F.relu(x) 1733 x = self.fc2(x) 1734 return x 1735 1736 def name(self): 1737 return "SampleNet" 1738 1739 # Create our inputs... 1740 vocab_size = 1000 1741 batch_shape = [64] 1742 words_per_sentence = 5 1743 data = torch.randint( 1744 0, vocab_size, (*batch_shape, words_per_sentence), device=device 1745 ) 1746 targets = torch.randint(0, 1, (*batch_shape,), device=device) 1747 1748 # Construct our module 1749 net = SampleNet(vocab_size).to(device=device) 1750 criterion = nn.CrossEntropyLoss() 1751 1752 net_func, weights = _get_weights_and_functional_call(net, mechanism) 1753 1754 def compute_loss(weights, data, target): 1755 output = net_func(weights, data) 1756 result = criterion(output, target) 1757 return result 1758 1759 expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)] 1760 result = vmap(partial(grad(compute_loss), weights))(data, targets) 1761 self._compare_expected_and_result(expected, result, mechanism) 1762 1763 def test_log_softmax(self, device): 1764 x = torch.randn(3, 5, device=device) 1765 v = torch.randn(5, device=device) 1766 1767 def foo(x, v): 1768 _, vjp_fn = vjp(partial(torch.log_softmax, dim=-1), x) 1769 return vjp_fn(v)[0] 1770 1771 result = vmap(foo, (0, None))(x, v) 1772 1773 v = v.expand_as(x) 1774 x.requires_grad_() 1775 output = torch.log_softmax(x, dim=-1) 1776 output.backward(v) 1777 self.assertEqual(result, x.grad) 1778 1779 1780jacrev_and_jacfwd = parametrize( 1781 "jacapi", [subtest(jacrev, name="jacrev"), subtest(jacfwd, name="jacfwd")] 1782) 1783 1784FIXME_jacrev_only = parametrize("jacapi", [subtest(jacrev, name="jacrev")]) 1785 1786 1787@markDynamoStrictTest 1788class TestJac(VmapTearDownMixin, TestCase): 1789 @jacrev_and_jacfwd 1790 def test_simple(self, device, jacapi): 1791 x = torch.randn(3, device=device) 1792 y = jacapi(torch.sin)(x) 1793 expected = torch.diagflat(x.cos()) 1794 assert torch.allclose(y, expected) 1795 1796 @jacrev_and_jacfwd 1797 def test_simple_not_flat(self, device, jacapi): 1798 x = torch.randn(2, 3, device=device) 1799 y = jacapi(torch.sin)(x) 1800 expected = torch.diagflat(x.view(-1).cos()) 1801 expected = expected.view(2, 3, 2, 3) 1802 assert torch.allclose(y, expected) 1803 1804 @jacrev_and_jacfwd 1805 def test_take(self, device, jacapi): 1806 x = torch.rand(5) 1807 1808 def func(x): 1809 y = torch.ones(3, dtype=torch.long) 1810 z = torch.take(x, y) 1811 return z 1812 1813 self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x)) 1814 1815 @jacrev_and_jacfwd 1816 def test_diff_numel(self, device, jacapi): 1817 x = torch.randn(2, 4, device=device) 1818 1819 # Tensor[2, 4] -> Tensor[3, 1] 1820 def f(x): 1821 return x[0, 1:].unsqueeze(-1) 1822 1823 y = jacapi(f)(x) 1824 self.assertEqual(y.shape, (3, 1, 2, 4)) 1825 1826 expected = x.new_zeros(3, 1, 2, 4) 1827 expected[0, 0, 0, 1] = 1 1828 expected[1, 0, 0, 2] = 1 1829 expected[2, 0, 0, 3] = 1 1830 self.assertEqual(y, expected) 1831 1832 @jacrev_and_jacfwd 1833 def test_vmap_on_jac_simple(self, device, jacapi): 1834 x = torch.randn(2, 3, device=device) 1835 y = vmap(jacapi(torch.sin))(x) 1836 expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)]) 1837 assert torch.allclose(y, expected) 1838 1839 @jacrev_and_jacfwd 1840 def test_nested_jac_simple(self, device, jacapi): 1841 def foo(x): 1842 return x.sin().sum() 1843 1844 x = torch.randn(3, device=device) 1845 y = jacapi(jacapi(foo))(x) 1846 expected = torch.diagflat(-x.sin()) 1847 assert torch.allclose(y, expected) 1848 1849 @jacrev_and_jacfwd 1850 def test_multiple_args(self, device, jacapi): 1851 x = torch.randn(3, device=device) 1852 y = torch.randn(3, device=device) 1853 z = jacapi(torch.multiply, argnums=1)(x, y) 1854 expected = torch.diagflat(x) 1855 assert torch.allclose(z, expected) 1856 1857 @jacrev_and_jacfwd 1858 def test_multiple_outputs_multiple_argnums(self, device, jacapi): 1859 def f(x, y): 1860 return 2 * x + 3 * y, 4 * x + 5 * y 1861 1862 x = torch.randn(3, device=device) 1863 y = torch.randn(3, device=device) 1864 z = jacapi(f, argnums=(0, 1))(x, y) 1865 expected_out0_x = torch.diagflat(torch.full_like(x, 2)) 1866 expected_out0_y = torch.diagflat(torch.full_like(y, 3)) 1867 expected_out1_x = torch.diagflat(torch.full_like(x, 4)) 1868 expected_out1_y = torch.diagflat(torch.full_like(y, 5)) 1869 1870 self.assertEqual(len(z), 2) 1871 self.assertTrue(isinstance(z, tuple)) 1872 self.assertEqual(len(z[0]), 2) 1873 self.assertTrue(isinstance(z[0], tuple)) 1874 self.assertEqual(z[0][0], expected_out0_x) 1875 self.assertEqual(z[0][1], expected_out0_y) 1876 self.assertEqual(z[1][0], expected_out1_x) 1877 self.assertEqual(z[1][1], expected_out1_y) 1878 1879 @jacrev_and_jacfwd 1880 def test_multiple_outputs_single_argnums(self, device, jacapi): 1881 def f(x, y): 1882 return 2 * x + 3 * y, 4 * x + 5 * y 1883 1884 x = torch.randn(3, device=device) 1885 y = torch.randn(3, device=device) 1886 expected_out0_x = torch.diagflat(torch.full_like(x, 2)) 1887 expected_out1_x = torch.diagflat(torch.full_like(x, 4)) 1888 1889 z = jacapi(f, argnums=0)(x, y) 1890 self.assertEqual(len(z), 2) 1891 self.assertTrue(isinstance(z, tuple)) 1892 self.assertEqual(z, (expected_out0_x, expected_out1_x)) 1893 1894 z = jacapi(f, argnums=(0,))(x, y) 1895 self.assertEqual(len(z), 2) 1896 self.assertTrue(isinstance(z, tuple)) 1897 self.assertTrue(isinstance(z[0], tuple)) 1898 self.assertEqual(z, ((expected_out0_x,), (expected_out1_x,))) 1899 1900 @jacrev_and_jacfwd 1901 def test_multiple_outputs_pytree(self, device, jacapi): 1902 def f(x, y): 1903 return {"left": 2 * x + 3 * y, "right": 4 * x + 5 * y} 1904 1905 x = torch.randn(3, device=device) 1906 y = torch.randn(3, device=device) 1907 z = jacapi(f, argnums=(0, 1))(x, y) 1908 expected_left_x = torch.diagflat(torch.full_like(x, 2)) 1909 expected_left_y = torch.diagflat(torch.full_like(y, 3)) 1910 expected_right_x = torch.diagflat(torch.full_like(x, 4)) 1911 expected_right_y = torch.diagflat(torch.full_like(y, 5)) 1912 expected = { 1913 "left": (expected_left_x, expected_left_y), 1914 "right": (expected_right_x, expected_right_y), 1915 } 1916 self.assertTrue(isinstance(z, dict)) 1917 self.assertTrue(isinstance(z["left"], tuple)) 1918 self.assertTrue(isinstance(z["right"], tuple)) 1919 self.assertEqual(z, expected) 1920 1921 @jacrev_and_jacfwd 1922 def test_multiple_inputs_pytree(self, device, jacapi): 1923 def f(a, b, c): 1924 a0, a1 = a 1925 return a0 + a1 * 2 + b * 3 + c * 4 1926 1927 x = torch.randn([], device=device) 1928 args = ((x, x), x, x) 1929 1930 result = jacapi(f, argnums=(0, 1, 2))(*args) 1931 expected = ( 1932 (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), 1933 torch.tensor(3.0, device=device), 1934 torch.tensor(4.0, device=device), 1935 ) 1936 self.assertEqual(result, expected) 1937 1938 result = jacapi(f, argnums=(0,))(*args) 1939 expected = ( 1940 (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), 1941 ) 1942 self.assertEqual(result, expected) 1943 1944 result = jacapi(f)(*args) 1945 expected = (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)) 1946 self.assertEqual(result, expected) 1947 1948 @jacrev_and_jacfwd 1949 def test_dimensionality(self, device, jacapi): 1950 def f(x): 1951 return x 1952 1953 x = torch.randn([], device=device) 1954 result = jacapi(f)(x) 1955 self.assertEqual(result.dim(), 0) 1956 self.assertEqual(result, torch.ones_like(x)) 1957 1958 x = torch.randn([1], device=device) 1959 result = jacapi(f)(x) 1960 self.assertEqual(result.dim(), 2) 1961 self.assertEqual(result, x.new_ones(1, 1)) 1962 1963 @jacrev_and_jacfwd 1964 def test_aux_tensor(self, device, jacapi): 1965 def f(x): 1966 y = x.clone() 1967 return y, y.cos() 1968 1969 x = torch.randn(3, device=device) 1970 result, aux = jacapi(f, has_aux=True)(x) 1971 1972 self.assertEqual(result, torch.eye(3, 3, device=device)) 1973 self.assertEqual(aux, x.cos()) 1974 1975 @jacrev_and_jacfwd 1976 def test_aux_pytree(self, device, jacapi): 1977 def f(x): 1978 y = x.clone() 1979 return y, {"a": y.cos(), "b": [y.tan()]} 1980 1981 x = torch.randn(3, device=device) 1982 1983 result, aux = jacapi(f, has_aux=True)(x) 1984 self.assertEqual(result, torch.eye(3, 3, device=device)) 1985 _, expected_aux = f(x) 1986 self.assertEqual(aux, expected_aux) 1987 1988 for aux in [1, 1.0, "abc"]: 1989 with self.assertRaisesRegex( 1990 RuntimeError, r"Expected tensors, got unsupported type" 1991 ): 1992 _ = jacapi(lambda x: (x, aux), has_aux=True)(x) 1993 with self.assertRaisesRegex( 1994 RuntimeError, r"Expected tensors, got unsupported type" 1995 ): 1996 _ = jacapi(lambda x: (x, [x, aux]), has_aux=True)(x) 1997 1998 @jacrev_and_jacfwd 1999 def test_outputs_can_any_pytree(self, device, jacapi): 2000 x = torch.randn(2, 3, device=device) 2001 2002 for output in [None, ()]: 2003 with self.assertRaisesRegex( 2004 RuntimeError, 2005 r"(vjp|jvp).+: Expected f to be a function that has non-empty output", 2006 ): 2007 jacapi(lambda _: output)(x) 2008 2009 for output in [1, True, 12.2, "abc"]: 2010 with self.assertRaisesRegex( 2011 RuntimeError, 2012 r"(vjp|jvp).+: expected f\(\*primals\) to return only tensors", 2013 ): 2014 jacapi(lambda _: output)(x) 2015 2016 # Check list output 2017 out = jacapi(lambda x: [x, x.sum()])(x) 2018 assert isinstance(out, list) and len(out) == 2 2019 2020 # Check dict output 2021 out = jacapi(lambda x: {"x": x, "xsum": x.sum()})(x) 2022 assert isinstance(out, dict) and len(out) == 2 and "xsum" in out 2023 2024 def composite_output(x): 2025 out = x.sum() 2026 return [ 2027 (out, {"a": x, "out": [x, out]}), 2028 ] 2029 2030 out = jacapi(composite_output)(x) 2031 assert isinstance(out, list) 2032 assert isinstance(out[0], tuple) and isinstance(out[0][1], dict) 2033 2034 @jacrev_and_jacfwd 2035 def test_multiple_inputs_outputs_pytree(self, device, jacapi): 2036 def f(a, b, c): 2037 a0, a1 = a 2038 return a0 + a1 * 2, {"foo": b * 3 + c * 4} 2039 2040 x = torch.randn([], device=device) 2041 zero = torch.zeros([], device=device) 2042 args = ((x, x), x, x) 2043 2044 result = jacapi(f)(*args) 2045 expected = ( 2046 (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), 2047 {"foo": (zero, zero)}, 2048 ) 2049 self.assertEqual(result, expected) 2050 2051 result = jacapi(f, argnums=(0,))(*args) 2052 expected = ( 2053 ((torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),), 2054 {"foo": ((zero, zero),)}, 2055 ) 2056 self.assertEqual(result, expected) 2057 2058 result = jacapi(f, argnums=(0, 1))(*args) 2059 expected = ( 2060 ( 2061 (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), 2062 zero, 2063 ), 2064 {"foo": ((zero, zero), torch.tensor(3.0, device=device))}, 2065 ) 2066 self.assertEqual(result, expected) 2067 2068 @jacrev_and_jacfwd 2069 def test_multiple_inputs_outputs_pytree_multidim(self, device, jacapi): 2070 def f(dct): 2071 a = dct["a"] 2072 b = dct["b"] 2073 return {"c": a.sin(), "d": b.cos()} 2074 2075 x = torch.randn(3, device=device) 2076 args = ({"a": x, "b": x},) 2077 2078 result = jacapi(f)(*args) 2079 expected = { 2080 "c": {"a": x.cos().diagflat(), "b": x.new_zeros(3, 3)}, 2081 "d": {"a": x.new_zeros(3, 3), "b": -x.sin().diagflat()}, 2082 } 2083 self.assertEqual(result, expected) 2084 2085 @jacrev_and_jacfwd 2086 def test_unrelated_input(self, device, jacapi): 2087 def f(x, y): 2088 return x 2089 2090 x = torch.randn(2, 3, device=device) 2091 y = torch.randn(2, 3, device=device) 2092 2093 result = jacapi(f, argnums=(0, 1))(x, y) 2094 expected0 = torch.eye(6, 6, device=device).view(2, 3, 2, 3) 2095 expected1 = y.new_zeros(2, 3, 2, 3) 2096 expected = (expected0, expected1) 2097 self.assertTrue(isinstance(result, tuple)) 2098 self.assertEqual(result, expected) 2099 2100 @jacrev_and_jacfwd 2101 def test_unrelated_output(self, device, jacapi): 2102 y = torch.randn(2, 3, device=device) 2103 2104 def f(x): 2105 return y 2106 2107 x = torch.randn(2, 3, device=device) 2108 2109 result = jacapi(f)(x) 2110 expected = x.new_zeros(2, 3, 2, 3) 2111 self.assertEqual(result, expected) 2112 2113 @jacrev_and_jacfwd 2114 def test_empty_output(self, device, jacapi): 2115 x = torch.randn(3, device=device) 2116 y = torch.randn(3, device=device) 2117 2118 def f(x, y): 2119 return () 2120 2121 with self.assertRaisesRegex(RuntimeError, "xpected"): 2122 jacapi(f)(x, y) 2123 2124 @jacrev_and_jacfwd 2125 def test_argnums_tuple(self, device, jacapi): 2126 x = torch.randn(3, device=device) 2127 y = torch.randn(3, device=device) 2128 z = jacapi(torch.multiply, argnums=(0, 1))(x, y) 2129 expected0 = torch.diagflat(y) 2130 expected1 = torch.diagflat(x) 2131 assert len(z) == 2 2132 assert torch.allclose(z[0], expected0) 2133 assert torch.allclose(z[1], expected1) 2134 2135 @jacrev_and_jacfwd 2136 def test_argnums_effect_on_return(self, device, jacapi): 2137 x = torch.randn(3, device=device) 2138 y = torch.randn(3, device=device) 2139 z = jacapi(torch.multiply, argnums=(0,))(x, y) 2140 expected0 = torch.diagflat(y) 2141 assert isinstance(z, tuple) 2142 assert len(z) == 1 2143 assert torch.allclose(z[0], expected0) 2144 2145 x = torch.randn(3, device=device) 2146 y = torch.randn(3, device=device) 2147 z = jacapi(torch.multiply, argnums=0)(x, y) 2148 expected0 = torch.diagflat(y) 2149 assert isinstance(z, torch.Tensor) 2150 assert torch.allclose(z, expected0) 2151 2152 @jacrev_and_jacfwd 2153 def test_argnums_defaults_to_zero(self, device, jacapi): 2154 def f(x, y): 2155 return x * 2 + y * 3 2156 2157 x = torch.randn(3, device=device) 2158 y = torch.randn(3, device=device) 2159 z = jacapi(f)(x, y) 2160 expected = torch.diagflat(torch.full_like(x, 2)) 2161 self.assertEqual(z, expected) 2162 2163 @jacrev_and_jacfwd 2164 def test_empty_argnums(self, device, jacapi): 2165 x = torch.randn(3, device=device) 2166 with self.assertRaisesRegex(RuntimeError, "must be non-empty"): 2167 jacapi(torch.sin, argnums=())(x) 2168 2169 @jacrev_and_jacfwd 2170 def test_out_of_bounds_argnums(self, device, jacapi): 2171 x = torch.randn(3, device=device) 2172 with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"): 2173 jacapi(torch.sin, argnums=2)(x) 2174 2175 @jacrev_and_jacfwd 2176 def test_negative_argnums(self, device, jacapi): 2177 x = torch.randn(3, device=device) 2178 with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"): 2179 jacapi(torch.sin, argnums=-2)(x) 2180 2181 @jacrev_and_jacfwd 2182 def test_repeated_argnums(self, device, jacapi): 2183 x = torch.randn(3, device=device) 2184 with self.assertRaisesRegex(RuntimeError, "must be unique"): 2185 jacapi(torch.sin, argnums=(0, 0))(x) 2186 2187 @jacrev_and_jacfwd 2188 def test_float_argnums(self, device, jacapi): 2189 x = torch.randn(3, device=device) 2190 with self.assertRaisesRegex(RuntimeError, "must be int or Tuple"): 2191 jacapi(torch.sin, argnums=0.0)(x) 2192 with self.assertRaisesRegex(RuntimeError, "must be int"): 2193 jacapi(torch.multiply, argnums=(1, 0.0))(x, x) 2194 2195 def test_hessian_simple(self, device): 2196 def f(x): 2197 return x.sin() 2198 2199 x = torch.randn(3, device=device) 2200 hessian(f)(x) 2201 2202 def _test_against_reference(self, f, inputs, jacapi): 2203 def foo(inputs): 2204 return f(*inputs) 2205 2206 expected = torch.autograd.functional.jacobian(f, inputs) 2207 result = jacapi(foo)(inputs) 2208 self.assertEqual(result, expected) 2209 2210 @jacrev_and_jacfwd 2211 def test_against_reference_simple(self, device, jacapi): 2212 def f(x): 2213 return 3 * x**2 2214 2215 x = torch.randn(2, 3, 5, device=device) 2216 self._test_against_reference(f, (x,), jacapi) 2217 2218 @jacrev_and_jacfwd 2219 def test_against_reference_multi_input(self, device, jacapi): 2220 def f(x, y): 2221 return (x.cos() * x) @ y.sin() 2222 2223 x = torch.randn(2, 3, device=device) 2224 y = torch.randn(3, 5, device=device) 2225 self._test_against_reference(f, (x, y), jacapi) 2226 2227 @jacrev_and_jacfwd 2228 def test_against_reference_multi_input_multi_output(self, device, jacapi): 2229 def f(x, y): 2230 return (x * x) @ y, x @ (x.sum(1) * y), y.sum() 2231 2232 x = torch.randn(5, 3, device=device) 2233 y = torch.randn(3, 5, device=device) 2234 self._test_against_reference(f, (x, y), jacapi) 2235 2236 @jacrev_and_jacfwd 2237 def test_against_reference_unrelated_outputs(self, device, jacapi): 2238 def f(x, y): 2239 return x, y, x, y 2240 2241 x = torch.randn(2, device=device) 2242 y = torch.randn(3, device=device) 2243 self._test_against_reference(f, (x, y), jacapi) 2244 2245 @jacrev_and_jacfwd 2246 def test_against_reference_zero_dim(self, device, jacapi): 2247 # zero-dim output 2248 def f(x, y): 2249 return x.sum(), y.sum(), x * y 2250 2251 x = torch.randn(3, device=device) 2252 y = torch.randn(3, device=device) 2253 self._test_against_reference(f, (x, y), jacapi) 2254 2255 # zero-dim input 2256 def g(x): 2257 return torch.stack([x, x, x]) 2258 2259 x = torch.randn([], device=device) 2260 self._test_against_reference(g, (x,), jacapi) 2261 2262 # Mixed zero-dim input / zero-dim output 2263 def h(x, y): 2264 return y.sum(), x * y 2265 2266 x = torch.randn([], device=device) 2267 y = torch.randn(1, device=device) 2268 self._test_against_reference(h, (x, y), jacapi) 2269 2270 @jacrev_and_jacfwd 2271 def test_against_reference_correctness_different_devices(self, device, jacapi): 2272 def f(x, y): 2273 return x * y, (x * y).to(device=device) 2274 2275 x = torch.randn(3) 2276 y = torch.randn(3) 2277 self._test_against_reference(f, (x, y), jacapi) 2278 2279 @jacrev_and_jacfwd 2280 def test_against_reference_default_arg(self, device, jacapi): 2281 def f(x, y, z=3.0): 2282 return x * y * z 2283 2284 x = torch.randn(3, device=device) 2285 y = torch.randn(3, device=device) 2286 self._test_against_reference(f, (x, y), jacapi) 2287 2288 @jacrev_and_jacfwd 2289 def test_inplace(self, device, jacapi): 2290 def f(x, y): 2291 y.copy_(x) 2292 return y 2293 2294 out = jacapi(f, argnums=0) # x is differentiable 2295 x, y = torch.randn(2, device=device), torch.randn(2, device=device) 2296 self.assertEqual(out(x, y), torch.eye(y.shape[0])) 2297 2298 # testing tuple of argnums with the example that raised this issue originally 2299 def g(x, y, z): 2300 x[:2] = y 2301 return torch.vstack([(x**2).sum(), (z**3).sum()]) 2302 2303 out = jacapi(g, argnums=(1, 2)) 2304 x, y, z = ( 2305 torch.randn(3, device=device), 2306 torch.randn(2, device=device), 2307 torch.randn(2, device=device), 2308 ) 2309 2310 expected_out = ( 2311 torch.zeros(2, 1, 2, device=device), 2312 torch.zeros(2, 1, 2, device=device), 2313 ) 2314 expected_out[0][0][0] = 2 * y # top left corner 2315 expected_out[1][1][0] = 3 * (z**2) # bottom right corner 2316 2317 out_val = out(x, y, z) 2318 self.assertEqual(out_val, expected_out) 2319 2320 @parametrize("_preallocate_and_copy", (True, False)) 2321 def test_chunk_jacrev(self, device, _preallocate_and_copy): 2322 x = torch.randn(10, 2, device=device) 2323 y = torch.randn(1, 2, device=device) 2324 2325 def f(x, y): 2326 return (x.sin(), x + y), (x + 2, x.sum()) 2327 2328 for chunk_size in (1, 2, 3, 4, 7, 10, 1000): 2329 expected = jacrev(f, argnums=(0, 1))(x, y) 2330 actual = jacrev( 2331 f, 2332 argnums=(0, 1), 2333 chunk_size=chunk_size, 2334 _preallocate_and_copy=_preallocate_and_copy, 2335 )(x, y) 2336 self.assertEqual(actual, expected) 2337 2338 err_msg = "jacrev: `chunk_size` should be greater than 0." 2339 with self.assertRaisesRegex(ValueError, err_msg): 2340 jacrev(f, argnums=(0,), chunk_size=0)(x, y) 2341 2342 with self.assertRaisesRegex(ValueError, err_msg): 2343 jacrev(f, argnums=(0,), chunk_size=-2)(x, y) 2344 2345 @parametrize("_preallocate_and_copy", (True, False)) 2346 def test_chunk_jacrev_composition(self, device, _preallocate_and_copy): 2347 x = torch.randn(10, 2, device=device) 2348 chunk_size = 3 2349 2350 def f(x): 2351 return (x.sin(), x), (x + 2, x.sum()) 2352 2353 expected = vmap(jacrev(jacrev(f)))(x) 2354 actual = vmap( 2355 jacrev( 2356 jacrev( 2357 f, 2358 chunk_size=chunk_size, 2359 _preallocate_and_copy=_preallocate_and_copy, 2360 ), 2361 chunk_size=chunk_size, 2362 ) 2363 )(x) 2364 self.assertEqual(actual, expected) 2365 2366 # https://github.com/pytorch/pytorch/issues/127036 2367 @xfailIfTorchDynamo 2368 @parametrize("_preallocate_and_copy", (True, False)) 2369 def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): 2370 # With chunk_size=1, we shouldn't `vmap` and hence not be limited 2371 # by it's constraints. 2372 x = torch.randn(3, 3, device=device) 2373 2374 # Function with Dynamic Op in Backward. 2375 # This should cause jacrev/vmap(vjp) to fail. 2376 class IdentityWithDynamicBackwardOp(torch.autograd.Function): 2377 @staticmethod 2378 def forward(input): 2379 return input 2380 2381 @staticmethod 2382 def setup_context(ctx, inputs, output): 2383 pass 2384 2385 @staticmethod 2386 def backward(ctx, grad_output): 2387 # dynamic op in backward pass. 2388 grad_output.nonzero() 2389 return grad_output 2390 2391 def f(x): 2392 return IdentityWithDynamicBackwardOp.apply(x) 2393 2394 # With `chunk_size=1`, we don't use vmap. So the following should work. 2395 jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy) 2396 actual = jacfn(x) 2397 expected = torch.autograd.functional.jacobian(f, x, vectorize=False) 2398 self.assertEqual(actual, expected) 2399 2400 # Should fail with `chunk_size=2`. 2401 msg = ( 2402 r"vmap: We do not support batching operators that can output dynamic shape." 2403 ) 2404 with self.assertRaisesRegex(RuntimeError, msg): 2405 jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x) 2406 2407 def test_complex_error(self, device): 2408 # Verify complex input raises error 2409 # C -> C 2410 def fn(x): 2411 return x.conj() 2412 2413 x = torch.randn(1, device=device, dtype=torch.cfloat) 2414 2415 with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all inputs"): 2416 jacrev(fn)(x) 2417 2418 with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all inputs"): 2419 jacfwd(fn)(x) 2420 2421 # Verify complex output raises error 2422 # R -> C 2423 def fn(x): 2424 return torch.conj(x * 0.5j) 2425 2426 x = torch.randn(1, device=device, dtype=torch.float) 2427 2428 with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all outputs"): 2429 jacrev(fn)(x) 2430 2431 with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all outputs"): 2432 jacfwd(fn)(x) 2433 2434 @jacrev_and_jacfwd 2435 def test_jac_with_non_tensor_args(self, device, jacapi): 2436 def f(t, int_x): 2437 return t + int_x 2438 2439 t = torch.randn(3, 3, device=device) 2440 2441 actual = jacapi(f)(t, 3) 2442 expected = torch.autograd.functional.jacobian(partial(f, int_x=3), t) 2443 self.assertEqual(actual, expected) 2444 2445 2446@markDynamoStrictTest 2447class TestHessian(TestCase): 2448 def _test_against_reference(self, f, inputs): 2449 def foo(inputs): 2450 return f(*inputs) 2451 2452 expected = torch.autograd.functional.hessian(f, inputs) 2453 result = hessian(foo)(inputs) 2454 self.assertEqual(result, expected) 2455 2456 def test_hessian_vectorize_correctness_simple(self, device): 2457 def f(x): 2458 return (3 * x**2).sum() 2459 2460 x = torch.randn(2, 3, 5, device=device) 2461 self._test_against_reference(f, (x,)) 2462 2463 def test_hessian_vectorize_correctness_multi_input(self, device): 2464 def f(x, y, z): 2465 return ((x.relu() * x) @ y.sin() @ z).sum() 2466 2467 x = torch.randn(2, 3, device=device) 2468 y = torch.randn(3, 5, device=device) 2469 z = torch.randn(5, 5, device=device) 2470 self._test_against_reference(f, (x, y, z)) 2471 2472 def test_hessian_vectorize_correctness_unrelated_outputs(self, device): 2473 # output unrelated to one input 2474 def f(x, y): 2475 return (x**2).sum() 2476 2477 x = torch.randn(2, device=device) 2478 y = torch.randn(3, device=device) 2479 self._test_against_reference(f, (x, y)) 2480 2481 # output unrelated to all inputs 2482 def f(x, y): 2483 return torch.ones([]) 2484 2485 x = torch.randn(2, device=device) 2486 y = torch.randn(3, device=device) 2487 self._test_against_reference(f, (x, y)) 2488 2489 def test_jacfwd_different_levels(self, device): 2490 # Test case from: 2491 # https://github.com/pytorch/functorch/issues/597 2492 b = 8 2493 n = 100 2494 d = 2 2495 x1 = torch.randn(b, n, d, device=device) 2496 x2 = x1 2497 A = 0.1 * torch.randn(b, d, d, device=device) 2498 2499 def loss(A, x1, x2): 2500 x2_hat = (A @ (x1.T)).T 2501 res = x2 - x2_hat 2502 res_sqr = res**2 2503 return res_sqr.sum() 2504 2505 hess1 = vmap(jacrev(jacrev(loss)))(A, x1, x2) 2506 hess2 = vmap(hessian(loss))(A, x1, x2) 2507 self.assertEqual(hess2, hess1) 2508 2509 2510@markDynamoStrictTest 2511class TestJvp(TestCase): 2512 def test_inplace_on_captures(self, device): 2513 x = torch.tensor([1.0, 2.0, 3.0], device=device) 2514 captured = torch.randn(3, device=device) 2515 2516 def foo(x): 2517 captured.copy_(x) 2518 return (x * captured).sum() 2519 2520 with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"): 2521 grad(foo)(x) 2522 2523 def test_simple(self, device): 2524 x = torch.randn(2, 3, device=device) 2525 t = torch.randn(2, 3, device=device) 2526 result = jvp(torch.sin, (x,), (t,)) 2527 expected = (x.sin(), x.cos() * t) 2528 self.assertTrue(isinstance(result, tuple)) 2529 self.assertEqual(result, expected) 2530 2531 def test_multiple_inputs(self, device): 2532 x = torch.randn(2, 3, device=device) 2533 y = torch.randn(2, 3, device=device) 2534 tx = torch.randn(2, 3, device=device) 2535 ty = torch.randn(2, 3, device=device) 2536 2537 def f(x, y): 2538 return x * y 2539 2540 result = jvp(f, (x, y), (tx, ty)) 2541 expected = (x * y, y * tx + x * ty) 2542 self.assertTrue(isinstance(result, tuple)) 2543 self.assertEqual(result, expected) 2544 2545 def test_pytree_inputs(self, device): 2546 def f(x, y, z): 2547 a, b = x 2548 return a + 2 * b + 3 * y + 4 * z 2549 2550 one = torch.tensor(1.0, device=device) 2551 primal_outs, tangent_outs = jvp( 2552 f, ((one, one), one, one), ((one, one), one, one) 2553 ) 2554 self.assertEqual(primal_outs, one * 10) 2555 self.assertEqual(tangent_outs, one * 10) 2556 2557 def test_pytree_inputs_error_cases(self, device): 2558 def f(x): 2559 return x 2560 2561 one = torch.tensor(1.0, device=device) 2562 2563 with self.assertRaisesRegex(RuntimeError, "Expected primals to be a tuple"): 2564 jvp(f, one, one) 2565 with self.assertRaisesRegex(RuntimeError, "same python structure"): 2566 jvp(f, ((one, one), one), (one, one)) 2567 with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): 2568 jvp(f, ((one, one), 1), ((one, one), one)) 2569 with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): 2570 jvp(f, ((one, one), 1), ((1, one), one)) 2571 with self.assertRaisesRegex(RuntimeError, "at least one Tensor"): 2572 jvp(f, ((),), ((),)) 2573 2574 def test_unrelated_input(self, device): 2575 def f(x, y): 2576 return x 2577 2578 x = torch.randn(2, 3, device=device) 2579 y = torch.randn(2, 3, device=device) 2580 tx = torch.randn(2, 3, device=device) 2581 ty = torch.randn(2, 3, device=device) 2582 2583 result = jvp(f, (x, y), (tx, ty)) 2584 expected = (x, tx) 2585 self.assertTrue(isinstance(result, tuple)) 2586 self.assertEqual(result, expected) 2587 2588 def test_unrelated_output(self, device): 2589 y = torch.randn(2, 3, device=device) 2590 2591 def f(x): 2592 return y 2593 2594 x = torch.randn(2, 3, device=device) 2595 tx = torch.randn(2, 3, device=device) 2596 2597 result = jvp(f, (x,), (tx,)) 2598 expected = (y, torch.zeros_like(y)) 2599 self.assertTrue(isinstance(result, tuple)) 2600 self.assertEqual(result, expected) 2601 2602 def test_strict_mode(self, device): 2603 y = torch.randn(2, 3, device=device) 2604 2605 def f(x): 2606 return x, y 2607 2608 x = torch.randn(2, 3, device=device) 2609 tx = torch.randn(2, 3, device=device) 2610 2611 with self.assertRaisesRegex(RuntimeError, "strict"): 2612 jvp(f, (x,), (tx,), strict=True) 2613 2614 def test_multiple_outputs(self, device): 2615 x = torch.randn(2, 3, device=device) 2616 t = torch.randn(2, 3, device=device) 2617 2618 def f(x): 2619 return torch.sin(x), torch.cos(x) 2620 2621 result = jvp(f, (x,), (t,)) 2622 expected = (f(x), (x.cos() * t, -x.sin() * t)) 2623 self.assertTrue(isinstance(result, tuple)) 2624 self.assertEqual(result, expected) 2625 2626 def test_multiple_inputs_outputs(self, device): 2627 x = torch.randn(2, 3, device=device) 2628 y = torch.randn(2, 3, device=device) 2629 tx = torch.randn(2, 3, device=device) 2630 ty = torch.randn(2, 3, device=device) 2631 2632 def f(x, y): 2633 return 2 * x + 3 * y, 4 * x + 5 * y 2634 2635 result = jvp(f, (x, y), (tx, ty)) 2636 expected = (f(x, y), f(tx, ty)) 2637 self.assertTrue(isinstance(result, tuple)) 2638 self.assertEqual(result, expected) 2639 2640 def test_jvp_new_tensor(self): 2641 def f(x): 2642 y = x.new_tensor(0.5) 2643 return x + y 2644 2645 x = torch.rand(10, 10) 2646 tangents = torch.zeros_like(x) 2647 actual = jvp(f, (x,), (tangents,)) 2648 expected = (f(x), torch.zeros_like(x)) 2649 self.assertEqual(actual, expected) 2650 2651 def test_primals_tangents_length_mismatch(self, device): 2652 x = torch.randn(2, 3, device=device) 2653 t = torch.randn(2, 3, device=device) 2654 2655 msg = "same python structure" 2656 with self.assertRaisesRegex(RuntimeError, msg): 2657 jvp(torch.sin, (x,), (t, t)) 2658 with self.assertRaisesRegex(RuntimeError, msg): 2659 jvp(torch.sin, (x, x), (t, t, t)) 2660 2661 def test_nonempty_primals_and_tangents(self, device): 2662 with self.assertRaisesRegex(RuntimeError, "at least one Tensor"): 2663 jvp(torch.sin, (), ()) 2664 2665 def test_inputs_are_tuples_of_tensors(self, device): 2666 x = torch.randn(2, 3, device=device) 2667 t = torch.randn(2, 3, device=device) 2668 2669 with self.assertRaisesRegex(RuntimeError, "be a tuple"): 2670 jvp(torch.sin, x, (t,)) 2671 with self.assertRaisesRegex(RuntimeError, "same python structure"): 2672 jvp(torch.sin, (x,), t) 2673 with self.assertRaisesRegex(RuntimeError, "same python structure"): 2674 jvp(torch.sin, (x,), [t]) 2675 with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): 2676 jvp(torch.sin, (1.0,), (t,)) 2677 with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): 2678 jvp(torch.sin, (x,), (1.0,)) 2679 2680 def test_outputs_can_any_pytree(self, device): 2681 x = torch.randn(2, 3, device=device) 2682 t = torch.randn(2, 3, device=device) 2683 2684 for output in [None, ()]: 2685 with self.assertRaisesRegex( 2686 RuntimeError, 2687 r"jvp\(f, primals, tangents\): Expected f to be a function that has non-empty output", 2688 ): 2689 jvp(lambda _: output, (x,), (t,)) 2690 2691 for output in [1, True, 12.2, "abc"]: 2692 with self.assertRaisesRegex( 2693 RuntimeError, 2694 r"jvp\(f, primals, tangents\): expected f\(\*primals\) to return only tensors", 2695 ): 2696 jvp(lambda _: output, (x,), (t,)) 2697 2698 # Check list output 2699 out = jvp(lambda x: [x, x.sum()], (x,), (t,)) 2700 for i in range(2): 2701 assert isinstance(out[i], list) and len(out[i]) == 2 2702 2703 # Check dict output 2704 out = jvp(lambda x: {"x": x, "xsum": x.sum()}, (x,), (t,)) 2705 for i in range(2): 2706 assert isinstance(out[i], dict) and len(out[i]) == 2 and "xsum" in out[i] 2707 2708 def composite_output(x): 2709 out = x.sum() 2710 return [ 2711 (out, {"a": x, "out": [x, out]}), 2712 ] 2713 2714 out = jvp(composite_output, (x,), (t,)) 2715 for i in range(2): 2716 assert isinstance(out[i], list) 2717 assert isinstance(out[i][0], tuple) and isinstance(out[i][0][1], dict) 2718 2719 def test_aux_tensor(self, device): 2720 x = torch.randn(3, device=device) 2721 t = torch.randn(3, device=device) 2722 2723 with self.assertRaisesRegex( 2724 RuntimeError, 2725 r"jvp\(f, primals, tangents\): output of function f should be a tuple", 2726 ): 2727 jvp(lambda t: [t, t], (x,), (t,), has_aux=True) 2728 2729 with self.assertRaisesRegex( 2730 RuntimeError, 2731 r"jvp\(f, primals, tangents\): output of function f should be a tuple", 2732 ): 2733 jvp(lambda t: (t, t + 2, t + 3), (x,), (t,), has_aux=True) 2734 2735 def f(z): 2736 y = z.sin() 2737 return y, z.cos() 2738 2739 out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True) 2740 self.assertEqual(aux, x.cos()) 2741 self.assertEqual(out, x.sin()) 2742 self.assertEqual(jvp_out, t * x.cos()) 2743 2744 def test_aux_pytree(self, device): 2745 def f(x): 2746 y = x.sin() 2747 return y, {"a": x.cos(), "b": [x.tan()]} 2748 2749 x = torch.randn(3, device=device) 2750 t = torch.randn(3, device=device) 2751 2752 out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True) 2753 expected_out, expected_aux = f(x) 2754 self.assertEqual(out, expected_out) 2755 self.assertEqual(aux, expected_aux) 2756 self.assertEqual(jvp_out, t * x.cos()) 2757 2758 for aux in [1, 1.0, "abc"]: 2759 with self.assertRaisesRegex( 2760 RuntimeError, r"Expected tensors, got unsupported type" 2761 ): 2762 _ = jvp(lambda x: (x, aux), (x,), (t,), has_aux=True) 2763 with self.assertRaisesRegex( 2764 RuntimeError, r"Expected tensors, got unsupported type" 2765 ): 2766 _ = jvp(lambda x: (x, [x, aux]), (x,), (t,), has_aux=True) 2767 2768 def test_autograd_function_disables_fwd_grad(self, device): 2769 # Sanity check. We don't really assume this anywhere so 2770 # it's fine if this breaks one day. 2771 class MySquare(torch.autograd.Function): 2772 @staticmethod 2773 def forward(ctx, x): 2774 enabled = fwAD._is_fwd_grad_enabled() 2775 self.assertFalse(enabled) 2776 return x * x 2777 2778 @staticmethod 2779 def backward(ctx, gx): 2780 return gx 2781 2782 x = torch.randn(3, requires_grad=True) 2783 MySquare.apply(x) 2784 2785 def test_disable_fwd_grad_outside(self, device): 2786 x = torch.randn([], device=device) 2787 t = torch.ones_like(x) 2788 with fwAD._set_fwd_grad_enabled(False): 2789 _, y = jvp(torch.sin, (x,), (t,)) 2790 self.assertEqual(y, x.cos()) 2791 2792 def test_disable_fwd_grad_inside(self, device): 2793 def f(x): 2794 with fwAD._set_fwd_grad_enabled(False): 2795 shift = x**2 2796 return x**2 - shift 2797 2798 x = torch.randn([], device=device) 2799 t = torch.ones_like(x) 2800 _, y = jvp(f, (x,), (t,)) 2801 self.assertEqual(y, 2 * x) 2802 _, y = jvp(lambda x: jvp(f, (x,), (t,))[1], (x,), (t,)) 2803 self.assertEqual(y, 2) 2804 2805 def test_disable_fwd_grad_mixed(self, device): 2806 def f(x): 2807 with fwAD._set_fwd_grad_enabled(False): 2808 shift = x**2 2809 return x**2 - shift 2810 2811 x = torch.randn([], device=device) 2812 t = torch.ones_like(x) 2813 with fwAD._set_fwd_grad_enabled(True): 2814 _, y = jvp(f, (x,), (t,)) 2815 2816 self.assertEqual(y, 2 * x) 2817 2818 def test_jvp_inside_autograd_function(self, device): 2819 class MySin(torch.autograd.Function): 2820 @staticmethod 2821 def forward(ctx, x): 2822 t = torch.ones_like(x) 2823 _, neg_sin_x = jvp(torch.cos, (x,), (t,)) 2824 ctx.save_for_backward(x) 2825 return -neg_sin_x 2826 2827 @staticmethod 2828 def backward(ctx, gx): 2829 (x,) = ctx.saved_tensors 2830 t = torch.ones_like(x) 2831 _, cos_x = jvp(torch.sin, (x,), (t,)) 2832 return gx * cos_x 2833 2834 x = torch.randn([], device=device, requires_grad=True) 2835 y = MySin.apply(x) 2836 self.assertEqual(y, x.sin()) 2837 2838 (gx,) = torch.autograd.grad(y, x) 2839 self.assertEqual(gx, x.cos()) 2840 2841 def test_zerotensor_vmapjvp_interaction(self, device): 2842 dummy = torch.ones(4, 1) 2843 x = torch.randn(4, 2) 2844 x_tangent = torch.randn(2) 2845 2846 def push_jvp(dummy, x): 2847 result = jvp(torch.cov, (x,), (x_tangent,)) 2848 return result 2849 2850 # Should not error 2851 vmap(vmap(push_jvp, (0, None)))(dummy, x) 2852 2853 2854@markDynamoStrictTest 2855class TestLinearize(TestCase): 2856 @dtypes(torch.float) 2857 def test_linearize_basic(self, device, dtype): 2858 x_p = make_tensor((3, 1), device=device, dtype=dtype) 2859 x_t = make_tensor((3, 1), device=device, dtype=dtype) 2860 2861 def fn(x): 2862 return x.cos() 2863 2864 actual_output, jvp_fn = linearize(fn, x_p) 2865 actual_jvp = jvp_fn(x_t) 2866 expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,)) 2867 self.assertEqual(actual_output, expected_output) 2868 self.assertEqual(actual_jvp, expected_jvp) 2869 2870 @dtypes(torch.float) 2871 def test_linearize_return(self, device, dtype): 2872 x_p = make_tensor((3, 1), device=device, dtype=dtype) 2873 x_t = make_tensor((3, 1), device=device, dtype=dtype) 2874 2875 def fn(x): 2876 return (x.cos(), x.sum()) 2877 2878 actual_output, jvp_fn = linearize(fn, x_p) 2879 actual_jvp = jvp_fn(x_t) 2880 expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,)) 2881 self.assertEqual(actual_output, expected_output) 2882 self.assertEqual(actual_jvp, expected_jvp) 2883 2884 @dtypes(torch.float) 2885 def test_linearize_composition_vmap(self, device, dtype): 2886 x_p = make_tensor((3, 1), device=device, dtype=dtype) 2887 x_t = make_tensor((3, 3, 1), device=device, dtype=dtype) 2888 2889 def fn(x): 2890 return (x.cos(), x.sum()) 2891 2892 _, jvp_fn = linearize(fn, x_p) 2893 actual_batched_jvp = vmap(jvp_fn)(x_t) 2894 2895 def jvp_fn(x_t): 2896 return jvp(fn, (x_p,), (x_t,))[1] 2897 2898 expected_batched_jvp = vmap(jvp_fn)(x_t) 2899 2900 self.assertEqual(actual_batched_jvp, expected_batched_jvp) 2901 2902 @dtypes(torch.float) 2903 def test_linearize_composition_grad(self, device, dtype): 2904 x_p = make_tensor((3,), device=device, dtype=dtype) 2905 x_t = make_tensor((3,), device=device, dtype=dtype) 2906 2907 def fn(x): 2908 z = torch.ones(3, device=device, dtype=dtype) 2909 return grad(lambda x: z @ x)(x) 2910 2911 _, jvp_fn = linearize(fn, x_p) 2912 actual_batched_jvp = jvp_fn(x_t) 2913 2914 def jvp_fn(x_t): 2915 return jvp(fn, (x_p,), (x_t,))[1] 2916 2917 expected_batched_jvp = jvp_fn(x_t) 2918 2919 self.assertEqual(actual_batched_jvp, expected_batched_jvp) 2920 2921 @dtypes(torch.float) 2922 def test_linearize_nested_input_nested_output(self, device, dtype): 2923 x_p = make_tensor((3, 1), device=device, dtype=dtype) 2924 x_t = make_tensor((3, 1), device=device, dtype=dtype) 2925 y_p = make_tensor((3, 1), device=device, dtype=dtype) 2926 y_t = make_tensor((3, 1), device=device, dtype=dtype) 2927 z_p = make_tensor((3, 1), device=device, dtype=dtype) 2928 z_t = make_tensor((3, 1), device=device, dtype=dtype) 2929 2930 def fn(arg): 2931 x = arg["x"] 2932 y = arg["yz"][0] 2933 z = arg["yz"][1] 2934 2935 return {"a": x.sum(), "b": {"c": y + z, "d": (x * z, y.exp())}} 2936 2937 inp_p = {"x": x_p, "yz": (y_p, z_p)} 2938 inp_t = {"x": x_t, "yz": (y_t, z_t)} 2939 actual_output, jvp_fn = linearize(fn, inp_p) 2940 actual_jvp = jvp_fn(inp_t) 2941 2942 expected_output, expected_jvp = jvp(fn, (inp_p,), (inp_t,)) 2943 2944 self.assertEqual(actual_output, expected_output) 2945 self.assertEqual(actual_jvp, expected_jvp) 2946 2947 @onlyCUDA 2948 def test_linearize_errors(self): 2949 dtype = torch.float 2950 device = torch.device("cpu") 2951 x_p = make_tensor((3, 1), device=device, dtype=dtype) 2952 x_t = make_tensor((3, 1), device=device, dtype=dtype) 2953 2954 def fn(x): 2955 return x.sin() 2956 2957 _, jvp_fn = linearize(fn, x_p) 2958 2959 with self.assertRaisesRegex( 2960 RuntimeError, "to have the same argspec as the primals" 2961 ): 2962 jvp_fn((x_t, x_t)) 2963 2964 with self.assertRaisesRegex( 2965 RuntimeError, "in flattened pytree doesn't match the shape" 2966 ): 2967 jvp_fn(x_t.unsqueeze(0)) 2968 2969 with self.assertRaisesRegex( 2970 RuntimeError, "in flattened pytree doesn't match the dtype" 2971 ): 2972 jvp_fn(x_t.to(torch.double)) 2973 2974 with self.assertRaisesRegex( 2975 RuntimeError, "in flattened pytree doesn't match the device" 2976 ): 2977 jvp_fn(x_t.to(torch.device("cuda"))) 2978 2979 2980# The tests here follow the cases in [Forward Grad View/inplace] 2981# https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/autograd_meta.cpp#L18-L43 2982@markDynamoStrictTest 2983class TestVmapJvpInplaceView(TestCase): 2984 # Case 1 in [Forward Grad View/inplace] 2985 def test_all_dual_no_view(self, device): 2986 B = 2 2987 2988 def push_jvp(f): 2989 def inner(x, xt, y, yt): 2990 return jvp(f, (x, y), (xt, yt)) 2991 2992 return inner 2993 2994 def f(x, y): 2995 x.copy_(y) 2996 return x 2997 2998 x = torch.randn(3, B, device=device) 2999 xt = torch.randn(3, B, device=device) 3000 y = torch.randn(3, B, device=device) 3001 yt = torch.randn(3, B, device=device) 3002 out, out_tangent = vmap(push_jvp(f), in_dims=1)(x, xt, y, yt) 3003 self.assertEqual(out, x.movedim(1, 0)) 3004 self.assertEqual(out_tangent, yt.movedim(1, 0)) 3005 3006 x = torch.randn(3, B, device=device) 3007 xt = torch.randn(3, B, device=device) 3008 y = torch.randn(3, 3, device=device)[:, 1] 3009 yt = torch.randn(6, device=device)[::2] 3010 out, out_tangent = vmap(push_jvp(f), in_dims=(1, 1, None, None))(x, xt, y, yt) 3011 self.assertEqual(out, x.movedim(1, 0)) 3012 self.assertEqual(out_tangent, yt.expand(B, 3)) 3013 3014 # Case 2 in [Forward Grad View/inplace] 3015 def test_all_dual_base_view_inplace(self, device): 3016 B = 2 3017 3018 def push_jvp(f): 3019 def inner(x, xt, y, yt): 3020 return jvp(f, (x, y), (xt, yt)) 3021 3022 return inner 3023 3024 # with view, propagate from view to base 3025 def f(x, y): 3026 view = x[:, ::2] 3027 view.copy_(y) 3028 return view, x 3029 3030 orig_x = torch.randn(2, 6, B, device=device) 3031 orig_xt = torch.randn(2, 6, B, device=device) 3032 x = orig_x.clone() 3033 xt = orig_xt.clone() 3034 y = torch.randn(2, B, 3, device=device) 3035 yt = torch.randn(2, B, 3, device=device) 3036 out, out_tangent = vmap(push_jvp(f), in_dims=(2, 2, 1, 1))(x, xt, y, yt) 3037 3038 expected_out = vmap(f, in_dims=(2, 1))(orig_x.clone(), y) 3039 self.assertEqual(out[0], expected_out[0]) 3040 self.assertEqual(out[1], expected_out[1]) 3041 3042 self.assertEqual(out_tangent[0], yt.movedim(1, 0)) 3043 3044 expected_x_tangent = orig_xt.movedim(-1, 0).clone() 3045 expected_x_tangent[:, :, ::2].copy_(yt.movedim(1, 0)) 3046 self.assertEqual(out_tangent[1], expected_x_tangent) 3047 3048 expected = orig_x.movedim(2, 0).clone() 3049 expected[:, :, ::2] = y.movedim(1, 0) 3050 self.assertEqual(x.movedim(2, 0), expected) 3051 3052 # Case 3 in [Forward Grad View/inplace] 3053 def test_all_dual_base_inplace(self, device): 3054 B = 2 3055 3056 def push_jvp(f): 3057 def inner(x, xt, y, yt): 3058 return jvp(f, (x, y), (xt, yt)) 3059 3060 return inner 3061 3062 # Case 3: with view, propagate from base to view 3063 def f(x, y): 3064 view = x[0, ::2] 3065 x.copy_(y) 3066 return x, view 3067 3068 x = torch.randn(2, B, 6, device=device) 3069 xt = torch.randn(2, 6, B, device=device) 3070 y = torch.randn(2, B, 6, device=device) 3071 yt = torch.randn(2, B, 6, device=device) 3072 out, out_tangent = vmap(push_jvp(f), in_dims=(1, 2, 1, 1))(x.clone(), xt, y, yt) 3073 3074 expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y) 3075 self.assertEqual(out[0], expected_out[0]) 3076 self.assertEqual(out[1], expected_out[1]) 3077 3078 self.assertEqual(out_tangent[0], yt.movedim(1, 0)) 3079 self.assertEqual(out_tangent[1], yt.movedim(1, 0)[:, 0, ::2]) 3080 3081 # Case 4 in [Forward Grad View/inplace] 3082 def test_right_dual_view_prop(self, device): 3083 B = 2 3084 3085 # Changes on the view must propagate to its base. Also: 3086 # - x is a regular Tensor 3087 # - y is a dual tensor 3088 def f(x, y): 3089 x = x.clone() 3090 view = x[0] 3091 view.copy_(y) 3092 return view, x 3093 3094 def push_jvp(x, y, yt): 3095 return jvp(partial(f, x), (y,), (yt,)) 3096 3097 x = torch.randn(2, B, 6, device=device) 3098 y = torch.randn(6, B, device=device) 3099 yt = torch.randn(6, B, device=device) 3100 outs, tangents = vmap(push_jvp, in_dims=(1, 1, 1))(x, y, yt) 3101 3102 expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y) 3103 self.assertEqual(outs[0], expected_out[0]) 3104 self.assertEqual(outs[1], expected_out[1]) 3105 3106 self.assertEqual(tangents[0], yt.movedim(1, 0)) 3107 3108 expected_tangent_1 = torch.zeros_like(x).movedim(1, 0) 3109 expected_tangent_1[:, 0].copy_(yt.movedim(1, 0)) 3110 self.assertEqual(tangents[1], expected_tangent_1) 3111 3112 # Case 5 in [Forward Grad View/inplace] 3113 def test_right_dual_base_prop(self, device): 3114 B = 2 3115 3116 # Changes on the base must propagate on all its views. Also: 3117 # - x is a regular Tensor 3118 # - y is a dual tensor 3119 def f(x, y): 3120 x = x.clone() 3121 view = x[0] 3122 x.copy_(y) 3123 return view, x 3124 3125 def push_jvp(x, y, yt): 3126 return jvp(partial(f, x), (y,), (yt,)) 3127 3128 x = torch.randn(2, B, 6) 3129 y = torch.randn(2, 6, B) 3130 yt = torch.randn(2, 6, B) 3131 outs, tangents = vmap(push_jvp, in_dims=(1, 2, 2))(x, y, yt) 3132 3133 expected_out = vmap(f, in_dims=(1, 2))(x, y) 3134 self.assertEqual(outs[0], expected_out[0]) 3135 self.assertEqual(outs[1], expected_out[1]) 3136 3137 self.assertEqual(tangents[0], yt.movedim(2, 0)[:, 0]) 3138 self.assertEqual(tangents[1], yt.movedim(2, 0)) 3139 3140 3141# Use for testing miscellaneous helper functions 3142@markDynamoStrictTest 3143class TestHelpers(TestCase): 3144 def test_CtxWithSavedTensors_error_if_name_collision(self, device): 3145 x = torch.randn([], device=device, requires_grad=True) 3146 y = torch.randn([], device=device, requires_grad=True) 3147 3148 class A(torch.autograd.Function): 3149 @staticmethod 3150 def forward(ctx, x): 3151 ctx._pt_inner_ctx = 1 3152 ctx.save_for_backward(x) 3153 return x 3154 3155 @staticmethod 3156 def backward(ctx, gy): 3157 wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( 3158 ctx, (y,) 3159 ) 3160 return gy 3161 3162 class B(torch.autograd.Function): 3163 @staticmethod 3164 def forward(ctx, x): 3165 ctx._pt_new_saved_tensors = 1 3166 ctx.save_for_backward(x) 3167 return x 3168 3169 @staticmethod 3170 def backward(ctx, gy): 3171 wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( 3172 ctx, (y,) 3173 ) 3174 return gy 3175 3176 out = A.apply(x) 3177 with self.assertRaisesRegex(RuntimeError, "name collision"): 3178 out.backward() 3179 out = B.apply(x) 3180 with self.assertRaisesRegex(RuntimeError, "name collision"): 3181 out.backward() 3182 3183 def test_CtxWithSavedTensors_nesting(self, device): 3184 CtxWithSavedTensors = torch._functorch.autograd_function.CtxWithSavedTensors 3185 x = torch.randn([], device=device, requires_grad=True) 3186 y = torch.randn([], device=device) 3187 z = torch.randn([], device=device) 3188 3189 class A(torch.autograd.Function): 3190 @staticmethod 3191 def forward(ctx, x): 3192 ctx.save_for_backward(x) 3193 return x 3194 3195 @staticmethod 3196 def backward(ctx, gy): 3197 ctx_y = CtxWithSavedTensors(ctx, (y,)) 3198 # Can't use self.assertEqual because that relies on TLS 3199 # that is not available in multithread autograd 3200 assert len(ctx_y.saved_tensors) == 1 3201 assert torch.allclose(ctx_y.saved_tensors[0], y) 3202 3203 wrapped = CtxWithSavedTensors(ctx_y, (z,)) 3204 3205 assert len(wrapped.saved_tensors) == 1 3206 assert torch.allclose(wrapped.saved_tensors[0], z) 3207 3208 assert len(ctx_y.saved_tensors) == 1 3209 assert torch.allclose(ctx_y.saved_tensors[0], y) 3210 3211 return gy * wrapped.saved_tensors[0] 3212 3213 out = A.apply(x) 3214 out.backward() 3215 self.assertEqual(x.grad, z) 3216 3217 def test_CtxWithSavedTensors_overrides_saved_tensors(self, device): 3218 x = torch.randn([], device=device, requires_grad=True) 3219 3220 class A(torch.autograd.Function): 3221 @staticmethod 3222 def forward(ctx, x): 3223 ctx.save_for_backward(x) 3224 return x 3225 3226 @staticmethod 3227 def backward(ctx, gy): 3228 # The override can be literally anything 3229 override = (1, 2, 3) 3230 wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( 3231 ctx, override 3232 ) 3233 assert wrapped.saved_tensors == override 3234 return gy 3235 3236 out = A.apply(x) 3237 out.backward() 3238 3239 def test_CtxWithSavedTensors_passthrough(self, device): 3240 x = torch.randn([], device=device, requires_grad=True) 3241 y = torch.randn([], device=device) 3242 3243 class A(torch.autograd.Function): 3244 @staticmethod 3245 def forward(ctx, x, y): 3246 ctx.save_for_backward(x, y) 3247 return x * y 3248 3249 @staticmethod 3250 def backward(ctx, gz): 3251 # The override can be literally anything 3252 override = (1, 2, 3) 3253 wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( 3254 ctx, override 3255 ) 3256 3257 assert wrapped.needs_input_grad[0] == ctx.needs_input_grad[0] 3258 assert wrapped.needs_input_grad[1] == ctx.needs_input_grad[1] 3259 wrapped.foo = "bar" 3260 assert wrapped.foo == "bar" 3261 assert ctx.foo == "bar" 3262 return gz, gz 3263 3264 out = A.apply(x, y) 3265 out.backward() 3266 3267 def test_reductify_leaf(self, device): 3268 reductify_leaf = torch._functorch.autograd_function.reductify_leaf 3269 B = 2 3270 3271 # grad_input None case 3272 output = reductify_leaf(None, None, 0, B) 3273 self.assertIsNone(output) 3274 output = reductify_leaf(None, None, None, B) 3275 self.assertIsNone(output) 3276 3277 # grad_input has bdim, input does not have bdim 3278 grad_input = torch.randn([B, 3, 4], device=device) 3279 output = reductify_leaf(grad_input, 0, None, B) 3280 self.assertEqual(output, grad_input.sum(0)) 3281 3282 grad_input = torch.randn([3, B, 4], device=device) 3283 output = reductify_leaf(grad_input, 1, None, B, (3,)) 3284 self.assertEqual(output, grad_input.sum(1)) 3285 3286 # grad_input does not have bdim, input has bdim 3287 # This can happen if the user returns a fresh Tensor from the backward pass 3288 # that is unrelated to the input 3289 grad_input = torch.randn([3, 4], device=device) 3290 output = reductify_leaf(grad_input, None, 1, B) 3291 self.assertEqual(output, grad_input.view(3, 1, 4).expand(3, B, 4)) 3292 3293 grad_input = torch.randn([3, 4], device=device) 3294 output = reductify_leaf(grad_input, None, 1, B, (4,)) 3295 self.assertEqual(output, grad_input.view(3, 4, 1).expand(3, 4, B).sum(0)) 3296 3297 # grad_input has bdim, input has bdim 3298 grad_input = torch.randn([B, 3, 4], device=device) 3299 output = reductify_leaf(grad_input, 0, 1, B) 3300 self.assertEqual(output, grad_input.movedim(0, 1)) 3301 3302 grad_input = torch.randn([3, 4, 5, B], device=device) 3303 output = reductify_leaf(grad_input, 3, 0, B, (5,)) 3304 self.assertEqual(output, grad_input.movedim(-1, 2).sum(0).sum(0)) 3305 3306 3307@markDynamoStrictTest 3308class TestComposability(TestCase): 3309 def test_deprecation_vmap(self, device): 3310 x = torch.randn(3, device=device) 3311 3312 # functorch version of the API is deprecated 3313 with self.assertWarnsRegex(FutureWarning, "Please use `torch.vmap`"): 3314 vmap(torch.sin) 3315 3316 # the non-functorch version is not deprecated 3317 with warnings.catch_warnings(): 3318 warnings.simplefilter("error") 3319 torch.vmap(torch.sin) 3320 3321 # Some of these pass, some of these don't 3322 @parametrize( 3323 "transform", 3324 ["grad", "jacrev", "jacfwd", "grad_and_value", "hessian", "functionalize"], 3325 ) 3326 def test_deprecation_transforms(self, device, transform): 3327 api = getattr(functorch, transform) 3328 new_api = getattr(torch.func, transform) 3329 3330 # functorch version of the API is deprecated 3331 with self.assertWarnsRegex( 3332 FutureWarning, f"Please use `torch.func.{transform}`" 3333 ): 3334 api(torch.sin) 3335 3336 # the non-functorch version is not deprecated 3337 with warnings.catch_warnings(): 3338 warnings.simplefilter("error") 3339 new_api(torch.sin) 3340 3341 def test_grad_grad(self, device): 3342 x = torch.randn([], device=device) 3343 y = grad(grad(torch.sin))(x) 3344 self.assertEqual(y, -x.sin()) 3345 3346 def test_grad_vmap(self, device): 3347 def foo(x): 3348 y = vmap(torch.sin)(x) 3349 return y.sum() 3350 3351 x = torch.randn(3, device=device) 3352 y = grad(foo)(x) 3353 self.assertEqual(y, x.cos()) 3354 3355 def test_grad_vjp(self, device): 3356 x = torch.randn(3, device=device) 3357 3358 def foo(x): 3359 _, vjp_fn = vjp(torch.sin, x) 3360 return vjp_fn(x)[0].sum() 3361 3362 y = grad(foo)(x) 3363 expected = grad(lambda x: (x * x.cos()).sum())(x) 3364 self.assertEqual(y, expected) 3365 3366 def test_vmap_grad(self, device): 3367 x = torch.randn(3, device=device) 3368 y = vmap(grad(torch.sin))(x) 3369 self.assertEqual(y, x.cos()) 3370 3371 def test_vmap_vmap(self, device): 3372 x = torch.randn(2, 3, device=device) 3373 y = vmap(vmap(torch.sin))(x) 3374 self.assertEqual(y, x.sin()) 3375 3376 def test_vmap_vjp(self, device): 3377 x = torch.randn(3, device=device) 3378 _, vjp_fn = vjp(torch.sin, x) 3379 3380 def foo(x): 3381 _, vjp_fn = vjp(torch.sin, x) 3382 return vjp_fn(x) 3383 3384 y = vmap(foo)(x) 3385 self.assertEqual(y, vjp_fn(x)) 3386 3387 # TODO: there's a very interesting error message when the following 3388 # is on CPU 3389 xs = torch.randn(5, 3, device=device) 3390 expected = torch.stack([vjp_fn(x)[0] for x in xs]) 3391 result = vmap(lambda x: vjp_fn(x)[0])(xs) 3392 self.assertEqual(result, expected) 3393 3394 def test_vjp_grad(self, device): 3395 x = torch.randn([], device=device) 3396 y, vjp_fn = vjp(grad(torch.sin), x) 3397 self.assertEqual(y, x.cos()) 3398 3399 v = torch.randn([]) 3400 self.assertEqual(vjp_fn(v)[0], -x.sin() * v) 3401 3402 def test_vjp_vmap(self, device): 3403 x = torch.randn(3, device=device) 3404 y, vjp_fn = vjp(vmap(torch.sin), x) 3405 self.assertEqual(y, x.sin()) 3406 3407 v = torch.randn(3, device=device) 3408 self.assertEqual(vjp_fn(v)[0], x.cos() * v) 3409 3410 def test_vjp_vjp(self, device): 3411 x = torch.randn(3, device=device) 3412 y, vjp_fn = vjp(torch.sin, x) 3413 self.assertEqual(y, x.sin()) 3414 3415 y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x) 3416 self.assertEqual(y, x * x.cos()) 3417 3418 y = vjp_fn(x)[0] 3419 # Honestly IDK what the result here is... but at least it runs 3420 3421 def test_make_fx_vmap(self, device): 3422 def f(x): 3423 return torch.sin(x) 3424 3425 inp = torch.randn(5, 3) 3426 f = vmap(f) 3427 fx_f = make_fx(f)(inp) 3428 new_inp = torch.randn(5, 3) 3429 self.assertEqual(fx_f(new_inp), f(new_inp)) 3430 3431 def test_make_fx_jacrev(self, device): 3432 def f(x): 3433 return x.sin().sum() 3434 3435 inp = torch.randn(3) 3436 f = jacrev(jacrev(f)) 3437 fx_f = make_fx(f)(inp) 3438 new_inp = torch.randn(3) 3439 self.assertEqual(fx_f(new_inp), f(new_inp)) 3440 3441 def test_make_fx_vjp(self, device): 3442 def f(x): 3443 return torch.sin(x).sum() 3444 3445 primals = torch.randn(3) 3446 _, vjp_fn = vjp(f, primals) 3447 cotangent = torch.randn(()) 3448 fx_f = make_fx(vjp_fn)(cotangent, True, True) 3449 new_cotangent = torch.randn(()) 3450 self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) 3451 3452 # FIXME: test fails in Windows 3453 @unittest.skipIf(IS_WINDOWS, "fails in Windows; needs investigation") 3454 @unittest.skipIf(IS_FBCODE, "can't subprocess in fbcode") 3455 # it is redundant to run this test twice on a machine that has GPUs 3456 @onlyCPU 3457 def test_no_warning_on_import_functorch(self, device): 3458 out = subprocess.check_output( 3459 [sys.executable, "-W", "always", "-c", "import functorch"], 3460 stderr=subprocess.STDOUT, 3461 cwd=os.path.dirname(os.path.realpath(__file__)), 3462 ).decode("utf-8") 3463 self.assertEqual(out, "") 3464 3465 def test_requires_grad_inside_transform(self, device): 3466 def f(x): 3467 x.requires_grad_() 3468 return x.sin().sum() 3469 3470 x = torch.randn(3) 3471 3472 with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): 3473 vmap(f)(x) 3474 with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): 3475 grad(f)(x) 3476 with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): 3477 vmap(grad(f))(x) 3478 3479 x = torch.randn([]) 3480 with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): 3481 grad(grad(f))(x) 3482 3483 def test_retain_grad_inside_transform(self, device): 3484 def f(x): 3485 y = x.sin() 3486 y.retain_grad() 3487 return y.sum() 3488 3489 x = torch.randn(3) 3490 3491 with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"): 3492 grad(f)(x) 3493 3494 def test_autograd_functional_jacrev_inside_transform(self, device): 3495 def f(x): 3496 y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x) 3497 return y 3498 3499 B = 5 3500 x = torch.randn(B, 3) 3501 with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): 3502 vmap(f)(x) 3503 3504 x = torch.randn([]) 3505 with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): 3506 grad(f)(x) 3507 3508 def test_autograd_functional_vjp_inside_transform(self, device): 3509 def f(x): 3510 y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x) 3511 return y 3512 3513 B = 5 3514 x = torch.randn(B, 3) 3515 with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): 3516 vmap(f)(x) 3517 3518 x = torch.randn([]) 3519 with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): 3520 grad(f)(x) 3521 3522 def test_autograd_functional_jvp_inside_transform(self, device): 3523 def f(x): 3524 t = torch.ones_like(x) 3525 y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,)) 3526 return y 3527 3528 B = 5 3529 x = torch.randn(B, 3) 3530 with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): 3531 vmap(f)(x) 3532 3533 x = torch.randn([]) 3534 with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): 3535 grad(f)(x) 3536 3537 def test_autograd_functional_jacfwd_inside_transform(self, device): 3538 def f(x): 3539 y = torch.autograd.functional.jacobian( 3540 lambda x: x.sin().sum(), x, strategy="forward-mode", vectorize=True 3541 ) 3542 return y 3543 3544 B = 5 3545 x = torch.randn(B, 3) 3546 with self.assertRaisesRegex( 3547 RuntimeError, "Batching rule not implemented for aten::_make_dual" 3548 ): 3549 vmap(f)(x) 3550 3551 @parametrize( 3552 "transform", 3553 [ 3554 "vmap", 3555 "grad", 3556 "jacrev", 3557 "jacfwd", 3558 "grad_and_value", 3559 "hessian", 3560 "functionalize", 3561 ], 3562 ) 3563 def test_autograd_function_no_setup_context(self, device, transform): 3564 class MySin(torch.autograd.Function): 3565 @staticmethod 3566 def forward(ctx, x): 3567 ctx.save_for_backward(x) 3568 return x.sin() 3569 3570 @staticmethod 3571 def backward(ctx, gy): 3572 (x,) = ctx.saved_tensors 3573 return gy * x.cos() 3574 3575 x = torch.randn(3, device=device) 3576 transform = getattr(functorch, transform) 3577 with self.assertRaisesRegex(RuntimeError, "must override the setup_context"): 3578 transform(MySin.apply)(x) 3579 3580 # Some of these pass, some of these don't 3581 @parametrize( 3582 "transform", 3583 [ 3584 "grad", 3585 "jacrev", 3586 "grad_and_value", 3587 "hessian", 3588 ], 3589 ) 3590 def test_transforms_dont_support_saved_tensor_hooks(self, device, transform): 3591 def f(x): 3592 return torch.sin(x).sum() 3593 3594 def g(x): 3595 with torch.autograd.graph.save_on_cpu(): 3596 return f(x) 3597 3598 x = torch.randn(3, device=device) 3599 3600 if transform == "functionalize": 3601 transform = functorch.experimental.functionalize 3602 else: 3603 transform = getattr(functorch, transform) 3604 with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): 3605 with torch.autograd.graph.save_on_cpu(): 3606 transform(f)(x) 3607 3608 with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): 3609 transform(g)(x) 3610 3611 def test_vjp_doesnt_support_saved_tensor_hooks(self, device): 3612 def f(x): 3613 return torch.sin(x).sum() 3614 3615 def g(x): 3616 with torch.autograd.graph.save_on_cpu(): 3617 return f(x) 3618 3619 x = torch.randn(3, device=device) 3620 with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): 3621 with torch.autograd.graph.save_on_cpu(): 3622 vjp(f, x) 3623 3624 with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): 3625 vjp(g, x) 3626 3627 def test_jvp_supports_saved_tensor_hooks(self, device): 3628 def f(x): 3629 return torch.sin(x).sum() 3630 3631 def g(x): 3632 with torch.autograd.graph.save_on_cpu(): 3633 return f(x) 3634 3635 x = torch.randn(3, device=device) 3636 t = torch.randn(3, device=device) 3637 3638 # smoke tests 3639 with torch.autograd.graph.save_on_cpu(): 3640 jvp(f, (x,), (t,)) 3641 3642 # smoke tests 3643 jvp(g, (x,), (t,)) 3644 3645 def test_can_use_functionalize_when_key_is_excluded(self, device): 3646 def f(x): 3647 y = x.clone() 3648 y.sin_() 3649 return y 3650 3651 x = torch.randn([], device=device) 3652 expected = f(x) 3653 3654 with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): 3655 gm = make_fx(functorch.functionalize(f))(x) 3656 self.assertTrue("sin_" not in gm.code) 3657 self.assertEqual(gm(x), expected) 3658 3659 local_exclude_set = torch._C._dispatch_tls_local_exclude_set() 3660 self.assertTrue(local_exclude_set.has(DispatchKey.Functionalize)) 3661 3662 def test_can_use_vmap_when_key_is_excluded(self, device): 3663 def f(x): 3664 return x.sum(0) 3665 3666 x = torch.randn(3, device=device) 3667 expected = vmap(f)(x) 3668 3669 with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched)): 3670 result = vmap(f)(x) 3671 self.assertEqual(result, expected) 3672 local_exclude_set = torch._C._dispatch_tls_local_exclude_set() 3673 self.assertTrue(local_exclude_set.has(DispatchKey.FuncTorchBatched)) 3674 3675 def test_can_use_grad_when_key_is_excluded(self, device): 3676 def f(x): 3677 return x.sin() 3678 3679 x = torch.randn([], device=device) 3680 expected = grad(f)(x) 3681 3682 with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd)): 3683 result = grad(f)(x) 3684 self.assertEqual(result, expected) 3685 local_exclude_set = torch._C._dispatch_tls_local_exclude_set() 3686 self.assertTrue(local_exclude_set.has(DispatchKey.Autograd)) 3687 3688 3689@markDynamoStrictTest 3690class TestMakeFunctional(TestCase): 3691 @parametrize("disable_autograd_tracking", [True, False]) 3692 def test_disable_autograd_tracking(self, disable_autograd_tracking): 3693 class Foo(nn.Module): 3694 def __init__(self) -> None: 3695 super().__init__() 3696 self.linear = nn.Linear(3, 3) 3697 3698 def forward(self, x): 3699 x = self.linear(x) 3700 return x 3701 3702 mod = Foo() 3703 _, params = make_functional( 3704 mod, disable_autograd_tracking=disable_autograd_tracking 3705 ) 3706 self.assertEqual(len(params), 2) 3707 for param in params: 3708 self.assertEqual(param.requires_grad, not disable_autograd_tracking) 3709 3710 def test_parameter_tying(self): 3711 class Foo(nn.Module): 3712 def __init__(self) -> None: 3713 super().__init__() 3714 self.bias = nn.Parameter(torch.randn(3)) 3715 self.linear = nn.Linear(3, 3) 3716 self.linear.bias = self.bias 3717 self.linear_tied = self.linear 3718 3719 def forward(self, x): 3720 x = self.linear(x) 3721 x = self.linear_tied(x) 3722 x = x + self.bias 3723 return x 3724 3725 torch.manual_seed(1) 3726 mod = Foo() 3727 func, _ = make_functional(mod) 3728 3729 torch.manual_seed(0) 3730 mod = Foo() 3731 _, params = make_functional(mod) 3732 self.assertEqual(len(params), 2) 3733 3734 x = torch.randn(2, 3) 3735 result = func(params, x) 3736 expected = mod(x) 3737 self.assertEqual(result, expected) 3738 3739 def test_buffer_tying(self): 3740 class Foo(nn.Module): 3741 def __init__(self) -> None: 3742 super().__init__() 3743 self.bias = nn.Parameter(torch.randn(3)) 3744 self.linear = nn.Linear(3, 3) 3745 self.buffer = nn.Buffer(torch.randn(3)) 3746 self.buffer_tied = self.buffer 3747 3748 def forward(self, x): 3749 x = self.linear(x) 3750 x = x + self.bias 3751 x = x + self.buffer 3752 x = x + self.buffer_tied 3753 return x 3754 3755 torch.manual_seed(1) 3756 mod = Foo() 3757 func, _, _ = make_functional_with_buffers(mod) 3758 3759 torch.manual_seed(0) 3760 mod = Foo() 3761 _, params, buffers = make_functional_with_buffers(mod) 3762 self.assertEqual(len(params), 3) 3763 self.assertEqual(len(buffers), 1) 3764 3765 x = torch.randn(2, 3) 3766 result = func(params, buffers, x) 3767 expected = mod(x) 3768 self.assertEqual(result, expected) 3769 3770 @parametrize("disable_autograd_tracking", [True, False]) 3771 def test_with_buffers_disable_autograd_tracking(self, disable_autograd_tracking): 3772 class Foo(nn.Module): 3773 def __init__(self) -> None: 3774 super().__init__() 3775 self.linear = nn.Linear(3, 3) 3776 self.buffer = nn.Buffer(torch.randn(3)) 3777 3778 def forward(self, x): 3779 x = self.linear(x) 3780 x = x + self.buffer 3781 return x 3782 3783 mod = Foo() 3784 _, params, buffers = make_functional_with_buffers( 3785 mod, disable_autograd_tracking=disable_autograd_tracking 3786 ) 3787 self.assertEqual(len(params), 2) 3788 self.assertEqual(len(buffers), 1) 3789 for param in params: 3790 self.assertEqual(param.requires_grad, not disable_autograd_tracking) 3791 3792 @parametrize("detach_params", [True, False]) 3793 def test_using_detach_functional_call(self, detach_params): 3794 class Foo(nn.Module): 3795 def __init__(self) -> None: 3796 super().__init__() 3797 self.linear = nn.Linear(3, 3) 3798 self.buffer = nn.Buffer(torch.randn(3)) 3799 3800 def forward(self, x): 3801 x = self.linear(x) 3802 x = x + self.buffer 3803 return x 3804 3805 def params_dict(mod): 3806 named_params = mod.named_parameters() 3807 return ( 3808 {k: v.detach() for k, v in named_params} 3809 if detach_params 3810 else dict(named_params) 3811 ) 3812 3813 mod = Foo() 3814 x = torch.randn(3, 3) 3815 d = (params_dict(mod), dict(mod.named_buffers())) 3816 out = functional_call(mod, d, x) 3817 self.assertEqual(out.grad_fn is None, detach_params) 3818 3819 def test_parameter_tying_grad(self): 3820 class Foo(nn.Module): 3821 def __init__(self) -> None: 3822 super().__init__() 3823 self.linear = nn.Linear(3, 3) 3824 self.weight = self.linear.weight 3825 self.bias = self.linear.bias 3826 3827 def forward(self, x): 3828 x = self.linear(x) 3829 x = F.linear(x, self.weight, self.bias) 3830 return x 3831 3832 x = torch.randn(2, 3) 3833 torch.manual_seed(0) 3834 mod = Foo() 3835 loss = mod(x).sum() 3836 expected = torch.autograd.grad(loss, mod.parameters()) 3837 3838 mod = Foo() 3839 fmod, _, _ = make_functional_with_buffers(mod) 3840 torch.manual_seed(0) 3841 mod = Foo() 3842 _, params, buffers = make_functional_with_buffers(mod) 3843 3844 def compute_loss(params, buffers, x): 3845 return fmod(params, buffers, x).sum() 3846 3847 result = grad(compute_loss)(params, buffers, x) 3848 3849 self.assertEqual(result, expected) 3850 3851 def test_parameter_tying_ensemble(self): 3852 class Foo(nn.Module): 3853 def __init__(self) -> None: 3854 super().__init__() 3855 self.linear = nn.Linear(3, 3) 3856 self.weight = self.linear.weight 3857 self.bias = self.linear.bias 3858 self.buffer = nn.Buffer(torch.randn(3)) 3859 self.buffer_tied = self.buffer 3860 3861 def forward(self, x): 3862 x = self.linear(x) 3863 x = F.linear(x, self.weight, self.bias) 3864 x = x + self.buffer 3865 x = x + self.buffer_tied 3866 return x 3867 3868 num_models = 2 3869 xs = torch.randn(num_models, 64, 3) 3870 models = [Foo() for _ in range(num_models)] 3871 fmodel, _, _ = combine_state_for_ensemble(models) 3872 3873 torch.manual_seed(0) 3874 models = [Foo() for _ in range(num_models)] 3875 _, params, buffers = combine_state_for_ensemble(models) 3876 result = vmap(fmodel)(params, buffers, xs) 3877 3878 torch.manual_seed(0) 3879 models = [Foo() for _ in range(num_models)] 3880 expected = torch.stack([model(x) for model, x in zip(models, xs)]) 3881 3882 self.assertEqual(result, expected) 3883 3884 @parametrize("mechanism", ["make_functional", "functional_call"]) 3885 def test_correctness_mnist(self, mechanism): 3886 class Net(nn.Module): 3887 def __init__(self) -> None: 3888 super().__init__() 3889 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 3890 self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 3891 self.conv2_drop = nn.Dropout2d() 3892 self.fc1 = nn.Linear(320, 50) 3893 self.fc2 = nn.Linear(50, 10) 3894 3895 def forward(self, x): 3896 x = F.relu(F.max_pool2d(self.conv1(x), 2)) 3897 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 3898 x = x.view(-1, 320) 3899 x = F.relu(self.fc1(x)) 3900 x = F.dropout(x, training=self.training) 3901 x = self.fc2(x) 3902 return F.log_softmax(x) 3903 3904 x = torch.randn(64, 1, 32, 32) 3905 torch.manual_seed(301) 3906 fnet, _ = _get_weights_and_functional_call(Net(), mechanism) 3907 3908 torch.manual_seed(0) 3909 _, params = _get_weights_and_functional_call(Net(), mechanism) 3910 result = fnet(params, x) 3911 3912 torch.manual_seed(0) 3913 net = Net() 3914 expected = net(x) 3915 3916 self.assertEqual(result, expected) 3917 3918 def test_combine_state_for_ensemble_error(self): 3919 in_features = 2 3920 out_features = 2 3921 3922 models = [] 3923 with self.assertRaisesRegex(RuntimeError, "Expected at least one model"): 3924 _ = combine_state_for_ensemble(models) 3925 3926 num_models = 3 3927 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3928 models[1].eval() 3929 with self.assertRaisesRegex(RuntimeError, "same training/eval mode"): 3930 _ = combine_state_for_ensemble(models) 3931 3932 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3933 models[1] = torch.nn.Conv2d(3, 3, (3, 3)) 3934 with self.assertRaisesRegex(RuntimeError, "models to be of the same class"): 3935 _ = combine_state_for_ensemble(models) 3936 3937 def test_combine_state_for_ensemble_smoke(self): 3938 in_features = 2 3939 out_features = 2 3940 num_models = 3 3941 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3942 _ = combine_state_for_ensemble(models) 3943 3944 def test_stack_module_state_smoke(self): 3945 in_features = 2 3946 out_features = 2 3947 num_models = 3 3948 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3949 _ = stack_module_state(models) 3950 3951 def test_stack_module_state_leaf(self): 3952 in_features = 2 3953 out_features = 2 3954 num_models = 3 3955 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3956 params, buffers = stack_module_state(models) 3957 for param in params.values(): 3958 self.assertTrue(param.requires_grad) 3959 self.assertTrue(param.is_leaf) 3960 3961 def test_stack_module_state_mismatch_error(self): 3962 in_features = 2 3963 out_features = 2 3964 num_models = 3 3965 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3966 models[0].weight.requires_grad_(False) 3967 with self.assertRaisesRegex(RuntimeError, "same .requires_grad"): 3968 params, buffers = stack_module_state(models) 3969 3970 def test_stack_module_state_error(self): 3971 in_features = 2 3972 out_features = 2 3973 3974 models = [] 3975 with self.assertRaisesRegex( 3976 RuntimeError, "stack_module_state:.* Expected at least one model" 3977 ): 3978 _ = stack_module_state(models) 3979 3980 num_models = 3 3981 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3982 models[1].eval() 3983 with self.assertRaisesRegex( 3984 RuntimeError, "stack_module_state:.* same training/eval mode." 3985 ): 3986 _ = stack_module_state(models) 3987 3988 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 3989 models[1] = torch.nn.Conv2d(3, 3, (3, 3)) 3990 with self.assertRaisesRegex( 3991 RuntimeError, "stack_module_state:.* models to be of the same class" 3992 ): 3993 _ = stack_module_state(models) 3994 3995 @parametrize("mechanism", ["make_functional", "functional_call"]) 3996 def test_make_functional_state_correctly_returned_after_forward(self, mechanism): 3997 class Net(nn.Module): 3998 def __init__(self) -> None: 3999 super().__init__() 4000 self.linear = nn.Linear(3, 3) 4001 4002 def forward(self, x): 4003 x = self.linear(x) 4004 return x 4005 4006 def get_module_info(mod): 4007 if mechanism == "make_functional": 4008 return make_functional(mod) 4009 else: 4010 assert mechanism == "functional_call" 4011 return mod, dict(mod.named_parameters()) 4012 4013 mod = Net() 4014 func_mod, params = get_module_info(mod) 4015 4016 # state in func.names_map 4017 mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod 4018 old_state_linear_weight = mod.linear.weight 4019 old_state_linear_bias = mod.linear.bias 4020 4021 self.assertIsNotNone(old_state_linear_weight) 4022 self.assertIsNotNone(old_state_linear_bias) 4023 4024 x = torch.randn(4, 3) 4025 if mechanism == "make_functional": 4026 func_mod(params, x) 4027 else: 4028 assert mechanism == "functional_call" 4029 functional_call(func_mod, params, x) 4030 4031 mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod 4032 new_state_linear_weight = mod.linear.weight 4033 new_state_linear_bias = mod.linear.bias 4034 4035 self.assertIsNotNone(new_state_linear_weight) 4036 self.assertIsNotNone(new_state_linear_bias) 4037 4038 self.assertEqual(old_state_linear_weight, new_state_linear_weight) 4039 self.assertEqual(old_state_linear_bias, new_state_linear_bias) 4040 4041 4042@markDynamoStrictTest 4043class TestExamplesCorrectness(TestCase): 4044 def _update_params(self, params, grads, alpha, mechanism): 4045 if mechanism == "make_functional": 4046 return [(params[i] - alpha * grads[i]) for i in range(len(params))] 4047 else: 4048 assert mechanism == "functional_call" 4049 return {k: params[k] - alpha * grads[k] for k in params} 4050 4051 @parametrize("mechanism", ["make_functional", "functional_call"]) 4052 def test_maml_regression(self, device, mechanism): 4053 class ThreeLayerNet(nn.Module): 4054 def __init__(self) -> None: 4055 super().__init__() 4056 self.fc1 = nn.Linear(1, 40) 4057 self.relu1 = nn.ReLU() 4058 self.fc2 = nn.Linear(40, 40) 4059 self.relu2 = nn.ReLU() 4060 self.fc3 = nn.Linear(40, 1) 4061 4062 def forward(self, x): 4063 x = self.fc1(x) 4064 x = self.relu1(x) 4065 x = self.fc2(x) 4066 x = self.relu2(x) 4067 x = self.fc3(x) 4068 return x 4069 4070 # TODO: should replace with F.mse_loss 4071 def mse_loss(x, y): 4072 return torch.mean((x - y) ** 2) 4073 4074 net, params = _get_weights_and_functional_call( 4075 ThreeLayerNet().to(device), mechanism 4076 ) 4077 K = 20 4078 num_tasks = 4 4079 alpha = 0.1 4080 4081 def sample_tasks(outer_batch_size, inner_batch_size): 4082 # Select amplitude and phase for the task 4083 As = [] 4084 phases = [] 4085 for _ in range(outer_batch_size): 4086 As.append(np.random.uniform(low=0.1, high=0.5)) 4087 phases.append(np.random.uniform(low=0.0, high=np.pi)) 4088 4089 def get_batch(): 4090 xs, ys = [], [] 4091 for A, phase in zip(As, phases): 4092 x = np.random.uniform( 4093 low=-5.0, high=5.0, size=(inner_batch_size, 1) 4094 ) 4095 y = A * np.sin(x + phase) 4096 xs.append(x) 4097 ys.append(y) 4098 return torch.tensor(xs, dtype=torch.float, device=device), torch.tensor( 4099 ys, dtype=torch.float, device=device 4100 ) 4101 4102 x1, y1 = get_batch() 4103 x2, y2 = get_batch() 4104 return x1, y1, x2, y2 4105 4106 def get_loss_for_task(use_transform, x1, y1, x2, y2): 4107 def inner_loss(params, x1, y1): 4108 f = net(params, x1) 4109 loss = mse_loss(f, y1) 4110 return loss 4111 4112 if use_transform: 4113 grads = grad(inner_loss)(params, x1, y1) 4114 else: 4115 loss = inner_loss(params, x1, y1) 4116 grad_params, spec = tree_flatten(params) 4117 grads = torch.autograd.grad(loss, grad_params, create_graph=True) 4118 grads = tree_unflatten(grads, spec) 4119 4120 new_params = self._update_params(params, grads, alpha, mechanism) 4121 4122 v_f = net(new_params, x2) 4123 return mse_loss(v_f, y2) 4124 4125 task = sample_tasks(num_tasks, K) 4126 list_params = ( 4127 params if mechanism == "make_functional" else list(params.values()) 4128 ) 4129 4130 # Compute with vmap+grad 4131 inner_losses = vmap(partial(get_loss_for_task, True))( 4132 task[0], task[1], task[2], task[3] 4133 ) 4134 loss2 = sum(inner_losses) / len(inner_losses) 4135 result_grads = torch.autograd.grad(loss2, list_params) 4136 4137 # Compute without vmap+grad 4138 inner_losses = [ 4139 get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i]) 4140 for i in range(num_tasks) 4141 ] 4142 loss2 = sum(inner_losses) / len(inner_losses) 4143 expected_grads = torch.autograd.grad(loss2, list_params) 4144 4145 self.assertEqual(result_grads, expected_grads) 4146 4147 @parametrize("mechanism", ["make_functional", "functional_call"]) 4148 def test_maml_omniglot(self, device, mechanism): 4149 # TODO: there appears to be precision issues for float32 4150 dtype = torch.double 4151 4152 # TODO: We don't support inplace relu? 4153 inplace_relu = False 4154 n_way = 5 4155 n_inner_iter = 2 4156 num_tasks = 2 4157 4158 # real example uses batch norm but it's numerically unstable in the first 4159 # iteration, when near 0, and won't produce same gradients. Uses group norm instead 4160 net = ( 4161 nn.Sequential( 4162 nn.Conv2d(1, 64, 3), 4163 nn.GroupNorm(64, 64, affine=True), 4164 nn.ReLU(inplace=inplace_relu), 4165 nn.MaxPool2d(2, 2), 4166 nn.Conv2d(64, 64, 3), 4167 nn.GroupNorm(64, 64, affine=True), 4168 nn.ReLU(inplace=inplace_relu), 4169 nn.MaxPool2d(2, 2), 4170 nn.Conv2d(64, 64, 3), 4171 nn.GroupNorm(64, 64, affine=True), 4172 nn.ReLU(inplace=inplace_relu), 4173 nn.MaxPool2d(2, 2), 4174 nn.Flatten(), 4175 nn.Linear(64, n_way), 4176 ) 4177 .to(device) 4178 .to(dtype) 4179 ) 4180 4181 fnet, params, buffers = _get_weights_and_functional_call_with_buffers( 4182 net, mechanism 4183 ) 4184 net = (params, buffers, fnet) 4185 4186 def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry): 4187 params, buffers, fnet = net 4188 querysz = x_qry.size(0) 4189 4190 def compute_loss(new_params, buffers, x, y): 4191 logits = fnet(new_params, buffers, x) 4192 loss = F.cross_entropy(logits, y) 4193 return loss 4194 4195 new_params = params 4196 for _ in range(n_inner_iter): 4197 if use_transform: 4198 grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) 4199 else: 4200 res = compute_loss(new_params, buffers, x_spt, y_spt) 4201 grad_params, spec = tree_flatten(new_params) 4202 grads = torch.autograd.grad(res, grad_params, create_graph=True) 4203 grads = tree_unflatten(grads, spec) 4204 4205 new_params = self._update_params(new_params, grads, 1e-1, mechanism) 4206 4207 qry_logits = fnet(new_params, buffers, x_qry) 4208 qry_loss = F.cross_entropy(qry_logits, y_qry) 4209 qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz 4210 4211 return qry_loss, qry_acc 4212 4213 # Get some sample inputs... 4214 x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device) 4215 y_spt = torch.randint(0, 5, (num_tasks, 25), device=device) 4216 x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype, device=device) 4217 y_qry = torch.randint(0, 5, (num_tasks, 75), device=device) 4218 4219 # compute with vmap + grad 4220 compute_loss = partial(loss_for_task, net, n_inner_iter, True) 4221 qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry) 4222 list_params = ( 4223 params if mechanism == "make_functional" else list(params.values()) 4224 ) 4225 result_grads = torch.autograd.grad(qry_losses.sum(), list_params) 4226 4227 # compute without vmap + grad 4228 compute_loss = partial(loss_for_task, net, n_inner_iter, False) 4229 losses = [ 4230 compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0] 4231 for i in range(num_tasks) 4232 ] 4233 expected_grads = torch.autograd.grad(sum(losses), list_params) 4234 4235 self.assertEqual(result_grads, expected_grads) 4236 4237 @parametrize("mechanism", ["make_functional", "functional_call"]) 4238 @parametrize("originally_track_running_stats", [True, False]) 4239 def test_update_batch_norm(self, device, originally_track_running_stats, mechanism): 4240 dtype = torch.double 4241 inplace_relu = False 4242 classes = 5 4243 num_batches = 2 4244 net = ( 4245 nn.Sequential( 4246 nn.Conv2d(64, 64, 3), 4247 nn.BatchNorm2d( 4248 64, affine=True, track_running_stats=originally_track_running_stats 4249 ), 4250 nn.ReLU(inplace=inplace_relu), 4251 nn.Flatten(), 4252 nn.Linear(43264, classes), 4253 ) 4254 .to(device) 4255 .to(dtype) 4256 ) 4257 4258 replace_all_batch_norm_modules_(net) 4259 transformed_net = net 4260 fnet, params, buffers = _get_weights_and_functional_call_with_buffers( 4261 transformed_net, mechanism 4262 ) 4263 criterion = nn.CrossEntropyLoss() 4264 4265 def compute_loss(x, y, params, buffers): 4266 return criterion(fnet(params, buffers, x), y) 4267 4268 # Get some sample inputs... 4269 x = torch.randn(num_batches, 1, 64, 28, 28, device=device, dtype=dtype) 4270 y = torch.randint(0, classes, (num_batches, 1), device=device) 4271 4272 # compute some per sample grads with vmap + grad 4273 result_grads = vmap(grad(compute_loss, argnums=2), in_dims=(0, 0, None, None))( 4274 x, y, params, buffers 4275 ) 4276 4277 # compute some per sample grads without vmap + grad 4278 fnet, params, buffers = _get_weights_and_functional_call_with_buffers( 4279 transformed_net, mechanism 4280 ) 4281 flat_params, spec = tree_flatten(params) 4282 expected_grads = [ 4283 torch.autograd.grad(compute_loss(x[i], y[i], params, buffers), flat_params) 4284 for i in range(num_batches) 4285 ] 4286 expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] 4287 expected_grads = tree_unflatten(expected_grads, spec) 4288 4289 self.assertEqual(result_grads, expected_grads) 4290 4291 @parametrize("jac", ["jacfwd", "jacrev"]) 4292 def test_lennard_jones_batched_jac(self, device, jac): 4293 sigma = 0.5 4294 epsilon = 4.0 4295 4296 jac = getattr(functorch, jac) 4297 4298 def lennard_jones(r): 4299 return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6) 4300 4301 def lennard_jones_force(r): 4302 """Get magnitude of LJ force""" 4303 return -epsilon * ( 4304 (-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7) 4305 ) 4306 4307 r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device) 4308 drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device)) 4309 norms = torch.norm(drs, dim=1).reshape(-1, 1) 4310 training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1) 4311 training_forces = torch.stack( 4312 [force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)] 4313 ) 4314 4315 model = nn.Sequential( 4316 nn.Linear(1, 16), 4317 nn.Tanh(), 4318 nn.Linear(16, 16), 4319 nn.Tanh(), 4320 nn.Linear(16, 16), 4321 nn.Tanh(), 4322 nn.Linear(16, 16), 4323 nn.Tanh(), 4324 nn.Linear(16, 1), 4325 ).to(device) 4326 4327 def make_prediction(model, drs, use_functorch): 4328 norms = torch.norm(drs, dim=1).reshape(-1, 1) 4329 energies = model(norms) 4330 4331 if use_functorch: 4332 network_derivs = vmap(jac(model))(norms).squeeze(-1) 4333 forces = -network_derivs * drs / norms 4334 else: 4335 forces = [] 4336 for r, dr in zip(norms, drs): 4337 network_deriv = torch.autograd.functional.jacobian( 4338 model, r, create_graph=True 4339 ) 4340 force = -network_deriv * dr / r 4341 forces.append(force) 4342 forces = torch.cat(forces) 4343 return energies, forces 4344 4345 def loss_fn(energies, forces, predicted_energies, predicted_forces): 4346 return ( 4347 F.mse_loss(energies, predicted_energies) 4348 + 0.01 * F.mse_loss(forces, predicted_forces) / 3 4349 ) 4350 4351 energies, forces = make_prediction(model, drs, use_functorch=True) 4352 loss = loss_fn(training_energies, training_forces, energies, forces) 4353 result = torch.autograd.grad(loss, model.parameters()) 4354 4355 energies, forces = make_prediction(model, drs, use_functorch=False) 4356 loss = loss_fn(training_energies, training_forces, energies, forces) 4357 expected = torch.autograd.grad(loss, model.parameters()) 4358 4359 self.assertEqual(result, expected) 4360 4361 @parametrize("mechanism", ["make_functional", "functional_call"]) 4362 def test_ensemble_regression(self, device, mechanism): 4363 def make_spirals(n_samples, noise_std=0.0, rotations=1.0): 4364 ts = torch.linspace(0, 1, n_samples) 4365 rs = ts**0.5 4366 thetas = rs * rotations * 2 * math.pi 4367 signs = torch.randint(0, 2, (n_samples,)) * 2 - 1 4368 labels = (signs > 0).to(torch.long) 4369 4370 xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std 4371 ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std 4372 points = torch.stack([xs, ys], dim=1) 4373 return points.to(device), labels.to(device) 4374 4375 points, labels = make_spirals(100, noise_std=0.05) 4376 4377 class MLPClassifier(nn.Module): 4378 def __init__(self, hidden_dim=32, n_classes=2): 4379 super().__init__() 4380 self.hidden_dim = hidden_dim 4381 self.n_classes = n_classes 4382 4383 self.fc1 = nn.Linear(2, self.hidden_dim) 4384 self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) 4385 4386 def forward(self, x): 4387 x = self.fc1(x) 4388 x = F.relu(x) 4389 x = self.fc2(x) 4390 x = F.log_softmax(x, -1) 4391 return x 4392 4393 loss_fn = nn.NLLLoss() 4394 4395 func_model, weights = _get_weights_and_functional_call( 4396 MLPClassifier().to(device), mechanism 4397 ) 4398 4399 def train_step_fn(use_transform, weights, batch, targets, lr=0.2): 4400 def compute_loss(weights, batch, targets): 4401 output = func_model(weights, batch) 4402 loss = loss_fn(output, targets) 4403 return loss 4404 4405 if use_transform: 4406 grad_weights, loss = grad_and_value(compute_loss)( 4407 weights, batch, targets 4408 ) 4409 else: 4410 loss = compute_loss(weights, batch, targets) 4411 flat_weights, spec = tree_flatten(weights) 4412 flat_grad_weights = torch.autograd.grad(loss, flat_weights) 4413 grad_weights = tree_unflatten(flat_grad_weights, spec) 4414 4415 new_weights = self._update_params(weights, grad_weights, lr, mechanism) 4416 return (loss, new_weights) 4417 4418 def unpack(train_result): 4419 return train_result[0], train_result[1] 4420 4421 def init_fn(num_models): 4422 models = tuple(MLPClassifier().to(device) for _ in range(num_models)) 4423 if mechanism == "make_functional": 4424 return combine_state_for_ensemble(models)[1] 4425 else: 4426 return stack_module_state(models)[0] 4427 4428 def slice_weights(batched_weights, index): 4429 return tree_map( 4430 lambda weight: weight[index].detach().requires_grad_(), batched_weights 4431 ) 4432 4433 batched_weights = init_fn(num_models=2) 4434 parallel_train_step_fn = vmap( 4435 partial(train_step_fn, True), in_dims=(0, None, None) 4436 ) 4437 4438 result_loss, result_weights = unpack( 4439 parallel_train_step_fn(batched_weights, points, labels) 4440 ) 4441 4442 loss0, weights0 = unpack( 4443 train_step_fn(False, slice_weights(batched_weights, 0), points, labels) 4444 ) 4445 loss1, weights1 = unpack( 4446 train_step_fn(False, slice_weights(batched_weights, 1), points, labels) 4447 ) 4448 expected_loss = torch.stack([loss0, loss1]) 4449 4450 weights0, spec0 = tree_flatten(weights0) 4451 weights1, spec1 = tree_flatten(weights1) 4452 assert spec0 == spec1 4453 expected_weights = tuple( 4454 torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1) 4455 ) 4456 expected_weights = tree_unflatten(expected_weights, spec0) 4457 4458 self.assertEqual(result_loss, expected_loss) 4459 self.assertEqual(result_weights, expected_weights) 4460 4461 @parametrize( 4462 "dropout_layer", 4463 [ 4464 subtest(nn.Dropout, "Dropout"), 4465 subtest(nn.AlphaDropout, "AlphaDropout"), 4466 subtest(nn.FeatureAlphaDropout, "FeatureAlphaDropout"), 4467 ], 4468 ) 4469 @parametrize("mechanism", ["make_functional", "functional_call"]) 4470 def test_find_learning_rate_ensembling(self, device, dropout_layer, mechanism): 4471 # This example mimics what a user might do when trying to find the optimal learning rate. They would 4472 # want to run a bunch of models with the same behavior (including the same dropout!) and have them 4473 # each run with different learning rates. Specifically, this is an example of using same randomness with vmap 4474 points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint( 4475 0, 2, (100,), device=device 4476 ) 4477 4478 class MLPClassifier(nn.Module): 4479 def __init__(self, hidden_dim=32, n_classes=2): 4480 super().__init__() 4481 self.hidden_dim = hidden_dim 4482 self.n_classes = n_classes 4483 4484 self.dropout = dropout_layer() 4485 self.fc1 = nn.Linear(16, self.hidden_dim) 4486 self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) 4487 4488 def forward(self, x): 4489 x = self.dropout(x) 4490 x = torch.flatten(x, start_dim=1) 4491 x = self.fc1(x) 4492 x = F.relu(x) 4493 x = self.fc2(x) 4494 x = F.log_softmax(x, -1) 4495 return x 4496 4497 loss_fn = nn.NLLLoss() 4498 4499 func_model, weights = _get_weights_and_functional_call( 4500 MLPClassifier().to(device), mechanism 4501 ) 4502 4503 def train_step_fn(weights, batch, targets, lr): 4504 def compute_loss(weights, batch, targets): 4505 output = func_model(weights, batch) 4506 loss = loss_fn(output, targets) 4507 return loss 4508 4509 grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) 4510 new_weights = self._update_params(weights, grad_weights, lr, mechanism) 4511 if mechanism != "make_functional": 4512 new_weights = list(new_weights.values()) 4513 # NB: return looks weird because torch.vmap must return Tensors 4514 return (loss, *new_weights) 4515 4516 def unpack(train_result): 4517 return train_result[0], train_result[1:] 4518 4519 def init_fn(num_models): 4520 og_model = MLPClassifier().to(device) 4521 models = tuple( 4522 copy.deepcopy(og_model) for _ in range(num_models) 4523 ) # have same initialization 4524 if mechanism == "make_functional": 4525 return combine_state_for_ensemble(models)[1] 4526 else: 4527 return stack_module_state(models)[0] 4528 4529 batched_weights = init_fn(num_models=2) 4530 parallel_train_step_fn = vmap( 4531 train_step_fn, in_dims=(0, None, None, 0), randomness="same" 4532 ) 4533 4534 lrs = torch.tensor([0.2, 0.4], device=device) 4535 result_loss, result_weights = unpack( 4536 parallel_train_step_fn(batched_weights, points, labels, lrs) 4537 ) 4538 4539 self.assertEqual(result_loss[0], result_loss[1]) 4540 self.assertNotEqual( 4541 tuple(weight[0] for weight in result_weights), 4542 tuple(weight[1] for weight in result_weights), 4543 ) 4544 4545 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 4546 @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") 4547 @parametrize("mechanism", ["make_functional", "functional_call"]) 4548 def test_resnet18_per_sample_grads(self, device, mechanism): 4549 import torchvision.models as models 4550 4551 model = models.__dict__["resnet18"]( 4552 pretrained=False, norm_layer=(lambda c: nn.GroupNorm(min(32, c), c)) 4553 ).to(device) 4554 criterion = nn.CrossEntropyLoss( 4555 reduction="sum" 4556 ) # avoid cross batch reductions for for loop comparison 4557 4558 func_model, weights = _get_weights_and_functional_call(model, mechanism) 4559 4560 def compute_loss(weights, image, target): 4561 image = image.unsqueeze(0) 4562 target = target.unsqueeze(0) 4563 output = func_model(weights, image) 4564 loss = criterion(output, target) 4565 return loss 4566 4567 batch_size = 3 4568 images = torch.randn(batch_size, 3, 32, 32, device=device) 4569 targets = torch.randint(0, 10, (batch_size,), device=device) 4570 4571 result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))( 4572 weights, images, targets 4573 ) 4574 4575 flat_weights, spec = tree_flatten(weights) 4576 expected_grads = [ 4577 torch.autograd.grad( 4578 compute_loss(weights, images[i], targets[i]), flat_weights 4579 ) 4580 for i in range(batch_size) 4581 ] 4582 expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] 4583 expected_grads = tree_unflatten(expected_grads, spec) 4584 4585 self.assertEqual(result_grads, expected_grads, atol=1e-3, rtol=1.0) 4586 4587 4588def normalize_devices(fx_g): 4589 for node in fx_g.graph.nodes: 4590 args = list(node.args) 4591 for idx, arg in enumerate(args): 4592 if isinstance(arg, torch.device): 4593 args[idx] = "cpu" 4594 node.args = tuple(args) 4595 new_kwargs = {} 4596 for k, v in node.kwargs.items(): 4597 if isinstance(v, torch.device): 4598 v = "cpu" 4599 new_kwargs[k] = v 4600 node.kwargs = new_kwargs 4601 fx_g.recompile() 4602 return fx_g 4603 4604 4605@markDynamoStrictTest 4606class TestFunctionalize(TestCase): 4607 def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False): 4608 inpt1 = inpt.clone() 4609 inpt2 = inpt.clone() 4610 inpt3 = inpt.clone() 4611 4612 expected_outputs = f(inpt1) 4613 if skip_vmap: 4614 actual_outputs = functionalize(f)(inpt2) 4615 else: 4616 actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze() 4617 # Right now the flavor of functionalize that also removes view ops 4618 # isn't being used with vmap 4619 # That's because {view}_copy ops don't have batching rules yet 4620 # (although we should probably fix that) 4621 actual_outputs_view_copy = functionalize(f, remove="mutations_and_views")(inpt3) 4622 # Check that outputs are the same 4623 self.assertEqual(actual_outputs, expected_outputs) 4624 self.assertEqual(actual_outputs_view_copy, expected_outputs) 4625 4626 # Inputs might have been mutated by f: check that they were mutated properly 4627 self.assertEqual(inpt1, inpt2) 4628 self.assertEqual(inpt1, inpt3) 4629 4630 def test_simple_view(self, device): 4631 def f(x: torch.Tensor) -> torch.Tensor: 4632 tmp = torch.ones(2, device=device) 4633 y = x.view(4, 2) 4634 y.add_(tmp) 4635 return x 4636 4637 self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) 4638 4639 def test_multioutput_view(self, device): 4640 def f(x: torch.Tensor) -> torch.Tensor: 4641 tmp = torch.ones(2, device=device) 4642 y1, y2 = x.split(2) 4643 y1_view = y1.diagonal() 4644 y1_view.add_(tmp) 4645 return x 4646 4647 self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) 4648 4649 def test_inplace_view(self, device): 4650 def f(x: torch.Tensor) -> torch.Tensor: 4651 tmp = torch.ones(4, device=device) 4652 y = x + x 4653 y2 = y.transpose(1, 0) 4654 z = y2[0] 4655 z.add_(tmp) 4656 return y 4657 4658 self._check_functionalize_correctness( 4659 f, torch.zeros(4, 2, device=device), skip_vmap=True 4660 ) 4661 4662 # See https://github.com/pytorch/functorch/issues/780 4663 def test_linear(self, device): 4664 def f(x, y, z) -> torch.Tensor: 4665 return torch._C._nn.linear(x, y, z) 4666 4667 x = torch.randn(14, 1, 384, device=device) 4668 y = torch.randn(96, 384, device=device) 4669 z = torch.randn(96, device=device) 4670 4671 out_expected = f(x, y, z) 4672 out_actual = functionalize(f)(x, y, z) 4673 self.assertEqual(out_expected, out_actual) 4674 4675 def test_multioutput_inplace_slice_view(self, device): 4676 def f(x: torch.Tensor) -> torch.Tensor: 4677 tmp = torch.ones(2, 2, device=device) 4678 y = x.view(8) 4679 z0 = y.reshape(2, 4) 4680 z1 = z0.transpose(1, 0) 4681 z1.unsqueeze_(0) 4682 z1.squeeze_() 4683 z2, z3 = z1.split(2) 4684 z2.add_(tmp) 4685 return x 4686 4687 # See Note [Fix vmap slice_scatter] 4688 self._check_functionalize_correctness( 4689 f, torch.zeros(4, 2, device=device), skip_vmap=True 4690 ) 4691 4692 # Ensure functionalize works with List[Optional[Tensor]] arguments. 4693 # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085 4694 def test_functionalize_opt_tensor_list(self, device): 4695 def f(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: 4696 return x[indices] 4697 4698 inpta = torch.ones(4, device=device) 4699 inptb = torch.arange(2, device=device) 4700 out1 = f(inpta, inptb) 4701 out2 = functionalize(f)(inpta, inptb) 4702 self.assertEqual(out1, out2) 4703 out = make_fx(functionalize(f))(inpta, inptb) 4704 self.assertExpectedInline( 4705 (out.code), 4706 """\ 4707 4708 4709 4710def forward(self, x_1, indices_1) -> torch.Tensor: 4711 index = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None 4712 return index 4713 """, 4714 ) 4715 4716 # Ensure grad(functionalize(f)) works 4717 def test_functionalize_grad(self, device): 4718 def f(x: torch.Tensor) -> torch.Tensor: 4719 tmp = torch.ones(2, device=device) 4720 y = x + x 4721 z = y.view(4, 2) 4722 y.add_(tmp) 4723 return z.sum() 4724 4725 inpt1 = torch.ones(4, 2, device=device) 4726 inpt2 = torch.ones(4, 2, device=device) 4727 out1 = grad(f)(inpt1) 4728 out2 = grad(functionalize(f))(inpt2) 4729 self.assertEqual(out1, out2) 4730 self.assertEqual(inpt1, inpt2) 4731 4732 @unittest.skipIf(IS_FBCODE, "fails in fbcode") 4733 def test_vmap_functionalize_jvp(self, device): 4734 def f(x: torch.Tensor) -> torch.Tensor: 4735 y = x + x 4736 z = y.view(-1) 4737 y.add_(1) 4738 return z 4739 4740 def jvp_wrapper(x, t): 4741 return jvp( 4742 f, 4743 (x,), 4744 (t,), 4745 ) 4746 4747 x = torch.randn(2, 3, device=device) 4748 t = torch.randn(2, 3, device=device) 4749 4750 out1 = vmap(jvp_wrapper)(x, t) 4751 out2 = vmap(functionalize(jvp_wrapper))(x, t) 4752 self.assertEqual(out1, out2) 4753 4754 # TODO: move this test into test_fake_tensor.py 4755 # once functionalize() can be used in core tests. 4756 def test_functionalize_fake_tensors(self, device): 4757 def f(x: torch.Tensor) -> torch.Tensor: 4758 y = x.detach() 4759 return y + y 4760 4761 with FakeTensorMode() as mode: 4762 x = torch.ones(2, device=device, requires_grad=True) 4763 out = functionalize(f)(x) 4764 self.assertEqual(x.size(), (2,)) 4765 4766 def test_functionalize_fx_simple(self, device): 4767 def f(x: torch.Tensor) -> torch.Tensor: 4768 tmp = torch.ones(2, device=device) 4769 y = x.view(4, 2) 4770 y.add_(tmp) 4771 return x 4772 4773 # There's a copy_ in the graph, because the input (x) was mutated. 4774 # To preserve semantics, functionalize() needs to propagate the mutation. 4775 fn = make_fx(functionalize(f, remove="mutations_and_views")) 4776 out = fn(torch.zeros(4, 2, device=device)) 4777 out = normalize_devices(out) 4778 self.assertExpectedInline( 4779 (out.code), 4780 """\ 4781 4782 4783 4784def forward(self, x_1) -> torch.Tensor: 4785 ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False) 4786 view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2]) 4787 add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None 4788 view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None 4789 view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]); view_copy_2 = None 4790 copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = copy_ = None 4791 return view_copy_1 4792 """, 4793 ) 4794 4795 def test_functionalize_fx_transpose_simple(self, device): 4796 def f(x: torch.Tensor) -> torch.Tensor: 4797 return x.transpose(1, 0) 4798 4799 fn = make_fx(functionalize(f, remove="mutations_and_views")) 4800 out = fn(torch.zeros(4, 2, device=device)) 4801 out = normalize_devices(out) 4802 self.assertExpectedInline( 4803 out.code, 4804 """\ 4805 4806 4807 4808def forward(self, x_1) -> torch.Tensor: 4809 transpose_copy = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None 4810 return transpose_copy 4811 """, 4812 ) 4813 4814 def test_functionalize_fx_out_op(self, device): 4815 def f(inpt: torch.Tensor) -> torch.Tensor: 4816 out = torch.empty((), dtype=torch.float32) 4817 torch.add(inpt, inpt, out=out) 4818 out_view = out.view(4) 4819 out_view.add_(1) 4820 return out 4821 4822 fn = make_fx(functionalize(f, remove="mutations_and_views")) 4823 out = fn(torch.arange(4, device=device, dtype=torch.float32)) 4824 out = normalize_devices(out) 4825 self.assertExpectedInline( 4826 out.code, 4827 """\ 4828 4829 4830 4831def forward(self, inpt_1) -> torch.Tensor: 4832 empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False); empty = None 4833 add = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None 4834 view_copy = torch.ops.aten.view_copy.default(add, [4]); view_copy = None 4835 view_copy_1 = torch.ops.aten.view_copy.default(add, [4]); add = None 4836 add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None 4837 view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]); add_1 = None 4838 view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]); view_copy_3 = None 4839 return view_copy_2 4840 """, 4841 ) 4842 4843 def test_functionalize_fx_multi_out_op(self, device): 4844 def f(inpt: torch.Tensor) -> torch.Tensor: 4845 mins = torch.empty(4, dtype=torch.float32) 4846 maxs = torch.empty(2, 2, dtype=torch.float32) 4847 maxs_view = maxs.view(4) 4848 inpt_view = inpt.view(2, 4) 4849 torch.aminmax(inpt_view, dim=0, out=(mins, maxs_view)) 4850 return (maxs, mins) 4851 4852 fn = make_fx(functionalize(f, remove="mutations_and_views")) 4853 out = fn(torch.arange(8, device=device, dtype=torch.float32)) 4854 out = normalize_devices(out) 4855 self.assertExpectedInline( 4856 out.code, 4857 """\ 4858 4859 4860 4861def forward(self, inpt_1) -> torch.Tensor: 4862 empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False); empty = None 4863 empty_1 = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = 'cpu', pin_memory = False) 4864 view_copy = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = view_copy = None 4865 view_copy_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None 4866 aminmax = torch.ops.aten.aminmax.default(view_copy_1, dim = 0); view_copy_1 = None 4867 getitem = aminmax[0] 4868 getitem_1 = aminmax[1]; aminmax = None 4869 view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None 4870 view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]); view_copy_3 = None 4871 return (view_copy_2, getitem) 4872 """, 4873 ) 4874 4875 def test_functionalize_fx_reapply_views_simple(self, device): 4876 def f(x: torch.Tensor) -> torch.Tensor: 4877 tmp = torch.ones(2, device=device) 4878 y = x.view(4, 2) 4879 y.add_(tmp) 4880 return x 4881 4882 out = make_fx(functionalize(f))(torch.zeros(4, 2, device=device)) 4883 out = normalize_devices(out) 4884 self.assertExpectedInline( 4885 out.code, 4886 """\ 4887 4888 4889 4890def forward(self, x_1) -> torch.Tensor: 4891 ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False) 4892 view = torch.ops.aten.view.default(x_1, [4, 2]) 4893 add = torch.ops.aten.add.Tensor(view, ones); view = ones = None 4894 view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None 4895 view_2 = torch.ops.aten.view.default(view_1, [4, 2]); view_2 = None 4896 copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = copy_ = None 4897 return view_1 4898 """, 4899 ) 4900 4901 def test_functionalize_nonfunctional_output(self, device): 4902 global_out = torch.ones(2, device=device) 4903 4904 def f() -> torch.Tensor: 4905 return global_out 4906 4907 out = make_fx(functionalize(f))() 4908 out = normalize_devices(out) 4909 self.assertExpectedInline( 4910 out.code, 4911 """\ 4912 4913 4914 4915def forward(self) -> torch.Tensor: 4916 _tensor_constant0 = self._tensor_constant0 4917 return _tensor_constant0 4918 """, 4919 ) 4920 4921 def test_functionalize_optional_tensorlist1(self, device): 4922 def f(a, b) -> torch.Tensor: 4923 # at::index has OptionalTensorList arguments, 4924 # test that here 4925 return a[b] 4926 4927 a = torch.arange(4).reshape(2, 2) 4928 b = torch.ones(2, dtype=torch.long) 4929 out = make_fx(functionalize(f))(a, b) 4930 out = normalize_devices(out) 4931 self.assertExpectedInline( 4932 out.code, 4933 """\ 4934 4935 4936 4937def forward(self, a_1, b_1) -> torch.Tensor: 4938 index = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None 4939 return index 4940 """, 4941 ) 4942 4943 @unittest.skipIf(IS_FBCODE, "fails in fbcode") 4944 def test_functionalize_optional_tensorlist2(self, device): 4945 def f(a, b) -> torch.Tensor: 4946 # See https://github.com/pytorch/pytorch/pull/77846 4947 return torch.ops.aten.index(a, b) 4948 4949 a = torch.arange(4).reshape(2, 2) 4950 b = torch.ones(2, dtype=torch.long) 4951 out = make_fx(functionalize(f))(a, b) 4952 self.assertExpectedInline( 4953 out.code, 4954 """\ 4955 4956 4957 4958def forward(self, a_1, b_1) -> torch.Tensor: 4959 unbind = torch.ops.aten.unbind.int(b_1); b_1 = None 4960 getitem = unbind[0] 4961 getitem_1 = unbind[1]; unbind = None 4962 index = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None 4963 return index 4964 """, 4965 ) 4966 4967 def test_resize_program_inputs(self, device): 4968 def f(x): 4969 x.resize_(10) 4970 x.fill_(2) 4971 4972 fn = make_fx(functionalize(f)) 4973 out = fn(torch.zeros(0, device=device)) 4974 out = normalize_devices(out) 4975 self.assertExpectedInline( 4976 (out.code), 4977 """\ 4978 4979 4980 4981def forward(self, x_1): 4982 resize = torch.ops.aten.resize.default(x_1, [10]) 4983 fill = torch.ops.aten.fill.Scalar(resize, 2); resize = None 4984 resize_ = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None 4985 copy_ = torch.ops.aten.copy_.default(resize_, fill); resize_ = fill = copy_ = None 4986 return None 4987 """, 4988 ) 4989 4990 4991def construct_sum_pyop(): 4992 class MySum(HigherOrderOperator): 4993 def __init__(self): 4994 super().__init__("mysum") 4995 4996 def __call__(self, *args, **kwargs): 4997 return super().__call__(*args, **kwargs) 4998 4999 mysum = MySum() 5000 5001 @mysum.py_impl(torch._C._functorch.TransformType.Vmap) 5002 def mysum_batch_rule(interpreter, x, dim): 5003 if not torch._C._functorch.is_batchedtensor(x): 5004 with interpreter.lower(): 5005 x = x.view_as(x) # unnecessary, just here to test the dispatch 5006 return mysum(x, dim) 5007 5008 bdim = torch._C._functorch.maybe_get_bdim(x) 5009 value = torch._C._functorch.get_unwrapped(x) 5010 5011 with interpreter.lower(): 5012 value = value.movedim(bdim, 0) 5013 result = mysum(value, dim + 1) 5014 5015 return torch._C._functorch._add_batch_dim(result, 0, interpreter.level()) 5016 5017 @mysum.py_impl(torch._C._functorch.TransformType.Grad) 5018 def mysum_grad_rule(interpreter, x, dim): 5019 level = interpreter.level() 5020 5021 class MySum(torch.autograd.function._SingleLevelFunction): 5022 @staticmethod 5023 def forward(ctx, x, dim): 5024 ctx.x_shape = x.shape 5025 ctx.dim = dim 5026 x = torch._C._functorch._unwrap_for_grad(x, level) 5027 with torch.enable_grad(), interpreter.lower(): 5028 x = x.view_as(x) # unnecessary, just here to test the dispatch 5029 y = mysum(x, dim) 5030 5031 y = torch._C._functorch._wrap_for_grad(y, level) 5032 return y 5033 5034 @staticmethod 5035 def backward(ctx, gy): 5036 return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None 5037 5038 with enable_single_level_autograd_function(): 5039 return MySum.apply(x, dim) 5040 5041 @mysum.py_impl(torch._C.DispatchKey.AutogradCPU) 5042 def mysum_autograd_cpu(x, dim): 5043 return torch.sum(x, dim) 5044 5045 @mysum.py_impl(torch._C.DispatchKey.AutogradCUDA) 5046 def mysum_autograd_cuda(x, dim): 5047 return torch.sum(x, dim) 5048 5049 return mysum 5050 5051 5052sum_pyop = construct_sum_pyop() 5053 5054 5055@markDynamoStrictTest 5056class TestHigherOrderOperatorInteraction(TestCase): 5057 def test_basic_sum(self, device): 5058 x = torch.randn(2, 3, 4, device=device) 5059 result = sum_pyop(x, 1) 5060 self.assertEqual(result, torch.sum(x, 1)) 5061 5062 def test_vmap_sum(self, device): 5063 x = torch.randn(2, 3, 4, device=device) 5064 result = vmap(sum_pyop, (0, None))(x, 0) 5065 self.assertEqual(result, torch.sum(x, 1)) 5066 5067 result = vmap(vmap(sum_pyop, (0, None)), (0, None))(x, 0) 5068 self.assertEqual(result, torch.sum(x, 2)) 5069 5070 def test_grad_sum(self, device): 5071 x = torch.randn(3, device=device) 5072 gx = grad(sum_pyop)(x, 0) 5073 self.assertEqual(gx, torch.ones_like(x)) 5074 5075 def test_grad_grad_sum(self, device): 5076 x = torch.randn(3, requires_grad=True, device=device) 5077 5078 def f(x): 5079 # higher order grad. Requires a non-linearity 5080 return sum_pyop(x.sin(), 0) 5081 5082 def grad_f_sum(x): 5083 return grad(f)(x).sum() 5084 5085 ggx = grad(grad_f_sum)(x) 5086 self.assertEqual(ggx, -x.sin()) 5087 5088 def test_vmap_grad_sum(self, device): 5089 x = torch.randn(2, 3, device=device) 5090 gx = vmap(grad(sum_pyop), (0, None))(x, 0) 5091 self.assertEqual(gx, torch.ones_like(x)) 5092 5093 def test_no_grad_outside_grad(self, device): 5094 x = torch.randn(3, device=device, requires_grad=True) 5095 with torch.no_grad(): 5096 y = grad(sum_pyop)(x, 0) 5097 self.assertEqual(y, torch.ones_like(x)) 5098 self.assertFalse(y.requires_grad) 5099 5100 def test_no_grad_inside_grad(self, device): 5101 def f(x): 5102 with torch.no_grad(): 5103 shift = sum_pyop(x**2, 0) 5104 return sum_pyop(x**2, 0) - shift 5105 5106 x = torch.randn(3, device=device) 5107 y = grad(f)(x) 5108 self.assertEqual(y, 2 * x) 5109 y = grad(lambda x: grad(f)(x).sum())(x) 5110 self.assertEqual(y, torch.full_like(x, 2)) 5111 5112 x = torch.randn(3, device=device, requires_grad=True) 5113 y = grad(f)(x) 5114 (z,) = torch.autograd.grad(y.sum(), x) 5115 self.assertEqual(z, torch.full_like(x, 2)) 5116 5117 def test_grad_name_wrapping(self, device): 5118 def my_fn(x): 5119 return x.sum() 5120 5121 grad_fn = grad(my_fn) 5122 self.assertEqual(grad_fn.__name__, "my_fn") 5123 5124 def test_functional_call_multiple_dicts(self): 5125 mod = nn.Linear(1, 1) 5126 x = torch.randn((1, 1)) 5127 params = ({"weight": torch.zeros(1, 1)}, {"bias": torch.ones(1)}) 5128 functional_call(mod, params, x) 5129 5130 5131def traceable(f): 5132 f = allow_in_graph(f) 5133 5134 @wraps(f) 5135 def wrapper(*args, **kwargs): 5136 return f(*args, **kwargs) 5137 5138 return wrapper 5139 5140 5141@markDynamoStrictTest 5142class TestCompileTransforms(TestCase): 5143 @skipIfRocm(msg="test leaks memory on ROCm") 5144 # torch.compile is not supported on Windows CUDA. 5145 # Triton only supports GPU with SM70 or later. 5146 @expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater)) 5147 def test_compile_vmap_hessian(self, device): 5148 # The model and inputs are a smaller version 5149 # of code at benchmark repo: 5150 # https://github.com/pytorch/benchmark/blob/main/userbenchmark/functorch/vmap_hessian_fc.py 5151 D = 2 5152 B = 4 5153 5154 x = torch.randn(B, D, device=device) 5155 5156 model = nn.Sequential(nn.Linear(D, D), nn.ReLU()).to(device) 5157 5158 params_and_buffers = ( 5159 dict(model.named_parameters()), 5160 dict(model.named_buffers()), 5161 ) 5162 5163 def predict(params_and_buffers, x): 5164 out = torch.func.functional_call(model, params_and_buffers, x) 5165 return out, out 5166 5167 fn = vmap( 5168 jacfwd(jacrev(predict, argnums=1, has_aux=True), argnums=1, has_aux=True), 5169 in_dims=(None, 0), 5170 ) 5171 5172 expected = fn(params_and_buffers, x) 5173 5174 opt_fn = torch.compile(traceable(fn)) 5175 actual = opt_fn(params_and_buffers, x) 5176 self.assertEqual(actual, expected) 5177 5178 # torch.compile is not supported on Windows 5179 @torch._dynamo.config.patch(suppress_errors=False) 5180 def test_grad_deprecated_api(self, device): 5181 x = torch.randn((), device=device) 5182 y = torch.randn((), device=device) 5183 5184 def wrapper_fn(x, y): 5185 return functorch.grad(torch.mul)(x, y) 5186 5187 actual = wrapper_fn(x, y) 5188 expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y) 5189 fn = torch.compile(wrapper_fn, backend="eager", fullgraph=True) 5190 self.assertEqual(actual, expected) 5191 5192 def wrapper_fn(x, y): 5193 return functorch.grad(torch.mul, argnums=(0, 1))(x, y) 5194 5195 actual = wrapper_fn(x, y) 5196 expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y) 5197 self.assertEqual(actual, expected) 5198 5199 5200only_for = ("cpu", "cuda") 5201instantiate_device_type_tests( 5202 TestGradTransform, 5203 globals(), 5204 only_for=only_for, 5205) 5206instantiate_device_type_tests( 5207 TestVmapOfGrad, 5208 globals(), 5209 only_for=only_for, 5210) 5211instantiate_device_type_tests( 5212 TestJac, 5213 globals(), 5214 only_for=only_for, 5215) 5216instantiate_device_type_tests( 5217 TestJvp, 5218 globals(), 5219 only_for=only_for, 5220) 5221instantiate_device_type_tests( 5222 TestLinearize, 5223 globals(), 5224 only_for=only_for, 5225) 5226instantiate_device_type_tests( 5227 TestVmapJvpInplaceView, 5228 globals(), 5229 only_for=only_for, 5230) 5231instantiate_device_type_tests( 5232 TestHessian, 5233 globals(), 5234 only_for=only_for, 5235) 5236instantiate_device_type_tests( 5237 TestComposability, 5238 globals(), 5239 only_for=only_for, 5240) 5241instantiate_device_type_tests( 5242 TestExamplesCorrectness, 5243 globals(), 5244 only_for=only_for, 5245) 5246instantiate_device_type_tests( 5247 TestHigherOrderOperatorInteraction, 5248 globals(), 5249 only_for=only_for, 5250) 5251instantiate_device_type_tests( 5252 TestFunctionalize, 5253 globals(), 5254 only_for=only_for, 5255) 5256instantiate_device_type_tests( 5257 TestAutogradFunction, 5258 globals(), 5259 only_for=only_for, 5260) 5261instantiate_device_type_tests( 5262 TestAutogradFunctionVmapAPI, 5263 globals(), 5264 only_for=only_for, 5265) 5266instantiate_device_type_tests( 5267 TestHelpers, 5268 globals(), 5269 only_for=only_for, 5270) 5271instantiate_parametrized_tests( 5272 TestMakeFunctional, 5273) 5274instantiate_device_type_tests( 5275 TestCompileTransforms, 5276 globals(), 5277 only_for=only_for, 5278) 5279 5280if __name__ == "__main__": 5281 run_tests() 5282