1# Owner(s): ["module: functorch"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import functools 10import itertools 11import unittest 12 13from common_utils import ( 14 check_vmap_fallback, 15 decorate, 16 expectedFailureIf, 17 generate_vmap_inputs, 18 get_fallback_and_vmap_exhaustive, 19 is_batch_norm_training, 20 is_valid_inplace_sample_input, 21 loop, 22 loop2, 23 opsToleranceOverride, 24 skip, 25 skipOps, 26 tol1, 27 tol2, 28 xfail, 29) 30from functorch_additional_op_db import additional_op_db 31 32import torch 33import torch.autograd.forward_ad as fwAD 34from functorch import grad, jacfwd, jacrev, vjp, vmap 35from torch import Tensor 36from torch._functorch.eager_transforms import _as_tuple, jvp 37from torch.testing._internal.autograd_function_db import autograd_function_db 38from torch.testing._internal.common_cuda import with_tf32_off 39from torch.testing._internal.common_device_type import ( 40 instantiate_device_type_tests, 41 ops, 42 tol, 43 toleranceOverride, 44) 45from torch.testing._internal.common_methods_invocations import op_db 46from torch.testing._internal.common_utils import ( 47 is_iterable_of_tensors, 48 IS_MACOS, 49 IS_X86, 50 noncontiguous_like, 51 parametrize, 52 run_tests, 53 runOnRocm, 54 skipIfRocm, 55 TEST_WITH_ASAN, 56 TEST_WITH_ROCM, 57 TestCase, 58 unMarkDynamoStrictTest, 59) 60from torch.testing._internal.opinfo.core import SampleInput 61from torch.utils import _pytree as pytree 62from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten 63 64 65aten = torch.ops.aten 66 67 68# Version of autograd.grad with some differences: 69# - pytree inputs is allowed (but leaves of the pytree have to all 70# be tensors) 71# - if an input is not used as part of derivatives, we will return a 72# zero-filled tensor for the result 73def _autograd_grad( 74 outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True 75): 76 inputs, inputs_spec = tree_flatten(inputs) 77 diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) 78 if grad_outputs is None: 79 diff_outputs = tuple(out for out in outputs if out.requires_grad) 80 else: 81 diff_grad_outputs = [ 82 (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad 83 ] 84 if len(diff_grad_outputs) == 0: 85 diff_outputs, grad_outputs = (), () 86 else: 87 diff_outputs, grad_outputs = zip(*diff_grad_outputs) 88 grad_inputs = torch.autograd.grad( 89 diff_outputs, 90 diff_inputs, 91 grad_outputs, 92 retain_graph=retain_graph, 93 create_graph=create_graph, 94 allow_unused=True, 95 ) 96 result = [] 97 grad_inputs_iter = iter(grad_inputs) 98 for inp in inputs: 99 if inp.requires_grad: 100 grad_input = next(grad_inputs_iter) 101 if grad_input is None: 102 result.append(torch.zeros_like(inp)) 103 else: 104 result.append(grad_input) 105 else: 106 result.append(torch.zeros_like(inp)) 107 return tree_unflatten(result, inputs_spec) 108 109 110def diff_arg(arg, requires_grad=True): 111 def is_differentiable_arg(arg): 112 if requires_grad: 113 return arg.requires_grad 114 else: 115 return arg.is_floating_point() or arg.is_complex() 116 117 if is_iterable_of_tensors(arg): 118 if all(is_differentiable_arg(a) for a in arg): 119 return True 120 if all(not is_differentiable_arg(a) for a in arg): 121 return False 122 raise RuntimeError("NYI: The test runner can't handle this") 123 return isinstance(arg, Tensor) and is_differentiable_arg(arg) 124 125 126# Given f, returns an f' such that: 127# - f' takes only positional arguments 128# - All arguments to f' are floating-point Tensors 129# - All outputs of f' are floating-point Tensors 130def normalize_op_input_output2( 131 f, args, kwargs, output_process_fn_grad=None, requires_grad=True 132): 133 flat_args, args_spec = tree_flatten(args) 134 diff_argnums = tuple( 135 i 136 for i, arg in enumerate(flat_args) 137 if diff_arg(arg, requires_grad=requires_grad) 138 ) 139 assert len(diff_argnums) > 0 140 primals = tuple(flat_args[i] for i in diff_argnums) 141 142 @functools.wraps(f) 143 def wrapped(*primals): 144 _args = list(flat_args) 145 for num, arg in zip(diff_argnums, primals): 146 _args[num] = arg 147 _args = tree_unflatten(_args, args_spec) 148 result = f(*_args, **kwargs) 149 if output_process_fn_grad is not None: 150 result = output_process_fn_grad(result) 151 if isinstance(result, tuple): 152 result = tuple(r for r in result if torch.is_floating_point(r)) 153 assert len(result) > 0 154 return result 155 156 return wrapped, primals 157 158 159# TODO: consolidate with normalize_op_input_output2 160def normalize_op_input_output3( 161 f, args, kwargs, sample_args, output_process_fn_grad=None 162): 163 flat_args, args_spec = tree_flatten(args) 164 flat_sample_args = pytree.tree_leaves(sample_args) 165 diff_argnums = tuple( 166 i 167 for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args)) 168 if diff_arg(sample, requires_grad=True) 169 ) 170 assert len(diff_argnums) > 0 171 primals = tuple(flat_args[i] for i in diff_argnums) 172 173 @functools.wraps(f) 174 def wrapped(*primals): 175 _args = list(flat_args) 176 for num, arg in zip(diff_argnums, primals): 177 _args[num] = arg 178 _args = tree_unflatten(_args, args_spec) 179 result = f(*_args, **kwargs) 180 if output_process_fn_grad is not None: 181 result = output_process_fn_grad(result) 182 if isinstance(result, tuple): 183 result = tuple(r for r in result if torch.is_floating_point(r)) 184 assert len(result) > 0 185 return result 186 187 return wrapped, primals 188 189 190def normalize_op_input_output(f, sample, requires_grad=True): 191 args = tuple([sample.input] + list(sample.args)) 192 return normalize_op_input_output2( 193 f, 194 args, 195 sample.kwargs, 196 sample.output_process_fn_grad, 197 requires_grad=requires_grad, 198 ) 199 200 201def ref_vjp(f, *primals): 202 result = f(*primals) 203 204 def wrapped(cotangents): 205 return _autograd_grad(_as_tuple(result), primals, _as_tuple(cotangents)) 206 207 return result, wrapped 208 209 210def simulate_jvp(f, primals, tangents): 211 primals_out, tangents_out = torch.autograd.functional.jvp(f, primals, tangents) 212 return primals_out, tangents_out 213 214 215def ref_jvp(f, primals, tangents): 216 with fwAD.dual_level(): 217 duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents)) 218 result_duals = f(*duals) 219 result_duals, spec = tree_flatten(result_duals) 220 primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals)) 221 return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec) 222 223 224def get_sample_cotangents(f, sample): 225 fn, primals = normalize_op_input_output(f, sample) 226 output = fn(*primals) 227 return tree_map(torch.randn_like, output) 228 229 230# returns a new function g(*args, *cotangents) 231# that computes vjps and (*args, cotangents) 232def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents): 233 args = tuple([sample.input] + list(sample.args)) 234 kwargs = sample.kwargs 235 flat_args, args_spec = tree_flatten(args) 236 flat_cotangents, cotangents_spec = tree_flatten(cotangents) 237 238 @functools.wraps(f) 239 def wrapped(*args): 240 assert len(args) == len(flat_args) + len(flat_cotangents) 241 actual_args = args[: len(flat_args)] 242 cotangents = args[len(flat_args) :] 243 actual_args = tree_unflatten(actual_args, args_spec) 244 cotangents = tree_unflatten(cotangents, cotangents_spec) 245 246 fn, primals = normalize_op_input_output3( 247 f, actual_args, kwargs, flat_args, sample.output_process_fn_grad 248 ) 249 _, vjp_fn = vjp(fn, *primals) 250 return vjp_fn(cotangents) 251 252 return wrapped, tuple(flat_args + flat_cotangents) 253 254 255# Returns a new function g(*args, *cotangents) that computes vjps and 256# sample (*args, *cotangents) 257def get_vjpfull_variant(f, sample): 258 fn, primals = normalize_op_input_output(f, sample) 259 return _get_vjpfull_variant(fn, primals) 260 261 262def get_vjpfull_variant2(f, args, kwargs): 263 fn, primals = normalize_op_input_output2(f, args, kwargs) 264 return _get_vjpfull_variant(fn, primals) 265 266 267def _get_vjpfull_variant(fn, primals): 268 result = fn(*primals) 269 cotangents = _as_tuple( 270 tree_map(lambda x: torch.randn_like(x, requires_grad=True), result) 271 ) 272 num_primals = len(primals) 273 args = (*primals, *cotangents) 274 275 @functools.wraps(fn) 276 def wrapped(*args): 277 primals = args[:num_primals] 278 cotangents = args[num_primals:] 279 result, vjp_fn = vjp(fn, *primals) 280 if isinstance(result, torch.Tensor): 281 assert len(cotangents) == 1 282 cotangents = cotangents[0] 283 return vjp_fn(cotangents) 284 285 return wrapped, args 286 287 288def get_jvp_variant(f, sample): 289 # We want this higher-order variant of jvp, so that it can 290 # be used to wrap vmap 291 fn, primals = normalize_op_input_output(f, sample, requires_grad=False) 292 tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals)) 293 294 @functools.wraps(f) 295 def wrapped(*args): 296 tangents = args 297 primals_out, tangents_out = jvp(fn, primals, tangents) 298 299 if isinstance(primals_out, torch.Tensor): 300 return (primals_out, tangents_out) 301 else: 302 flat_primals_out = pytree.tree_leaves(primals_out) 303 flat_tangents_out = pytree.tree_leaves(tangents_out) 304 return tuple(flat_primals_out + flat_tangents_out) 305 306 return wrapped, tangents 307 308 309def get_jvp_variant_primals_tangents2( 310 f, args, kwargs, output_process_fn_grad=None, requires_grad=False 311): 312 fn, primals = normalize_op_input_output2( 313 f, args, kwargs, output_process_fn_grad, requires_grad 314 ) 315 tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals)) 316 return _get_jvp_variant(fn, primals, tangents) 317 318 319def get_jvp_variant_primals_tangents(f, sample): 320 # We want this higher-order variant of jvp, so that it can 321 # be used to wrap vmap 322 fn, primals = normalize_op_input_output(f, sample, requires_grad=False) 323 tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals)) 324 return _get_jvp_variant(fn, primals, tangents) 325 326 327def _get_jvp_variant(fn, primals, tangents): 328 @functools.wraps(fn) 329 def wrapped(*args): 330 primals_in = args[: len(primals)] 331 tangents_in = args[len(primals) :] 332 primals_out, tangents_out = jvp(fn, primals_in, tangents_in) 333 334 if isinstance(primals_out, torch.Tensor): 335 return (primals_out, tangents_out) 336 else: 337 flat_primals_out = pytree.tree_leaves(primals_out) 338 flat_tangents_out = pytree.tree_leaves(tangents_out) 339 return tuple(flat_primals_out + flat_tangents_out) 340 341 return wrapped, primals + tangents 342 343 344def is_inplace(op, variant): 345 if hasattr(variant, "__wrapped__"): 346 return variant.__wrapped__ is op.get_inplace() 347 return variant is op.get_inplace() 348 349 350vjp_fail = { 351 xfail("tensor_split"), # data_ptr composite compliance 352 # Very minor accuracy issue on ROCm 353 decorate("nn.functional.scaled_dot_product_attention", decorator=skipIfRocm), 354} 355 356aliasing_ops = { 357 "T", 358 "broadcast_to", 359 "conj", 360 "contiguous", 361 "diagonal", # linalg.diagonal is an alias 362 "expand", 363 "flatten", 364 "imag", 365 "mH", # adjoint is an alias 366 "mT", 367 "movedim", # moveaxis is an alias 368 "narrow", 369 "permute", 370 "positive", 371 # 'ravel', is composite implicit autograd and may call clone 372 "real", 373 "reshape", 374 "resolve_conj", 375 "resolve_neg", 376 "select", 377 "squeeze", 378 "transpose", # swapdims and swapaxes are aliases 379 "unflatten", 380 "unfold", 381 "unsqueeze", 382 "view", 383 "view_as", 384 "view_as_complex", 385 "view_as_real", 386} 387 388aliasing_ops_list_return = { 389 "chunks", 390 "dsplit", 391 "hsplit", 392 "split", 393 "unbind", 394 "vsplit", 395 # 'tensor_split' not composite compliant, see vjp_fail 396} 397 398skip_noncontig = { 399 "_batch_norm_with_update", 400 "as_strided_copy", 401} 402 403 404@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant") 405@unMarkDynamoStrictTest 406class TestOperators(TestCase): 407 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 408 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 409 @skipOps( 410 "TestOperators", 411 "test_grad", 412 vjp_fail.union( 413 { 414 xfail( 415 "chalf", "", device_type="cpu" 416 ), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' 417 xfail( 418 "sparse.sampled_addmm", "" 419 ), # RuntimeError: Sparse CSR tensors do not have strides 420 xfail( 421 "sparse.mm", "reduce" 422 ), # RuntimeError: Sparse CSR tensors do not have strides 423 # Non-contiguous Bugs 424 # 425 # AssertionError: Tensor-likes are not close! 426 xfail("_softmax_backward_data", device_type="cpu"), 427 xfail("as_strided"), 428 xfail("as_strided", "partial_views"), 429 # RuntimeError: !self.requires_grad() || self.is_contiguous() 430 xfail("as_strided_scatter"), 431 # RuntimeError: Tensor must have a last dimension with stride 1 432 xfail("view_as_complex"), 433 # query: last dimension must be contiguous 434 # Fused attention kernels require last dim to be contiguous 435 decorate( 436 "nn.functional.scaled_dot_product_attention", 437 decorator=expectedFailureIf(not TEST_WITH_ROCM), 438 ), # Works on ROCm 439 xfail("torch.ops.aten._flash_attention_forward"), 440 xfail("torch.ops.aten._efficient_attention_forward"), 441 # RuntimeError: Expected contiguous tensor, but got 442 # non-contiguous tensor for argument #2 'grad_output' 443 decorate( 444 "_batch_norm_with_update", 445 decorator=expectedFailureIf(TEST_WITH_ROCM), 446 device_type="cuda", 447 ), 448 } 449 ), 450 ) 451 @opsToleranceOverride( 452 "TestOperators", 453 "test_grad", 454 ( 455 tol1( 456 "nn.functional.binary_cross_entropy_with_logits", 457 {torch.float32: tol(atol=1e-04, rtol=1e-04)}, 458 ), 459 tol1("masked.cumprod", {torch.float32: tol(atol=1e-05, rtol=1e-05)}), 460 tol1("svd_lowrank", {torch.float32: tol(atol=3e-04, rtol=3e-04)}), 461 tol1( 462 "linalg.multi_dot", 463 {torch.float32: tol(atol=1e-05, rtol=8e-04)}, 464 device_type="cuda", 465 ), 466 tol1( 467 "linalg.tensorsolve", 468 {torch.float32: tol(atol=3e-04, rtol=3e-04)}, 469 device_type="cuda", 470 ), 471 tol1( 472 "nn.functional.multi_head_attention_forward", 473 {torch.float32: tol(atol=8e-04, rtol=1e-03)}, 474 ), 475 tol1( 476 "__rmatmul__", 477 {torch.float32: tol(atol=3e-04, rtol=3e-04)}, 478 device_type="cuda", 479 ), 480 tol1( 481 "matmul", 482 {torch.float32: tol(atol=3e-04, rtol=3e-04)}, 483 device_type="cuda", 484 ), 485 tol1( 486 "pca_lowrank", 487 {torch.float32: tol(atol=3e-05, rtol=4e-06)}, 488 device_type="cpu", 489 ), 490 ), 491 ) 492 def test_grad(self, device, dtype, op): 493 if op.name in vjp_fail: 494 self.skipTest("Skipped; Expected failures") 495 return 496 497 if not op.supports_autograd: 498 self.skipTest("Skipped! Autograd not supported.") 499 return 500 501 samples = op.sample_inputs(device, dtype, requires_grad=True) 502 503 if is_inplace(op, op.get_op()): 504 self.skipTest("Skipped for redundancy. test_vjp handles in-place testing.") 505 return 506 507 for sample in samples: 508 args = [sample.input] + list(sample.args) 509 kwargs = sample.kwargs 510 511 if op.name not in skip_noncontig: 512 noncontig_sample = sample.noncontiguous() 513 noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args) 514 noncontig_kwargs = noncontig_sample.kwargs 515 516 diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg)) 517 assert len(diff_argnums) > 0 518 diff_args = tuple(args[i] for i in diff_argnums) 519 520 def wrapped_fn(*args, **kwargs): 521 result = op(*args, **kwargs) 522 if sample.output_process_fn_grad is not None: 523 result = sample.output_process_fn_grad(result) 524 525 def abs_if_complex(t): 526 if t.dtype.is_complex: 527 return t.abs() 528 return t 529 530 # Reduce into single value for grad 531 if isinstance(result, torch.Tensor): 532 return abs_if_complex(result.sum()) 533 result = sum(abs_if_complex(res.sum()) for res in result) 534 return result 535 536 result = grad(wrapped_fn, diff_argnums)(*args, **kwargs) 537 expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args) 538 self.assertEqual(result, expected) 539 540 if op.name not in skip_noncontig: 541 result_noncontig = grad(wrapped_fn, diff_argnums)( 542 *noncontig_args, **noncontig_kwargs 543 ) 544 self.assertEqual(result_noncontig, expected) 545 546 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 547 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 548 @skipOps( 549 "TestOperators", 550 "test_jvp", 551 set( 552 { 553 # Composite ops that do bad things. Need to be fixed in PyTorch core. 554 # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage 555 xfail("tensor_split"), 556 # BUG: silent incorrectness: runs and produces numerical differences 557 skip("nn.functional.max_unpool1d"), # fails everywhere except on mac 558 skip( 559 "nn.functional.max_unpool2d" 560 ), # fails everywhere except on windows 561 skip("nn.functional.max_unpool3d"), # fails everywhere except on mac 562 xfail( 563 "native_batch_norm" 564 ), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents 565 xfail( 566 "_native_batch_norm_legit" 567 ), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents 568 xfail( 569 "_batch_norm_with_update" 570 ), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents 571 xfail("nn.functional.scaled_dot_product_attention"), 572 xfail("torch.ops.aten._flash_attention_forward"), 573 xfail("torch.ops.aten._efficient_attention_forward"), 574 xfail( 575 "nn.functional.rrelu" 576 ), # in-place test errors out with no formula implemented 577 xfail( 578 "NumpyExpMarkDirtyAutogradFunction" 579 ), # TODO: https://github.com/pytorch/pytorch/issues/91280 580 # --- Non-Contiguous Failures! --- 581 # This is expected to fail as the operator 582 # expects last dim to have stride=1 583 xfail("view_as_complex"), 584 # BUG 585 # AssertionError: Tensor-likes are not close! 586 xfail("as_strided"), 587 xfail("as_strided", "partial_views"), 588 xfail("as_strided_scatter"), 589 decorate( 590 "linalg.det", 591 "singular", 592 decorator=expectedFailureIf(IS_MACOS and IS_X86), 593 ), 594 } 595 ), 596 ) 597 @opsToleranceOverride( 598 "TestOperators", 599 "test_jvp", 600 ( 601 tol1( 602 "nn.functional.conv_transpose3d", 603 {torch.float32: tol(atol=1e-04, rtol=1.3e-06)}, 604 device_type="cuda", 605 ), 606 tol1( 607 "linalg.tensorsolve", 608 {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, 609 device_type="cuda", 610 ), 611 tol1( 612 "masked.prod", 613 {torch.float32: tol(atol=1e-05, rtol=1.3e-05)}, 614 device_type="cuda", 615 ), 616 tol1( 617 "nn.functional.binary_cross_entropy_with_logits", 618 {torch.float32: tol(atol=4e-04, rtol=4e-04)}, 619 ), 620 tol1( 621 "nn.functional.batch_norm", {torch.float32: tol(atol=4e-05, rtol=5e-05)} 622 ), 623 tol1("nn.functional.conv2d", {torch.float32: tol(atol=4e-05, rtol=5e-05)}), 624 tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), 625 tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), 626 tol1( 627 "nn.functional.multi_head_attention_forward", 628 {torch.float32: tol(atol=6e-05, rtol=2e-05)}, 629 ), 630 tol2( 631 "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-5, rtol=2e-5)} 632 ), 633 ), 634 ) 635 def test_jvp(self, device, dtype, op): 636 # TODO: get rid of vjp_decomp when we add decomposition support to 637 # PyTorch's forward-mode ad. Currently the decomposition support only 638 # works for functorch.jvp 639 VJP_DECOMP = { 640 "nn.functional.logsigmoid", 641 } 642 if op.name in VJP_DECOMP: 643 fixme_ref_jvp_local = simulate_jvp 644 else: 645 fixme_ref_jvp_local = ref_jvp 646 647 if not op.supports_forward_ad and op.name not in VJP_DECOMP: 648 self.skipTest("Skipped! Forward AD not supported.") 649 return 650 651 samples = op.sample_inputs(device, dtype, requires_grad=True) 652 653 outplace_variant = op if not is_inplace(op, op.get_op()) else None 654 inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None 655 656 for sample in samples: 657 if outplace_variant: 658 self.jvp_opinfo_test( 659 outplace_variant, 660 sample, 661 sample.output_process_fn_grad, 662 clone_inputs=False, 663 fixme_ref_jvp_local=fixme_ref_jvp_local, 664 test_noncontig=op.name not in skip_noncontig, 665 ) 666 if is_valid_inplace_sample_input(sample, op, inplace_variant): 667 self.jvp_opinfo_test( 668 inplace_variant, 669 sample, 670 sample.output_process_fn_grad, 671 clone_inputs=True, 672 fixme_ref_jvp_local=fixme_ref_jvp_local, 673 test_noncontig=op.name not in skip_noncontig, 674 ) 675 676 def jvp_opinfo_test( 677 self, 678 fn, 679 sample, 680 output_process_fn, 681 clone_inputs, 682 fixme_ref_jvp_local, 683 test_noncontig, 684 ): 685 # NB: we used requires_grad=True to determine where the primals are, 686 # but don't need that information otherwise 687 args = (sample.input,) + sample.args 688 kwargs = sample.kwargs 689 contig_fn, primals = normalize_op_input_output2( 690 fn, args, kwargs, output_process_fn, requires_grad=True 691 ) 692 orig_primals = tree_map(lambda x: x.detach(), primals) 693 orig_tangents = tree_map(lambda x: torch.randn_like(x), primals) 694 695 def maybe_clone_inputs(): 696 if clone_inputs: 697 primals = tree_map(torch.clone, orig_primals) 698 tangents = tree_map(torch.clone, orig_tangents) 699 return primals, tangents 700 return orig_primals, orig_tangents 701 702 primals, tangents = maybe_clone_inputs() 703 expected_primal_outs, expected_tangent_outs = fixme_ref_jvp_local( 704 contig_fn, primals, tangents 705 ) 706 707 primals, tangents = maybe_clone_inputs() 708 primal_outs, tangent_outs = jvp(contig_fn, primals, tangents) 709 710 self.assertEqual(primal_outs, expected_primal_outs) 711 self.assertEqual(tangent_outs, expected_tangent_outs) 712 713 if test_noncontig: 714 noncontig_sample = sample.noncontiguous() 715 noncontig_args = (noncontig_sample.input,) + noncontig_sample.args 716 noncontig_kwargs = sample.kwargs 717 noncontig_fn, primals = normalize_op_input_output2( 718 fn, 719 noncontig_args, 720 noncontig_kwargs, 721 output_process_fn, 722 requires_grad=True, 723 ) 724 noncontig_primals = tree_map(lambda x: x.detach(), primals) 725 noncontig_tangents = tree_map( 726 lambda x: noncontiguous_like(x), orig_tangents 727 ) 728 noncontig_primal_outs, noncontig_tangent_outs = jvp( 729 noncontig_fn, noncontig_primals, noncontig_tangents 730 ) 731 732 self.assertEqual(noncontig_primal_outs, expected_primal_outs) 733 self.assertEqual(noncontig_tangent_outs, expected_tangent_outs) 734 735 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 736 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 737 @skipOps( 738 "TestOperators", 739 "test_vjp", 740 vjp_fail.union( 741 { 742 xfail("sparse.sampled_addmm", ""), 743 xfail("sparse.mm", "reduce"), 744 # ---- Non-Contiguous Failures ---- 745 # This is expected to fail as the operator 746 # expects last dim to have stride=1 747 xfail("view_as_complex"), 748 # RuntimeError: query: last dimension must be contiguous 749 # The fused attention kernels require the last dim to be contiguous 750 decorate( 751 "nn.functional.scaled_dot_product_attention", 752 decorator=expectedFailureIf(not TEST_WITH_ROCM), 753 ), # Works on ROCm 754 xfail("torch.ops.aten._flash_attention_forward"), 755 xfail("torch.ops.aten._efficient_attention_forward"), 756 # BUG 757 # AssertionError: Tensor-likes are not close! 758 xfail("as_strided"), 759 xfail("as_strided_scatter"), 760 xfail("_softmax_backward_data", device_type="cpu"), 761 xfail("as_strided", "partial_views"), 762 } 763 ), 764 ) 765 @opsToleranceOverride( 766 "TestOperators", 767 "test_vjp", 768 ( 769 tol1( 770 "nn.functional.conv_transpose3d", 771 {torch.float32: tol(atol=5e-05, rtol=9e-05)}, 772 device_type="cuda", 773 ), 774 tol1( 775 "nn.functional.binary_cross_entropy_with_logits", 776 {torch.float32: tol(atol=1e-04, rtol=1e-04)}, 777 ), 778 tol1( 779 "nn.functional.multi_head_attention_forward", 780 {torch.float32: tol(atol=2e-03, rtol=2e-04)}, 781 ), 782 tol1("__rmatmul__", {torch.float32: tol(atol=1e-05, rtol=1e-05)}), 783 tol1("matmul", {torch.float32: tol(atol=1e-05, rtol=1e-05)}), 784 tol2( 785 "linalg.pinv", "hermitian", {torch.float32: tol(atol=1e-05, rtol=1e-05)} 786 ), 787 tol1("linalg.tensorsolve", {torch.float32: tol(atol=9e-03, rtol=2e-04)}), 788 tol1("linalg.multi_dot", {torch.float32: tol(atol=1e-04, rtol=1e-04)}), 789 tol1("svd_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}), 790 tol1("pca_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}), 791 ), 792 ) 793 def test_vjp(self, device, dtype, op): 794 if not op.supports_autograd: 795 self.skipTest("Skipped! Autograd not supported.") 796 return 797 798 samples = op.sample_inputs(device, dtype, requires_grad=True) 799 800 def _test(_op, inplace=False): 801 for sample in samples: 802 if inplace and not is_valid_inplace_sample_input( 803 sample, op, op.inplace_variant 804 ): 805 continue 806 fn, primals = normalize_op_input_output(_op, sample) 807 result = fn(*primals) 808 cotangents = tree_map(lambda x: torch.randn_like(x), result) 809 810 out, vjp_fn = vjp(fn, *primals) 811 self.assertEqual(out, result) 812 result_vjps = vjp_fn(cotangents) 813 814 _, vjp_fn = ref_vjp(fn, *primals) 815 expected_vjps = vjp_fn(cotangents) 816 817 self.assertEqual(result_vjps, expected_vjps) 818 819 if op.name not in skip_noncontig: 820 noncontig_fn, noncontig_primals = normalize_op_input_output( 821 _op, sample.noncontiguous() 822 ) 823 noncontig_cotangents = tree_map( 824 lambda x: noncontiguous_like(x), cotangents 825 ) 826 out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals) 827 self.assertEqual(out_noncontig, result) 828 noncontig_result_vjps = vjp_fn(noncontig_cotangents) 829 self.assertEqual(noncontig_result_vjps, expected_vjps) 830 831 _test(op) 832 for a_op in op.aliases: 833 _test(a_op) 834 if op.inplace_variant: 835 836 def f(inp, *args, **kwargs): 837 return op.inplace_variant(inp.clone(), *args, **kwargs) 838 839 _test(f, inplace=True) 840 841 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 842 @skipOps( 843 "TestOperators", 844 "test_vjpvjp", 845 vjp_fail.union( 846 { 847 skip("nn.functional.max_unpool1d"), # silent incorrectness; Flaky 848 skip("nn.functional.max_unpool2d"), # silent incorrectness; Flaky 849 xfail("nn.functional.ctc_loss"), # Not Implemented 850 xfail( 851 "native_layer_norm", "" 852 ), # Expected a proper Tensor but got None for argument #1 'other' 853 xfail("sparse.sampled_addmm", ""), # sparse tensors have no strides 854 xfail("sparse.mm", "reduce"), # sparse tensors have no strides 855 skip("nn.functional.scaled_dot_product_attention"), 856 xfail("torch.ops.aten._flash_attention_forward"), 857 xfail("torch.ops.aten._efficient_attention_forward"), 858 # AssertionError: Tensor-likes are not close! 859 # Mismatched elements: 1 / 15 (6.7%) 860 # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) 861 # Greatest relative difference: 1.7933241714393998e-06 at index (2, 4) (up to 1.3e-06 allowed) 862 # The failure occurred for item [0] 863 xfail("masked.prod"), 864 } 865 ), 866 ) 867 @opsToleranceOverride( 868 "TestOperators", 869 "test_vjpvjp", 870 ( 871 tol1( 872 "nn.functional.conv_transpose3d", 873 {torch.float32: tol(atol=5e-05, rtol=9e-05)}, 874 device_type="cuda", 875 ), 876 tol1("prod", {torch.float32: tol(atol=2e-05, rtol=1e-04)}), 877 tol1("masked.cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), 878 tol1("cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), 879 tol1("linalg.vander", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), 880 tol2( 881 "linalg.det", "singular", {torch.float32: tol(atol=2e-05, rtol=2e-05)} 882 ), 883 ), 884 ) 885 def test_vjpvjp(self, device, dtype, op): 886 if not op.supports_autograd: 887 self.skipTest("Skipped! Autograd not supported.") 888 return 889 if not op.supports_gradgrad: 890 self.skipTest("Skipped! Operation does not support gradgrad") 891 return 892 893 samples = op.sample_inputs(device, dtype, requires_grad=True) 894 895 def test(_op, inplace=False): 896 for sample in samples: 897 if inplace and not is_valid_inplace_sample_input( 898 sample, op, op.inplace_variant 899 ): 900 continue 901 fn, args = get_vjpfull_variant(_op, sample) 902 result = fn(*args) 903 cotangents = tree_map(lambda x: torch.randn_like(x), result) 904 905 # Compute vjp of vjp 906 _, vjp_fn = vjp(fn, *args) 907 result_vjps = vjp_fn(cotangents) 908 909 # Compute ref_vjp of vjp. We could have done ref_vjp of ref_vjp, 910 # but since we're confident that vjp works by itself, this is 911 # an equivalent way to test that. 912 _, vjp_fn = ref_vjp(fn, *args) 913 expected_vjps = vjp_fn(cotangents) 914 915 self.assertEqual(result_vjps, expected_vjps) 916 917 test(op) 918 if op.inplace_variant: 919 920 def fn(inp, *args, **kwargs): 921 return op.inplace_variant(inp.clone(), *args, **kwargs) 922 923 test(fn, inplace=True) 924 925 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 926 @skipOps( 927 "TestOperators", 928 "test_vmapvjpvjp", 929 vjp_fail.union( 930 { 931 skip("atleast_1d"), # Takes too long 932 skip("atleast_2d"), # Takes too long 933 skip("atleast_3d"), # Takes too long 934 skip("ormqr"), # Takes too long 935 xfail("as_strided"), # incorrect output 936 xfail("as_strided", "partial_views"), # incorrect output 937 xfail("as_strided_scatter"), # incorrect output 938 skip("bernoulli"), # calls random op 939 xfail("bfloat16"), # rank 4 tensor for channels_last 940 xfail("cdouble"), # rank 4 tensor for channels_last 941 xfail("cfloat"), # rank 4 tensor for channels_last 942 xfail("chalf"), # rank 4 tensor for channels_last 943 xfail("double"), # rank 4 tensor for channels_last 944 xfail("float"), # rank 4 tensor for channels_last 945 xfail("half"), # rank 4 tensor for channels_last 946 xfail( 947 "NumpyCubeNotComposableAutogradFunction" 948 ), # Not composable autograd.Function 949 # It looks like you're either (1) calling .item() on a Tensor or 950 # (2) attempting to use a Tensor in some data-dependent control flow or 951 # (3) encountering this error in PyTorch internals. 952 xfail("index_reduce", "prod"), 953 decorate( 954 "linalg.householder_product", decorator=runOnRocm 955 ), # works on ROCm 956 xfail( 957 # nans 958 "masked.softmax", 959 device_type="cpu", 960 ), 961 xfail( 962 "nanquantile", device_type="cpu" 963 ), # vmap not implemented for at::equal. 964 xfail("native_layer_norm"), # vmap: inplace into a regular tensor 965 # got a batched tensor as input while the running_mean or running_var, 966 # which will be updated in place, were not batched. 967 xfail("nn.functional.batch_norm"), 968 xfail( 969 "nn.functional.binary_cross_entropy" 970 ), # vmap: inplace into a regular tensor 971 xfail( 972 "nn.functional.ctc_loss" 973 ), # derivate not implemented for _ctc_loss_backward 974 # flaky on ROCM needs investigation 975 decorate("nn.functional.conv_transpose2d", decorator=skipIfRocm), 976 skip("nn.functional.dropout"), # calls random op 977 skip("nn.functional.dropout2d"), # calls random op 978 skip("nn.functional.dropout3d"), # calls random op 979 skip("nn.functional.alpha_dropout"), # calls random op 980 skip( 981 "nn.functional.feature_alpha_dropout", "with_train" 982 ), # calls random op 983 skip("nn.functional.fractional_max_pool2d"), # calls random op 984 skip("nn.functional.fractional_max_pool3d"), # calls random op 985 xfail("nn.functional.scaled_dot_product_attention"), # randomness 986 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 987 xfail("nn.functional.multi_head_attention_forward"), # randomness 988 # It looks like you're either (1) calling .item() on a Tensor or 989 # (2) attempting to use a Tensor in some data-dependent control flow or 990 # (3) encountering this error in PyTorch internals. 991 xfail("nn.functional.gaussian_nll_loss"), 992 # got a batched tensor as input while the running_mean or running_var, 993 # which will be updated in place, were not batched. 994 xfail("nn.functional.instance_norm"), 995 xfail( 996 "nn.functional.layer_norm" 997 ), # vmap: inplace into a regular tensor 998 # RuntimeError: NYI: querying is_contiguous inside of vmap 999 # for memory_format other than torch.contiguous_formats 1000 xfail("nn.functional.max_pool2d"), 1001 # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only 1002 # supported with memory_format torch.preserve_format or 1003 # torch.contiguous_format (got ChannelsLast) 1004 xfail("nn.functional.max_unpool2d"), 1005 # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only 1006 # supported with memory_format torch.preserve_format 1007 # or torch.contiguous_format (got ChannelsLast)s 1008 xfail("nn.functional.max_unpool2d", "grad"), 1009 xfail( 1010 "nn.functional.rrelu" 1011 ), # RuntimeError: vmap: we do not yet support aten::rrelu_with_noise. 1012 xfail("normal"), # calls random op 1013 xfail("normal", "number_mean"), # calls random op 1014 xfail("pca_lowrank"), # calls random op 1015 xfail( 1016 "quantile", device_type="cpu" 1017 ), # Batching rule not implemented for `at::equal` 1018 xfail( 1019 "scatter_reduce", "prod" 1020 ), # vmap (looks like you are calling item/data-dependent) 1021 xfail( 1022 "sparse.sampled_addmm" 1023 ), # RuntimeError: Sparse CSR tensors do not have strides 1024 xfail( 1025 "sparse.mm", "reduce" 1026 ), # RuntimeError: Sparse CSR tensors do not have strides 1027 xfail("svd_lowrank"), # calls random op 1028 xfail("to"), # rank 4 tensor for channels_last 1029 xfail( 1030 "view_as_complex" 1031 ), # RuntimeError: Tensor must have a last dimension with stride 1 1032 # got a batched tensor as input while the running_mean or running_var, 1033 # which will be updated in place, were not batched. 1034 xfail("nn.functional.batch_norm", "without_cudnn"), 1035 # view doesn't work on sparse 1036 xfail("to_sparse"), 1037 xfail("native_batch_norm"), 1038 xfail("_native_batch_norm_legit"), 1039 # TODO: implement batching rule 1040 xfail("_batch_norm_with_update"), 1041 } 1042 ), 1043 ) 1044 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 1045 @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) 1046 @opsToleranceOverride( 1047 "TestOperators", 1048 "test_vmapvjpvjp", 1049 ( 1050 tol1("linalg.svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}), 1051 tol1("linalg.lu", {torch.float32: tol(atol=5e-04, rtol=7e-04)}), 1052 tol1("linalg.lu_factor", {torch.float32: tol(atol=2e-03, rtol=2e-02)}), 1053 tol1("linalg.multi_dot", {torch.float32: tol(atol=2e-03, rtol=2e-04)}), 1054 tol1("svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}), 1055 tol1("matrix_exp", {torch.float32: tol(atol=1e-03, rtol=5e-04)}), 1056 tol1("masked.prod", {torch.float32: tol(atol=2e-03, rtol=2e-04)}), 1057 ), 1058 ) 1059 @skipOps( 1060 "TestOperators", 1061 "test_vmapvjpvjp", 1062 { 1063 xfail("as_strided", "partial_views"), 1064 xfail("as_strided_copy"), 1065 }, 1066 ) 1067 def test_vmapvjpvjp(self, device, dtype, op): 1068 # Since, we test `vjpvjp` independently, 1069 # for this test, we just verify that vmap 1070 # of `vjpvjp` is correct. 1071 if not op.supports_autograd: 1072 self.skipTest("Skipped! Autograd not supported.") 1073 return 1074 if not op.supports_gradgrad: 1075 self.skipTest("Skipped! Operation does not support gradgrad") 1076 return 1077 1078 samples = op.sample_inputs(device, dtype, requires_grad=True) 1079 1080 # TODO: test in-place 1081 if is_inplace(op, op.get_op()): 1082 self.skipTest("Skipped! NYI: inplace-testing not supported.") 1083 return 1084 1085 for sample in samples: 1086 fn, args = get_vjpfull_variant(op, sample) 1087 result = fn(*args) 1088 cotangents = tree_map(lambda x: torch.randn_like(x), result) 1089 cotangents = pytree.tree_leaves(cotangents) 1090 num_args = len(args) 1091 1092 args_and_cotangents = tuple(args) + tuple(cotangents) 1093 1094 def vjp_of_vjp(*args_and_cotangents): 1095 args = args_and_cotangents[:num_args] 1096 cotangents = args_and_cotangents[num_args:] 1097 result, vjp_fn = vjp(fn, *args) 1098 result_vjps = vjp_fn(cotangents) 1099 result = pytree.tree_leaves(result) 1100 result_vjps = pytree.tree_leaves(result_vjps) 1101 return (*result, *result_vjps) 1102 1103 is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) 1104 generator = get_fallback_and_vmap_exhaustive( 1105 vjp_of_vjp, 1106 args_and_cotangents, 1107 {}, 1108 is_batch_norm_and_training=is_batch_norm_and_training, 1109 ) 1110 for loop_out, batched_out in generator: 1111 self.assertEqual(loop_out, batched_out) 1112 1113 vmapvjp_fail = vjp_fail.union( 1114 { 1115 # -------------------- ALLOWED FAILURES -------------------------------- 1116 # The following are not bugs and are expected behavior 1117 xfail("masked_select"), # Not possible due to dynamic shapes 1118 skip("bernoulli"), # randomness 1119 skip("normal", ""), # randomness 1120 skip("normal", "number_mean"), # randomness 1121 skip("nn.functional.rrelu"), # randomness 1122 skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness 1123 skip("nn.functional.feature_alpha_dropout", "without_train"), # randomness 1124 skip("nn.functional.dropout"), # randomness 1125 skip("nn.functional.dropout2d"), # randomness 1126 skip("nn.functional.dropout3d", ""), # randomness 1127 skip("nn.functional.alpha_dropout"), # randomness 1128 skip("nn.functional.scaled_dot_product_attention"), # randomness 1129 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 1130 skip("nn.functional.multi_head_attention_forward"), # randomness 1131 xfail( 1132 "index_put", "" 1133 ), # not possible due to dynamic shapes; we support a subset 1134 xfail("nn.functional.fractional_max_pool2d"), # random 1135 xfail("nn.functional.fractional_max_pool3d"), # random 1136 xfail("pca_lowrank", ""), # randomness 1137 xfail("svd_lowrank", ""), # randomness 1138 xfail("to_sparse", ""), # non-dense output 1139 skip( 1140 "to" 1141 ), # RuntimeError: required rank 4 tensor to use channels_last format 1142 xfail("as_strided", "partial_views"), 1143 xfail( 1144 "NumpyCubeNotComposableAutogradFunction" 1145 ), # Not composable autograd.Function 1146 # ---------------------------------------------------------------------- 1147 # ---------------------------- BUGS ------------------------------------ 1148 # All of the following are bugs and need to be fixed 1149 skip( 1150 "linalg.svdvals" 1151 ), # # really annoying thing where it passes correctness check but not has_batch_rule 1152 skip("native_batch_norm"), 1153 skip("_native_batch_norm_legit"), 1154 # TODO: implement batching rule 1155 skip("_batch_norm_with_update"), 1156 xfail("__getitem__", ""), # dynamic error 1157 xfail("nanquantile", device_type="cpu"), # checks q via a .item() call 1158 xfail("nn.functional.gaussian_nll_loss"), # checks var for if any value < 0 1159 xfail("narrow"), # .item() call 1160 xfail("quantile", device_type="cpu"), # checks q via a .item() call 1161 xfail("view_as_complex"), # Tensor must have a last dimension with stride 1 1162 # required rank 4 tensor to use channels_last format 1163 xfail("bfloat16"), 1164 xfail("double"), 1165 xfail("float"), 1166 xfail("half"), 1167 xfail("cdouble", ""), 1168 xfail("cfloat", ""), 1169 xfail("chalf", ""), 1170 xfail("scatter_reduce", "prod"), # item call 1171 # Batching rule not implemented for aten::_use_cudnn_ctc_loss.Tensor 1172 xfail("nn.functional.ctc_loss", device_type="cuda"), 1173 # NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format 1174 xfail("nn.functional.max_unpool2d"), 1175 xfail("nn.functional.max_unpool2d", "grad"), 1176 xfail("sparse.sampled_addmm", ""), 1177 xfail("sparse.mm", "reduce"), 1178 xfail("as_strided_scatter", ""), # calls as_strided 1179 xfail("index_reduce", "prod"), # .item() call 1180 # --------------------------------------------------------------------- 1181 } 1182 ) 1183 1184 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 1185 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 1186 @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) 1187 @opsToleranceOverride( 1188 "TestOperators", 1189 "test_vmapvjp", 1190 ( 1191 tol1( 1192 "linalg.svd", 1193 {torch.float32: tol(atol=5e-04, rtol=1e-04)}, 1194 device_type="cuda", 1195 ), 1196 tol1( 1197 "svd", {torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda" 1198 ), 1199 tol1( 1200 "linalg.householder_product", 1201 {torch.float32: tol(atol=3e-04, rtol=9e-04)}, 1202 ), 1203 tol1( 1204 "matrix_exp", 1205 {torch.float32: tol(atol=5e-04, rtol=1e-04)}, 1206 device_type="cuda", 1207 ), 1208 tol1( 1209 "nn.functional.layer_norm", 1210 {torch.float32: tol(atol=3e-4, rtol=1e-4)}, 1211 device_type="cpu", 1212 ), 1213 tol1( 1214 "native_layer_norm", 1215 {torch.float32: tol(atol=3e-4, rtol=1e-4)}, 1216 device_type="cpu", 1217 ), 1218 ), 1219 ) 1220 @skipOps( 1221 "TestOperators", 1222 "test_vmapvjp", 1223 vmapvjp_fail.union( 1224 { 1225 xfail("as_strided"), 1226 xfail("as_strided_copy"), 1227 xfail("as_strided", "partial_views"), 1228 } 1229 ), 1230 ) 1231 def test_vmapvjp(self, device, dtype, op): 1232 if not op.supports_autograd: 1233 self.skipTest("Skipped! Autograd not supported.") 1234 return 1235 1236 samples = op.sample_inputs(device, dtype, requires_grad=True) 1237 1238 # TODO: test in-place 1239 if is_inplace(op, op.get_op()): 1240 self.skipTest("Skipped! NYI: inplace-testing not supported.") 1241 return 1242 for sample in samples: 1243 cotangents = get_sample_cotangents(op, sample) 1244 fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) 1245 is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) 1246 generator = get_fallback_and_vmap_exhaustive( 1247 fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training 1248 ) 1249 for loop_out, batched_out in generator: 1250 self.assertEqual(loop_out, batched_out) 1251 1252 vmapjvpall_fail = { 1253 # -------------------- ALLOWED FAILURES -------------------------------- 1254 # The following are expected (not a bug) 1255 skip("bernoulli", ""), # randomness 1256 skip("nn.functional.dropout"), # randomness 1257 skip("nn.functional.rrelu"), # randomness 1258 skip("nn.functional.dropout2d", ""), 1259 skip("nn.functional.dropout3d", ""), 1260 skip("nn.functional.scaled_dot_product_attention"), # randomness 1261 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 1262 skip("nn.functional.multi_head_attention_forward"), # randomness 1263 skip("nn.functional.alpha_dropout"), # randomness 1264 skip("nn.functional.feature_alpha_dropout", "without_train"), 1265 skip("nn.functional.feature_alpha_dropout", "with_train"), 1266 xfail( 1267 "nn.functional.fractional_max_pool2d" 1268 ), # Cannot access data pointer of Tensor that doesn't have storage 1269 xfail( 1270 "nn.functional.fractional_max_pool3d" 1271 ), # Cannot access data pointer of Tensor that doesn't have storage 1272 # Not actually a problem: embedding with max_norm mutates the weight 1273 # and causes different runs to produce different results. 1274 # skip because this is flaky depending on what the max_norm is! 1275 skip("nn.functional.embedding", ""), 1276 skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format 1277 xfail( 1278 "NumpyExpMarkDirtyAutogradFunction" 1279 ), # vmap: inplace into a regular tensor 1280 # ---------------------------------------------------------------------- 1281 # ---------------------------- BUGS ------------------------------------ 1282 # The following are bugs that we should fix 1283 xfail("masked.mean"), # silent incorrectness (nan difference) 1284 xfail("as_strided", "partial_views"), # Tensor-likes are not close! 1285 xfail( 1286 "nn.functional.soft_margin_loss", "" 1287 ), # soft_margin_loss_backward does not support forward-ad 1288 xfail("tensor_split"), # data_ptr composite compliance 1289 xfail("quantile"), # at::equal batching rule (cpu), also, in-place vmap (cuda) 1290 skip("as_strided"), # Test runner cannot handle this 1291 # requires special handling, and does not yet have a batching rule. Feel free to file a github issue! 1292 xfail("as_strided_scatter"), 1293 xfail( 1294 "nn.functional.gaussian_nll_loss" 1295 ), # .item or data-dependent control flow 1296 xfail("scatter"), # forward-mode AD does not support at::scatter 1297 xfail( 1298 "nanquantile" 1299 ), # at::equal batching rule (cpu), also, in-place vmap (cuda) 1300 xfail("view_as_complex"), # Tensor must have a last dimension with stride 1 1301 skip("pca_lowrank", ""), # randomness 1302 skip("svd_lowrank", ""), # randomness 1303 xfail("double"), # required rank 4 tensor to use channels_last format 1304 xfail("cdouble"), # required rank 4 tensor to use channels_last format 1305 # potential silent incorrectness 1306 skip( 1307 "nn.functional.max_unpool1d" 1308 ), # Flaky, seems to sometimes his max_unpool2d 1309 skip("nn.functional.max_unpool2d"), # fails everywhere except on mac 1310 skip("nn.functional.max_unpool3d"), # fails everywhere except on mac 1311 # erroring because running_mean and running_var aren't differentiable 1312 xfail("nn.functional.batch_norm"), 1313 xfail("nn.functional.batch_norm", "without_cudnn"), 1314 xfail("native_batch_norm"), 1315 xfail("_native_batch_norm_legit"), 1316 # TODO: implement batching rule 1317 xfail("_batch_norm_with_update"), 1318 # ---------------------------------------------------------------------- 1319 } 1320 1321 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 1322 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 1323 @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) 1324 @opsToleranceOverride( 1325 "TestOperators", 1326 "test_vmapjvpall", 1327 ( 1328 tol1( 1329 "nn.functional.conv_transpose3d", 1330 {torch.float32: tol(atol=2e-04, rtol=9e-3)}, 1331 device_type="cuda", 1332 ), 1333 tol1( 1334 "linalg.householder_product", 1335 {torch.float32: tol(atol=2e-04, rtol=9e-3)}, 1336 ), 1337 ), 1338 ) 1339 @skipOps( 1340 "TestOperators", 1341 "test_vmapjvpall", 1342 vmapjvpall_fail.union( 1343 { 1344 xfail("as_strided_copy"), 1345 decorate( 1346 "linalg.det", 1347 "singular", 1348 decorator=expectedFailureIf(IS_MACOS and IS_X86), 1349 ), 1350 } 1351 ), 1352 ) 1353 # This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp 1354 # or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact 1355 # because that corresponds to "batched forward-mode AD" testing in PyTorch core 1356 def test_vmapjvpall(self, device, dtype, op): 1357 if is_inplace(op, op.get_op()): 1358 # TODO: test in-place 1359 self.skipTest("Skipped! NYI: inplace-testing not supported.") 1360 return 1361 1362 samples = op.sample_inputs(device, dtype, requires_grad=False) 1363 1364 if not op.supports_forward_ad: 1365 self.skipTest("Skipped! Forward AD not supported.") 1366 return 1367 1368 for sample in samples: 1369 arg_values = [sample.input] + list(sample.args) 1370 kwarg_values = sample.kwargs 1371 args = tuple(arg_values) + tuple(kwarg_values) 1372 fn, args = get_jvp_variant_primals_tangents(op, sample) 1373 is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) 1374 generator = get_fallback_and_vmap_exhaustive( 1375 fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training 1376 ) 1377 for loop_out, batched_out in generator: 1378 self.assertEqual(loop_out, batched_out) 1379 1380 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 1381 @skipOps( 1382 "TestOperators", 1383 "test_vmapjvpall_has_batch_rule", 1384 vmapjvpall_fail.union( 1385 { 1386 skip( 1387 "to" 1388 ), # RuntimeError: required rank 4 tensor to use channels_last format 1389 xfail( 1390 "cdouble" 1391 ), # RuntimeError: required rank 4 tensor to use channels_last format 1392 xfail("cumprod"), 1393 xfail("masked_fill"), 1394 xfail("fill"), 1395 skip("masked.mean"), # ??? 1396 xfail("masked_scatter"), 1397 xfail("put"), 1398 xfail("take"), 1399 xfail("nn.functional.feature_alpha_dropout", "without_train"), 1400 xfail("nn.functional.dropout2d", ""), 1401 xfail("pca_lowrank", ""), 1402 xfail("svd_lowrank", ""), 1403 xfail("nn.functional.feature_alpha_dropout", "with_train"), 1404 xfail("special.log_ndtr", ""), 1405 xfail("fft.ihfft2"), # conj_physical fallback 1406 xfail("fft.ihfftn"), # conj_physical fallback 1407 xfail("nn.functional.max_unpool3d", "grad"), 1408 xfail("nn.functional.max_unpool2d", "grad"), 1409 xfail("nn.functional.soft_margin_loss", ""), 1410 xfail("nn.functional.max_unpool1d", "grad"), 1411 xfail("nn.functional.embedding", ""), 1412 xfail( 1413 "scatter_reduce", "sum" 1414 ), # aten::scatter_reduce.two hit the vmap fallback 1415 xfail( 1416 "scatter_reduce", "mean" 1417 ), # aten::scatter_reduce.two hit the vmap fallback 1418 xfail( 1419 "scatter_reduce", "amin" 1420 ), # aten::scatter_reduce.two hit the vmap fallback 1421 xfail( 1422 "scatter_reduce", "amax" 1423 ), # aten::scatter_reduce.two hit the vmap fallback 1424 xfail("nn.functional.glu"), 1425 xfail("nn.functional.bilinear"), # trilinear doesn't have batching rule 1426 xfail("linalg.lu", ""), 1427 xfail("nn.functional.dropout3d", ""), 1428 xfail("as_strided_scatter", ""), 1429 xfail("masked.cumprod", ""), 1430 xfail("renorm"), # hit vmap fallback, which is disabled 1431 xfail("t_copy"), 1432 xfail("unsqueeze_copy"), 1433 } 1434 ), 1435 ) 1436 @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) 1437 def test_vmapjvpall_has_batch_rule(self, device, dtype, op): 1438 if is_inplace(op, op.get_op()): 1439 # TODO: test in-place 1440 self.skipTest("Skipped! NYI: inplace-testing not supported.") 1441 return 1442 1443 samples = op.sample_inputs(device, dtype, requires_grad=False) 1444 1445 if not op.supports_forward_ad: 1446 self.skipTest("Skipped! Forward AD not supported.") 1447 return 1448 1449 def test(): 1450 for sample in samples: 1451 arg_values = [sample.input] + list(sample.args) 1452 kwarg_values = sample.kwargs 1453 args = tuple(arg_values) + tuple(kwarg_values) 1454 fn, args = get_jvp_variant_primals_tangents(op, sample) 1455 is_batch_norm_and_training = is_batch_norm_training( 1456 op.name, kwarg_values 1457 ) 1458 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 1459 fn, 1460 args, 1461 {}, 1462 is_batch_norm_and_training=is_batch_norm_and_training, 1463 compute_loop_out=False, 1464 ): 1465 pass 1466 1467 check_vmap_fallback(self, test, op, dry_run=False) 1468 1469 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 1470 @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) 1471 @skipOps( 1472 "TestOperators", 1473 "test_vmapvjp_has_batch_rule", 1474 vmapvjp_fail.union( 1475 { 1476 skip( 1477 "to" 1478 ), # RuntimeError: required rank 4 tensor to use channels_last format 1479 xfail("view_as_complex"), 1480 xfail("cummax"), 1481 xfail("cummin"), 1482 xfail("fill"), 1483 xfail( 1484 "narrow" 1485 ), # Batching rule not implemented for `narrow.Tensor` (and view op) 1486 xfail("special.log_ndtr"), 1487 xfail("linalg.householder_product"), 1488 xfail("masked_fill"), 1489 xfail("masked_scatter"), 1490 xfail("masked_select"), 1491 xfail("nanquantile"), 1492 xfail("ormqr"), 1493 xfail("put"), 1494 xfail( 1495 "scatter_reduce", "sum" 1496 ), # aten::scatter_reduce.two hit the vmap fallback 1497 xfail( 1498 "scatter_reduce", "mean" 1499 ), # aten::scatter_reduce.two hit the vmap fallback 1500 xfail( 1501 "scatter_reduce", "amin" 1502 ), # aten::scatter_reduce.two hit the vmap fallback 1503 xfail( 1504 "scatter_reduce", "amax" 1505 ), # aten::scatter_reduce.two hit the vmap fallback 1506 xfail("quantile"), 1507 xfail("renorm"), 1508 xfail("take"), 1509 xfail("tensor_split"), 1510 xfail("to_sparse"), 1511 xfail("unfold"), 1512 xfail("unfold_copy"), 1513 xfail("nn.functional.dropout"), 1514 xfail("fft.ihfft2"), 1515 xfail("fft.ihfftn"), 1516 xfail("nn.functional.gaussian_nll_loss"), 1517 xfail("nn.functional.bilinear"), 1518 xfail("nn.functional.fractional_max_pool3d"), 1519 xfail("nn.functional.ctc_loss"), 1520 xfail("nn.functional.rrelu"), 1521 xfail("nn.functional.embedding_bag"), 1522 xfail("nn.functional.fractional_max_pool2d"), 1523 xfail("nn.functional.feature_alpha_dropout", "with_train"), 1524 xfail("pca_lowrank", ""), 1525 xfail("nn.functional.dropout2d", ""), 1526 xfail("nn.functional.feature_alpha_dropout", "without_train"), 1527 xfail("svd_lowrank", ""), 1528 xfail("nn.functional.max_unpool2d", ""), 1529 xfail("nn.functional.multi_margin_loss", ""), 1530 xfail("nn.functional.multilabel_margin_loss", ""), 1531 xfail("nn.functional.pdist", ""), 1532 xfail("scatter_reduce", "prod"), 1533 xfail("nn.functional.max_unpool1d", ""), 1534 xfail("nn.functional.max_unpool3d", ""), 1535 xfail("nn.functional.max_unpool3d", "grad"), 1536 xfail("nn.functional.soft_margin_loss", ""), 1537 xfail("nn.functional.max_unpool1d", "grad"), 1538 xfail("nn.functional.max_unpool2d", "grad"), 1539 xfail("linalg.lu", ""), 1540 xfail("cdouble", ""), 1541 xfail("cfloat", ""), 1542 xfail("chalf", ""), 1543 xfail( 1544 "index_reduce", "prod" 1545 ), # aten::index_reduce hit the vmap fallback which is currently disabled 1546 xfail( 1547 "index_reduce", "mean" 1548 ), # aten::index_reduce hit the vmap fallback which is currently disabled 1549 xfail( 1550 "index_reduce", "amax" 1551 ), # aten::index_reduce hit the vmap fallback which is currently disabled 1552 xfail( 1553 "index_reduce", "amin" 1554 ), # aten::index_reduce hit the vmap fallback which is currently disabled 1555 xfail("nn.functional.dropout3d", ""), 1556 xfail("as_strided_scatter", ""), 1557 xfail("_segment_reduce", "offsets"), 1558 xfail("_segment_reduce", "lengths"), 1559 xfail("sparse.sampled_addmm", ""), 1560 xfail("sparse.mm", "reduce"), 1561 xfail("native_batch_norm"), 1562 xfail("_native_batch_norm_legit"), 1563 # TODO: implement batching rule 1564 xfail("_batch_norm_with_update"), 1565 xfail("native_dropout_backward"), 1566 xfail( 1567 "index_fill" 1568 ), # aten::_unique hit the vmap fallback which is currently disabled 1569 xfail("t_copy"), 1570 xfail("unsqueeze_copy"), 1571 } 1572 ), 1573 ) 1574 def test_vmapvjp_has_batch_rule(self, device, dtype, op): 1575 if not op.supports_autograd: 1576 self.skipTest("Skipped! Autograd not supported.") 1577 return 1578 1579 samples = op.sample_inputs(device, dtype, requires_grad=True) 1580 1581 # TODO: test in-place 1582 if is_inplace(op, op.get_op()): 1583 self.skipTest("Skipped! NYI: inplace-testing not supported.") 1584 return 1585 1586 def test(): 1587 for sample in samples: 1588 cotangents = get_sample_cotangents(op, sample) 1589 fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) 1590 is_batch_norm_and_training = is_batch_norm_training( 1591 op.name, sample.kwargs 1592 ) 1593 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 1594 fn, 1595 args, 1596 {}, 1597 is_batch_norm_and_training=is_batch_norm_and_training, 1598 compute_loop_out=False, 1599 ): 1600 pass 1601 for a_op in op.aliases: 1602 fn, args = get_vjp_fn_and_args_with_cotangents( 1603 a_op, sample, cotangents 1604 ) 1605 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 1606 fn, 1607 args, 1608 {}, 1609 is_batch_norm_and_training=is_batch_norm_and_training, 1610 compute_loop_out=False, 1611 ): 1612 pass 1613 1614 check_vmap_fallback(self, test, op, dry_run=False) 1615 1616 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 1617 @skipOps( 1618 "TestOperators", 1619 "test_vjpvmap", 1620 vjp_fail.union( 1621 { 1622 skip("bernoulli", ""), # vjpvmap testing can't handle randomness 1623 skip("normal", ""), # vjpvmap testing can't handle randomness 1624 skip( 1625 "normal", "number_mean" 1626 ), # vjpvmap testing can't handle randomness 1627 skip("nn.functional.rrelu"), # randomness 1628 skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness 1629 skip( 1630 "nn.functional.feature_alpha_dropout", "without_train" 1631 ), # randomness 1632 skip("nn.functional.scaled_dot_product_attention"), 1633 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 1634 skip("nn.functional.multi_head_attention_forward"), # randomness 1635 skip("nn.functional.alpha_dropout"), # randomness 1636 skip( 1637 "to" 1638 ), # RuntimeError: required rank 4 tensor to use channels_last format 1639 skip("to_sparse", ""), # non-dense output 1640 skip("ormqr", ""), # takes too long 1641 xfail( 1642 "NumpyCubeNotComposableAutogradFunction" 1643 ), # Not composable autograd.Function 1644 # fallback path doesn't work 1645 # All of the following are bugs and need to be fixed 1646 xfail("__getitem__", ""), 1647 xfail("index_put", ""), 1648 xfail("view_as_complex"), 1649 xfail("nn.functional.gaussian_nll_loss"), 1650 xfail("masked_select"), 1651 xfail( 1652 "narrow" 1653 ), # Batching rule not implemented for `narrow.Tensor` (and view op) 1654 skip( 1655 "nn.functional.fractional_max_pool3d" 1656 ), # generator works on cpu, fails on cuda 1657 skip( 1658 "nn.functional.fractional_max_pool2d" 1659 ), # generator works on cpu, fails on cuda 1660 xfail("column_stack", ""), 1661 xfail("nn.functional.dropout2d", ""), 1662 xfail("svd_lowrank", ""), 1663 xfail("pca_lowrank", ""), 1664 xfail("clamp"), 1665 # something weird happening with channels_last 1666 xfail("bfloat16"), 1667 xfail("double"), 1668 xfail("float"), 1669 xfail("half"), 1670 xfail("cdouble"), 1671 xfail("cfloat"), 1672 xfail("nn.functional.dropout3d", ""), 1673 xfail("as_strided_scatter", ""), 1674 xfail("sparse.sampled_addmm", ""), 1675 xfail("sparse.mm", "reduce"), 1676 xfail("native_batch_norm"), 1677 xfail("_native_batch_norm_legit"), 1678 # TODO: implement batching rule 1679 xfail("_batch_norm_with_update"), 1680 xfail("as_strided", "partial_views"), 1681 } 1682 ), 1683 ) 1684 def test_vjpvmap(self, device, dtype, op): 1685 # NB: there is no vjpvmap_has_batch_rule test because that is almost 1686 # certainly redundant with the vmap_has_batch_rule test in test_vmap.py 1687 1688 # one-off skip 1689 if op.name == "nn.functional.dropout": 1690 self.skipTest("Skipped!") 1691 1692 if not op.supports_autograd: 1693 # If the op doesn't support autograd, vmap(op) won't either 1694 self.skipTest("Skipped! Autograd not supported.") 1695 return 1696 1697 # TODO: test in-place 1698 if is_inplace(op, op.get_op()): 1699 self.skipTest("Skipped! NYI: inplace-testing not supported.") 1700 return 1701 1702 samples = op.sample_inputs(device, dtype, requires_grad=True) 1703 batch_norm_fns = ( 1704 "nn.functional.batch_norm", 1705 "nn.functional.instance_norm", 1706 ) # instance norm calls batch norm 1707 is_batch_norm = op.name in batch_norm_fns 1708 1709 for sample in samples: 1710 args = [sample.input] + list(sample.args) 1711 kwargs = sample.kwargs 1712 1713 is_batch_norm_and_training = is_batch_norm and is_batch_norm_training( 1714 op.name, kwargs 1715 ) 1716 generator = generate_vmap_inputs( 1717 args, kwargs, is_batch_norm_and_training=is_batch_norm_and_training 1718 ) 1719 1720 for batched_args, in_dims, kwargs in generator: 1721 vmapped_op = vmap(op, in_dims) 1722 fn, primals = normalize_op_input_output2( 1723 vmapped_op, batched_args, kwargs, sample.output_process_fn_grad 1724 ) 1725 result = fn(*primals) 1726 cotangents = tree_map(lambda x: torch.randn_like(x), result) 1727 1728 _, vjp_fn = vjp(fn, *primals) 1729 result_vjps = vjp_fn(cotangents) 1730 1731 _, vjp_fn = ref_vjp(fn, *primals) 1732 expected_vjps = vjp_fn(cotangents) 1733 1734 self.assertEqual(result_vjps, expected_vjps) 1735 1736 def _compare_jacobians_of_vjp( 1737 self, fn, cotangents_and_primals, argnums=None, atol_rtol=None 1738 ): 1739 if argnums is None: 1740 argnums = tuple(range(len(cotangents_and_primals))) 1741 1742 def get_vjp(cotangents, *primals): 1743 _, vjp_fn = vjp(fn, *primals) 1744 return vjp_fn(cotangents) 1745 1746 jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals) 1747 jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals) 1748 1749 # For dtype changing operations, the jacobians have different dtype. 1750 jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp) 1751 jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp) 1752 1753 if atol_rtol is not None: 1754 (atol, rtol) = atol_rtol 1755 self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol) 1756 else: 1757 self.assertEqual(jacobian_jvp, jacobian_vjp) 1758 1759 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 1760 @skipOps( 1761 "TestOperators", 1762 "test_jvpvjp", 1763 vjp_fail.union( 1764 { 1765 xfail("to_sparse", ""), # NYI 1766 # RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor, 1767 # this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3]. 1768 xfail("normal", ""), 1769 xfail("cdist", ""), # NYI: forward-AD for _cdist_forward 1770 xfail("cholesky", ""), # NYI: forward-AD for cholesky 1771 xfail( 1772 "nn.functional.embedding_bag", "" 1773 ), # NYI: forward-AD for _embedding_bag 1774 xfail( 1775 "nn.functional.grid_sample", "" 1776 ), # NYI: forward AD for grid_sampler_2d 1777 xfail("grid_sampler_2d", ""), # NYI: forward AD for grid_sampler_2d 1778 xfail( 1779 "nn.functional.hardsigmoid", "" 1780 ), # NYI: forward AD for hardsigmoid_backward 1781 xfail( 1782 "nn.functional.huber_loss", "" 1783 ), # NYI: forward AD for huber_loss_backward 1784 xfail("NumpyCubeNotComposableAutogradFunction"), # not composable 1785 xfail("ormqr", ""), # NYI: forward AD for ormqr 1786 xfail( 1787 "nn.functional.multilabel_margin_loss", "" 1788 ), # NYI: multilabel_margin_loss_forward 1789 xfail( 1790 "nn.functional.soft_margin_loss", "" 1791 ), # NYI: forward-AD for soft_margin_loss_backward 1792 xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss 1793 xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward 1794 skip("nn.functional.scaled_dot_product_attention"), 1795 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 1796 xfail( 1797 "nn.functional.multi_margin_loss", "" 1798 ), # NYI: forward AD with multi_margin_loss 1799 skip( 1800 "linalg.householder_product", "", device_type="cuda" 1801 ), # flaky, I'm not sure why 1802 xfail("sparse.sampled_addmm", ""), # Sparse tensors have no strides 1803 xfail( 1804 "_segment_reduce", "offsets" 1805 ), # NYI: forward-AD for _segment_reduce 1806 xfail("sparse.mm", "reduce"), # Sparse tensors have no strides 1807 xfail("index_reduce", "prod"), # NYI: forward-AD for index_reduce 1808 xfail("index_reduce", "mean"), # NYI: forward-AD for index_reduce 1809 xfail("index_reduce", "amax"), # NYI: forward-AD for index_reduce 1810 xfail("index_reduce", "amin"), # NYI: forward-AD for index_reduce 1811 xfail( 1812 "_segment_reduce", "lengths" 1813 ), # NYI: forward-AD for _segment_reduce 1814 xfail("native_dropout_backward"), # NYI 1815 } 1816 ), 1817 ) 1818 @opsToleranceOverride( 1819 "TestOperators", 1820 "test_jvpvjp", 1821 ( 1822 tol1("masked.prod", {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}), 1823 tol1("masked.cumprod", {torch.float32: tol(atol=1e-04, rtol=5e-04)}), 1824 tol1( 1825 "cumprod", 1826 {torch.float32: tol(atol=1e-03, rtol=5e-04)}, 1827 device_type="cuda", 1828 ), 1829 tol1( 1830 "linalg.det", 1831 {torch.float32: tol(atol=3e-05, rtol=5e-06)}, 1832 device_type="cuda", 1833 ), 1834 tol1( 1835 "linalg.vander", 1836 {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, 1837 device_type="cuda", 1838 ), 1839 tol1( 1840 "nn.functional.group_norm", {torch.float32: tol(atol=1e-03, rtol=1e-03)} 1841 ), 1842 tol2( 1843 "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-03, rtol=5e-03)} 1844 ), 1845 ), 1846 ) 1847 def test_jvpvjp(self, device, dtype, op): 1848 if not op.supports_autograd: 1849 self.skipTest("Skipped! Autograd not supported.") 1850 return 1851 1852 samples = op.sample_inputs(device, dtype, requires_grad=True) 1853 1854 # TODO: test in-place 1855 if is_inplace(op, op.get_op()): 1856 self.skipTest("Skipped! NYI: inplace-testing not supported.") 1857 return 1858 1859 for sample in samples: 1860 fn, primals = normalize_op_input_output(op, sample) 1861 result = fn(*primals) 1862 cotangents = tree_map(lambda x: torch.randn_like(x), result) 1863 1864 primals_tangents = tree_map(lambda x: torch.randn_like(x), primals) 1865 cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents) 1866 1867 def push_vjp(primals, cotangents): 1868 _, vjp_fn = vjp(fn, *primals) 1869 return vjp_fn(cotangents) 1870 1871 result = jvp( 1872 push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents) 1873 ) 1874 self.assertEqual(len(result), 2) 1875 1876 def tree_map2(fn, first, second): 1877 flat_first, spec_first = tree_flatten(first) 1878 flat_second, spec_second = tree_flatten(second) 1879 assert spec_first == spec_second 1880 flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)] 1881 return tree_unflatten(flat_result, spec_first) 1882 1883 def reference(primals, cotangents, primals_tangents, cotangents_tangents): 1884 with fwAD.dual_level(): 1885 primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents) 1886 _, vjp_fn = ref_vjp(fn, *primal_duals) 1887 1888 cotangent_duals = tree_map2( 1889 fwAD.make_dual, cotangents, cotangents_tangents 1890 ) 1891 result = vjp_fn(cotangent_duals) 1892 1893 flat_result, spec = tree_flatten(result) 1894 primals_out, tangents_out = zip( 1895 *[fwAD.unpack_dual(r) for r in flat_result] 1896 ) 1897 tangents_out = [ 1898 t if t is not None else torch.zeros_like(p) 1899 for p, t in zip(primals_out, tangents_out) 1900 ] 1901 expected = ( 1902 tree_unflatten(primals_out, spec), 1903 tree_unflatten(tangents_out, spec), 1904 ) 1905 return expected 1906 1907 expected = reference( 1908 primals, cotangents, primals_tangents, cotangents_tangents 1909 ) 1910 self.assertEqual(result, expected) 1911 1912 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 1913 @skipOps( 1914 "TestOperators", 1915 "test_vmapjvpvjp", 1916 vjp_fail.union( 1917 { 1918 # Following operators take too long, hence skipped 1919 skip("atleast_1d"), 1920 skip("atleast_2d"), 1921 skip("atleast_3d"), 1922 skip("meshgrid", "list_of_tensors"), 1923 skip("meshgrid", "variadic_tensors"), 1924 skip("broadcast_tensors"), 1925 skip("linalg.lstsq"), 1926 skip("nn.functional.bilinear"), 1927 skip("native_layer_norm"), 1928 skip("ormqr"), 1929 # Not actually a problem 1930 xfail("NumpyCubeNotComposableAutogradFunction"), # not composable 1931 xfail( 1932 "NumpyExpMarkDirtyAutogradFunction" 1933 ), # vmap: inplace into a regular tensor 1934 # Potential bugs/errors 1935 xfail("as_strided"), # AssertionError: Tensor-likes are not close! 1936 xfail( 1937 "as_strided", "partial_views" 1938 ), # AssertionError: Tensor-likes are not close! 1939 xfail("as_strided_copy"), # AssertionError: Tensor-likes are not close! 1940 xfail( 1941 "as_strided_scatter" 1942 ), # AssertionError: Tensor-likes are not close! 1943 xfail("bernoulli"), # calls random op 1944 xfail("bfloat16"), # required rank 4 tensor to use channels_last format 1945 xfail("cdist"), # Forward AD not implemented and no decomposition 1946 xfail("cdouble"), # required rank 4 tensor to use channels_last format 1947 xfail("cfloat"), # required rank 4 tensor to use channels_last format 1948 xfail("chalf"), # required rank 4 tensor to use channels_last format 1949 xfail("cholesky"), # Forward AD not implemented and no decomposition 1950 xfail("ormqr"), # Forward AD not implemented and no decomposition 1951 xfail("double"), # required rank 4 tensor to use channels_last format 1952 xfail("float"), # required rank 4 tensor to use channels_last format 1953 xfail("half"), # required rank 4 tensor to use channels_last format 1954 xfail("index_reduce", "prod"), # NYI: forward AD for index_reduce 1955 xfail("index_reduce", "mean"), # NYI: forward AD for index_reduce 1956 xfail("index_reduce", "amax"), # NYI: forward AD for index_reduce 1957 xfail("index_reduce", "amin"), # NYI: forward AD for index_reduce 1958 xfail( 1959 "mvlgamma", "mvlgamma_p_1" 1960 ), # vmap: inplace into a regular tensor 1961 xfail( 1962 "mvlgamma", "mvlgamma_p_3" 1963 ), # vmap: inplace into a regular tensor 1964 xfail( 1965 "mvlgamma", "mvlgamma_p_5" 1966 ), # vmap: inplace into a regular tensor 1967 xfail("nanquantile"), # Batching rule not implemented for aten::equal 1968 # RuntimeError: Batch norm got a batched tensor as input while the 1969 # running_mean or running_var, which will be updated in place, 1970 # were not batched. 1971 xfail("nn.functional.batch_norm"), 1972 xfail("nn.functional.batch_norm", "without_cudnn"), 1973 xfail( 1974 "nn.functional.ctc_loss" 1975 ), # ForwardAD not implemented and no decomposition 1976 xfail("nn.functional.dropout2d"), # calls random op 1977 xfail("nn.functional.dropout3d"), # calls random op 1978 xfail("nn.functional.dropout"), # calls random op 1979 xfail("nn.functional.scaled_dot_product_attention"), # randomness 1980 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 1981 xfail("nn.functional.multi_head_attention_forward"), # randomness 1982 xfail( 1983 "nn.functional.embedding_bag" 1984 ), # Forward AD not implemented and no decomposition 1985 xfail("nn.functional.alpha_dropout"), # calls randomn op 1986 xfail( 1987 "nn.functional.feature_alpha_dropout", "with_train" 1988 ), # calls random op 1989 xfail("nn.functional.fractional_max_pool2d"), # calls random op 1990 xfail("nn.functional.fractional_max_pool3d"), # calls random op 1991 xfail("nn.functional.gaussian_nll_loss"), # data depenedant flow 1992 xfail( 1993 "nn.functional.grid_sample" 1994 ), # Forward AD not implemented and no decomposition 1995 xfail( 1996 "grid_sampler_2d" 1997 ), # Forward AD not implemented and no decomposition 1998 xfail( 1999 "nn.functional.hardsigmoid" 2000 ), # Forward AD not implemented and no decomposition 2001 xfail( 2002 "nn.functional.hinge_embedding_loss" 2003 ), # vmap: inplace into a regular tensor 2004 xfail( 2005 "nn.functional.huber_loss" 2006 ), # Forward AD not implemented and no decomposition 2007 # RuntimeError: Batch norm got a batched tensor as input while the 2008 # running_mean or running_var, which will be updated in place, 2009 # were not batched. 2010 xfail("nn.functional.instance_norm"), 2011 # NYI: Tensor.clone(memory_format) inside vmap is only supported with 2012 # memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast) 2013 xfail("nn.functional.max_unpool2d"), 2014 xfail("nn.functional.max_unpool2d", "grad"), 2015 xfail( 2016 "nn.functional.multi_margin_loss" 2017 ), # Forward AD not implemented and no decomposition 2018 xfail( 2019 "nn.functional.multilabel_margin_loss" 2020 ), # Forward AD not implemented and no decomposition 2021 xfail( 2022 "nn.functional.pdist" 2023 ), # Forward AD not implemented and no decomposition 2024 xfail( 2025 "nn.functional.rrelu" 2026 ), # vmap: we do not yet support aten::rrelu_with_noise. 2027 xfail( 2028 "nn.functional.soft_margin_loss" 2029 ), # Forward AD not implemented and no decomposition 2030 xfail("normal"), # calls random op 2031 xfail("normal", "number_mean"), # calls random op 2032 xfail("pca_lowrank"), # calls random op 2033 xfail("quantile"), # Batching rule not implemented for aten::equal 2034 xfail( 2035 "scatter_reduce", "prod" 2036 ), # Forward AD not implemented and no decomposition 2037 xfail( 2038 "_segment_reduce", "lengths" 2039 ), # Forward AD not implemented and no decomposition 2040 xfail( 2041 "_segment_reduce", "offsets" 2042 ), # Forward AD not implemented and no decomposition 2043 xfail( 2044 "sparse.sampled_addmm" 2045 ), # RuntimeError: Sparse CSR tensors do not have strides 2046 xfail( 2047 "sparse.mm", "reduce" 2048 ), # RuntimeError: Sparse CSR tensors do not have strides 2049 xfail("svd_lowrank"), # calls random op 2050 xfail( 2051 "to" 2052 ), # RuntimeError: required rank 4 tensor to use channels_last format 2053 xfail("to_sparse"), # Forward AD not implemented and no decomposition 2054 xfail( 2055 "view_as_complex" 2056 ), # RuntimeError: Tensor must have a last dimension with stride 1 2057 # RuntimeError: Batch norm got a batched tensor as 2058 # input while the running_mean or running_var, which will be updated in 2059 # place, were not batched. 2060 xfail("native_batch_norm"), 2061 xfail("_native_batch_norm_legit"), 2062 # TODO: implement batching rule 2063 xfail("_batch_norm_with_update"), 2064 xfail("native_dropout_backward"), 2065 } 2066 ), 2067 ) 2068 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) 2069 @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) 2070 @opsToleranceOverride( 2071 "TestOperators", 2072 "test_vmapjvpvjp", 2073 ( 2074 tol1("linalg.svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), 2075 tol1( 2076 "linalg.householder_product", 2077 {torch.float32: tol(atol=5e-03, rtol=5e-03)}, 2078 ), 2079 tol1("linalg.multi_dot", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), 2080 tol2( 2081 "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-04, rtol=5e-04)} 2082 ), 2083 tol1( 2084 "nn.functional.conv_transpose2d", 2085 {torch.float32: tol(atol=5e-04, rtol=5e-04)}, 2086 ), 2087 tol1("svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), 2088 tol1("matrix_exp", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), 2089 ), 2090 ) 2091 def test_vmapjvpvjp(self, device, dtype, op): 2092 # Since we test `jvpvjp` separately, 2093 # in this we just check that vmap of `jvpvjp` 2094 # is correct. 2095 if not op.supports_autograd: 2096 self.skipTest("Skipped! Autograd not supported.") 2097 return 2098 2099 samples = op.sample_inputs(device, dtype, requires_grad=True) 2100 2101 # TODO: test in-place 2102 if is_inplace(op, op.get_op()): 2103 self.skipTest("Skipped! NYI: inplace-testing not supported.") 2104 return 2105 2106 for sample in samples: 2107 fn, primals = normalize_op_input_output(op, sample) 2108 result = fn(*primals) 2109 cotangents = tree_map(lambda x: torch.randn_like(x), result) 2110 2111 primals_tangents = tree_map(lambda x: torch.randn_like(x), primals) 2112 cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents) 2113 2114 def push_vjp(primals, cotangents): 2115 _, vjp_fn = vjp(fn, *primals) 2116 return vjp_fn(cotangents) 2117 2118 args, spec = tree_flatten( 2119 ((primals, cotangents), (primals_tangents, cotangents_tangents)) 2120 ) 2121 2122 def jvp_of_vjp(*args): 2123 (primals, tangents) = tree_unflatten(args, spec) 2124 primals_out, tangents_out = jvp(push_vjp, primals, tangents) 2125 2126 flat_primals_out = pytree.tree_leaves(primals_out) 2127 flat_tangents_out = pytree.tree_leaves(tangents_out) 2128 return tuple(flat_primals_out + flat_tangents_out) 2129 2130 is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs) 2131 generator = get_fallback_and_vmap_exhaustive( 2132 jvp_of_vjp, 2133 args, 2134 {}, 2135 is_batch_norm_and_training=is_batch_norm_and_training, 2136 ) 2137 for loop_out, batched_out in generator: 2138 self.assertEqual(loop_out, batched_out) 2139 2140 def _make_extremal_inputs(self, shape, device): 2141 if shape is None: 2142 return (None,) 2143 return ( 2144 torch.full(shape, -1000.0, device=device), 2145 torch.zeros(shape, device=device), 2146 torch.full(shape, 1000.0, device=device), 2147 ) 2148 2149 def _arg_and_kwarg_options(self, args_options, kwargs_options): 2150 return itertools.product(*args_options, kwargs_options) 2151 2152 def test_extremal_numerics_nll_loss(self, device): 2153 N, C = 3, 4 2154 d1, d2, d3 = 5, 6, 7 2155 shapes = ( 2156 ((N, C), (N,), (C,)), 2157 ((N, C), (N,), None), 2158 ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)), 2159 ((N, C, d1, d2, d3), (N, d1, d2, d3), None), 2160 ) 2161 kwargs_options = ( 2162 {"ignore_index": 0, "reduction": "mean"}, 2163 {"reduction": "sum"}, 2164 {"reduction": "none"}, 2165 {}, 2166 ) 2167 for input_shape, target_shape, weight_shape in shapes: 2168 input_options = self._make_extremal_inputs(input_shape, device) 2169 for input, kwargs in self._arg_and_kwarg_options( 2170 (input_options,), kwargs_options 2171 ): 2172 if weight_shape is None: 2173 weight = None 2174 else: 2175 weight = torch.randn(weight_shape, device=device) 2176 target = torch.randint(0, C, target_shape, device=device) 2177 target[ 2178 0 2179 ] = 1 # since we're ignoring index 0, at least one element must be non-zero 2180 2181 fn = functools.partial( 2182 torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs 2183 ) 2184 result = fn(input) 2185 cotangents = torch.randn_like(result, device=device) 2186 self._compare_jacobians_of_vjp(fn, (cotangents, input)) 2187 2188 def test_extremal_numerics_l1_loss(self, device): 2189 N, C, H, W = 3, 4, 5, 6 2190 shapes = ((N, C), (N, C, H), (N, C, H, W)) 2191 kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {}) 2192 for shape in shapes: 2193 input_options = self._make_extremal_inputs(shape, device) 2194 target_options = self._make_extremal_inputs(shape, device) 2195 for input, target, kwargs in self._arg_and_kwarg_options( 2196 (input_options, target_options), kwargs_options 2197 ): 2198 result = torch.nn.functional.l1_loss(input, target) 2199 cotangents = torch.randn_like(result, device=device) 2200 self._compare_jacobians_of_vjp( 2201 torch.nn.functional.l1_loss, (cotangents, input, target) 2202 ) 2203 2204 def test_extremal_numerics_mse_loss(self, device): 2205 N, C, H, W = 3, 4, 5, 6 2206 shapes = ((N, C), (N, C, H), (N, C, H, W)) 2207 kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {}) 2208 for shape in shapes: 2209 input_options = self._make_extremal_inputs(shape, device) 2210 target_options = self._make_extremal_inputs(shape, device) 2211 for input, target, kwargs in self._arg_and_kwarg_options( 2212 (input_options, target_options), kwargs_options 2213 ): 2214 result = torch.nn.functional.mse_loss(input, target) 2215 cotangents = torch.randn_like(result, device=device) 2216 self._compare_jacobians_of_vjp( 2217 torch.nn.functional.mse_loss, (cotangents, input, target) 2218 ) 2219 2220 def test_extremal_numerics_softmax(self, device): 2221 N, C, H, W = 3, 4, 5, 6 2222 shapes = ((N, C), (N, C, H), (N, C, H, W)) 2223 kwargs_options = ({"dim": 1}, {}) 2224 for shape in shapes: 2225 input_options = self._make_extremal_inputs(shape, device) 2226 for input, kwargs in self._arg_and_kwarg_options( 2227 (input_options,), kwargs_options 2228 ): 2229 result = torch.nn.functional.softmax(input) 2230 cotangents = torch.randn_like(result, device=device) 2231 self._compare_jacobians_of_vjp( 2232 torch.nn.functional.softmax, (cotangents, input) 2233 ) 2234 2235 def test_extremal_numerics_log_softmax(self, device): 2236 N, C, H, W = 3, 4, 5, 6 2237 shapes = ((N, C), (N, C, H), (N, C, H, W)) 2238 kwargs_options = ({"dim": 1}, {}) 2239 for shape in shapes: 2240 input_options = self._make_extremal_inputs(shape, device) 2241 for input, kwargs in self._arg_and_kwarg_options( 2242 (input_options,), kwargs_options 2243 ): 2244 result = torch.nn.functional.log_softmax(input) 2245 cotangents = torch.randn_like(result, device=device) 2246 self._compare_jacobians_of_vjp( 2247 torch.nn.functional.log_softmax, (cotangents, input) 2248 ) 2249 2250 def test_extremal_numerics_cross_entropy(self, device): 2251 N, C = 3, 4 2252 d1, d2, d3 = 5, 6, 7 2253 shapes = ( 2254 ((N, C), (N,), (C,)), 2255 ((N, C), (N,), None), 2256 ((N, C), (N, C), (C,)), 2257 ((N, C), (N, C), None), 2258 ((C,), (), (C,)), 2259 ((C,), (), None), 2260 ((C,), (C,), (C,)), 2261 ((C,), (C,), None), 2262 ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)), 2263 ((N, C, d1, d2, d3), (N, d1, d2, d3), None), 2264 ((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)), 2265 ((N, C, d1, d2, d3), (N, C, d1, d2, d3), None), 2266 ) 2267 for input_shape, target_shape, weight_shape in shapes: 2268 input_options = self._make_extremal_inputs(input_shape, device) 2269 kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}] 2270 if input_shape != target_shape: 2271 kwargs_options.append({"ignore_index": 0, "reduction": "mean"}) 2272 2273 for input, kwargs in self._arg_and_kwarg_options( 2274 (input_options,), kwargs_options 2275 ): 2276 if weight_shape is None: 2277 weight = None 2278 else: 2279 weight = torch.randn(weight_shape, device=device) 2280 2281 if input_shape == target_shape: 2282 target = torch.rand(target_shape, device=device) 2283 elif len(target_shape) == 0: 2284 target = torch.tensor( 2285 1, device=device 2286 ) # must be non-zero since ignore_index may be 0 2287 else: 2288 target = torch.randint(0, C, target_shape, device=device) 2289 2290 fn = functools.partial( 2291 torch.nn.functional.cross_entropy, 2292 target=target, 2293 weight=weight, 2294 **kwargs, 2295 ) 2296 result = fn(input) 2297 cotangents = torch.randn_like(result, device=device) 2298 self._compare_jacobians_of_vjp( 2299 fn, (cotangents, input), atol_rtol=(1e-4, 1e-5) 2300 ) 2301 2302 def test_extremal_numerics_binary_cross_entropy(self, device): 2303 N, C, H, W = 3, 4, 5, 6 2304 shapes = ((N, C), (N, C, H), (N, C, H, W)) 2305 for shape in shapes: 2306 weight_options = self._make_extremal_inputs(shape, device) 2307 kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}] 2308 2309 for weight, kwargs in self._arg_and_kwarg_options( 2310 (weight_options,), kwargs_options 2311 ): 2312 input = torch.rand(shape, device=device) 2313 target = torch.rand(shape, device=device) 2314 fn = functools.partial( 2315 torch.nn.functional.binary_cross_entropy, 2316 target=target, 2317 weight=weight, 2318 **kwargs, 2319 ) 2320 result = fn(input) 2321 cotangents = torch.randn_like(result, device=device) 2322 self._compare_jacobians_of_vjp( 2323 fn, (cotangents, input), atol_rtol=(1e-4, 2e-5) 2324 ) 2325 2326 def test_extremal_numerics_layer_norm(self, device): 2327 N, C, H, W = 3, 4, 5, 6 2328 shapes = ((N, C), (N, C, H), (N, C, H, W)) 2329 for shape in shapes: 2330 input_options = self._make_extremal_inputs(shape, device) 2331 normalized_shape = shape[1:] 2332 weight_options = self._make_extremal_inputs(normalized_shape, device) 2333 bias_options = self._make_extremal_inputs(normalized_shape, device) 2334 2335 for input, bias, weight in self._arg_and_kwarg_options( 2336 (input_options, bias_options, weight_options), () 2337 ): 2338 2339 def fn(input, weight, bias): 2340 return torch.nn.functional.layer_norm( 2341 input, normalized_shape, weight=weight, bias=bias 2342 ) 2343 2344 result = fn(input, weight, bias) 2345 cotangents = torch.randn_like(result, device=device) 2346 self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias)) 2347 2348 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 2349 @ops( 2350 op_db + additional_op_db + autograd_function_db, 2351 allowed_dtypes=(torch.float32, torch.double), 2352 ) 2353 @skipOps( 2354 "TestOperators", 2355 "test_vmap_autograd_grad", 2356 { 2357 # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0 2358 xfail("masked_select"), 2359 xfail("nn.functional.max_unpool2d", "grad"), # contiguous call 2360 xfail("nn.functional.max_unpool2d"), # contiguous call 2361 xfail("to_sparse"), # dispatch key issue 2362 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 2363 # https://github.com/pytorch/pytorch/issues/96560#issuecomment-2151063723 2364 # ** minor accuracy issue for float32 on ROCm 2365 decorate("xlogy", decorator=skipIfRocm), 2366 # numerical inconsistencies, look like bugs 2367 skip( 2368 "matrix_exp", dtypes=(torch.float32,), device_type="cuda" 2369 ), # fails on linux, passes on windows 2370 skip( 2371 "ldexp", dtypes=(torch.float32,), device_type="cpu" 2372 ), # fails on all but mac 2373 skip("__rmatmul__"), # flaky needs investigation 2374 skip("matmul"), # flaky needs investigation 2375 skip("nn.functional.conv_transpose3d"), # flaky needs investigation 2376 skip("nn.functional.conv_transpose2d"), # flaky needs investigation 2377 skip("nn.functional.conv_transpose1d"), # flaky needs investigation 2378 skip( 2379 "nn.functional.layer_norm", dtypes=(torch.float32,), device_type="cpu" 2380 ), # fails on windows 2381 skip( 2382 "linalg.lu_factor", dtypes=(torch.float32,), device_type="cuda" 2383 ), # fails on all but windows 2384 skip( 2385 "linalg.lu_factor_ex", dtypes=(torch.float32,), device_type="cuda" 2386 ), # fails on all but windows 2387 skip("linalg.multi_dot", "", device_type="cpu"), 2388 skip("sparse.sampled_addmm", ""), 2389 skip("sparse.mm", "reduce"), 2390 skip("native_layer_norm", "", device_type="cpu"), 2391 # RuntimeError: Expected contiguous tensor, but got 2392 # non-contiguous tensor for argument #2 'grad_output' 2393 decorate( 2394 "_batch_norm_with_update", 2395 decorator=expectedFailureIf(TEST_WITH_ROCM), 2396 device_type="cuda", 2397 ), 2398 }, 2399 ) 2400 @opsToleranceOverride( 2401 "TestOperators", 2402 "test_vmap_autograd_grad", 2403 ( 2404 tol1( 2405 "ldexp", 2406 {torch.float32: tol(atol=3e-04, rtol=1.6e-06)}, 2407 device_type="cuda", 2408 ), 2409 tol1( 2410 "linalg.householder_product", 2411 {torch.float32: tol(atol=5e-04, rtol=9e-03)}, 2412 device_type="cuda", 2413 ), 2414 tol1( 2415 "linalg.householder_product", 2416 {torch.float32: tol(atol=6e-03, rtol=1e-03)}, 2417 device_type="cpu", 2418 ), 2419 tol1( 2420 "linalg.multi_dot", 2421 {torch.float32: tol(atol=2e-04, rtol=1e-04)}, 2422 device_type="cuda", 2423 ), 2424 tol2( 2425 "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-06, rtol=5e-06)} 2426 ), 2427 tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}), 2428 tol1( 2429 "nn.functional.conv2d", 2430 {torch.float32: tol(atol=3e-05, rtol=5e-06)}, 2431 device_type="cuda", 2432 ), 2433 tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), 2434 tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), 2435 ), 2436 ) 2437 def test_vmap_autograd_grad(self, device, dtype, op): 2438 def is_differentiable(inp): 2439 return isinstance(inp, Tensor) and ( 2440 inp.grad_fn is not None or inp.requires_grad 2441 ) 2442 2443 def get_flat_differentiable(tree): 2444 flattened = pytree.tree_leaves(tree) 2445 return tuple(i for i in flattened if is_differentiable(i)) 2446 2447 def get_differentiable_linked(list1, list2): 2448 paired_list = zip(list1, list2) 2449 paired_list = tuple( 2450 (first, second) 2451 for (first, second) in paired_list 2452 if is_differentiable(first) 2453 ) 2454 return zip(*paired_list) 2455 2456 def filter_none(out): 2457 flattened = pytree.tree_leaves(out) 2458 return tuple(o for o in flattened if o is not None) 2459 2460 if not op.supports_autograd: 2461 self.skipTest("Skipped! Autograd not supported.") 2462 return 2463 2464 sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) 2465 2466 for sample_input in sample_inputs: 2467 fn, primals = normalize_op_input_output(op, sample_input) 2468 out = fn(*primals) 2469 cotangents = tree_map(torch.randn_like, out) 2470 2471 def compute_grad(cotangents): 2472 out_flattened = out 2473 cotangents_flattened = cotangents 2474 if not isinstance(out_flattened, torch.Tensor): 2475 out_flattened = pytree.tree_leaves(out) 2476 cotangents_flattened = pytree.tree_leaves(cotangents) 2477 out_flattened, cotangents_flattened = get_differentiable_linked( 2478 out_flattened, cotangents_flattened 2479 ) 2480 2481 return filter_none( 2482 torch.autograd.grad( 2483 out_flattened, 2484 get_flat_differentiable(primals), 2485 cotangents_flattened, 2486 retain_graph=True, 2487 allow_unused=True, 2488 ) 2489 ) 2490 2491 is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs) 2492 generator = get_fallback_and_vmap_exhaustive( 2493 compute_grad, 2494 (cotangents,), 2495 {}, 2496 is_batch_norm_and_training=is_batch_norm_and_training, 2497 ) 2498 for loop_out, batched_out in generator: 2499 self.assertEqual(loop_out, batched_out) 2500 2501 def test_vmapvmapjvp_linalg_solve(self): 2502 ops = [op for op in op_db if op.name == "linalg.solve"] 2503 assert len(ops) > 0 2504 2505 # this specializes a lot of code from the get_fallback_and_vmap_exhaustive test. If we need this more 2506 # generally, this could go for a refactor 2507 2508 B0 = 2 2509 B1 = 3 2510 2511 # we want to check the case where A will be seen as contiguous by jvp but during the vmap calls will become 2512 # non-contiguous because vmap will expand. This will happen during both levels of vmap 2513 A = torch.randn(4, 4) 2514 k = torch.randn(4, 5, B1, B0) 2515 fn, args = get_jvp_variant_primals_tangents( 2516 torch.linalg.solve, SampleInput(A, args=(k,)) 2517 ) 2518 2519 in_dims_all = (None, -1, None, -1) 2520 batched_out = vmap(vmap(fn, in_dims=in_dims_all), in_dims=in_dims_all)(*args) 2521 loop_out = loop2(fn, in_dims_all, in_dims_all, 0, 0, B0, B1, *args) 2522 self.assertEqual(loop_out, batched_out) 2523 2524 @ops( 2525 filter(lambda op: op.name in aliasing_ops, op_db + additional_op_db), 2526 allowed_dtypes=(torch.float,), 2527 ) 2528 @parametrize("grad_op", ["jvp", "vjp"]) 2529 def test_view_then_inplace(self, device, dtype, op, grad_op): 2530 for sample_input in op.sample_inputs(device, dtype): 2531 2532 def f(x): 2533 op(sample_input.input, *sample_input.args, **sample_input.kwargs).copy_( 2534 x 2535 ) 2536 return x 2537 2538 without_grad = op( 2539 sample_input.input, *sample_input.args, **sample_input.kwargs 2540 ) 2541 if grad_op == "jvp": 2542 with self.assertRaisesRegex( 2543 RuntimeError, 2544 "During a grad .* attempted to call in-place operation", 2545 ): 2546 jvp( 2547 f, 2548 (torch.randn_like(without_grad),), 2549 (torch.randn_like(without_grad),), 2550 ) 2551 else: 2552 assert grad_op == "vjp" 2553 with self.assertRaisesRegex( 2554 RuntimeError, 2555 "During a grad .* attempted to call in-place operation", 2556 ): 2557 vjp(f, torch.randn_like(without_grad)) 2558 2559 @ops( 2560 filter( 2561 lambda op: op.name in aliasing_ops_list_return, op_db + additional_op_db 2562 ), 2563 allowed_dtypes=(torch.float,), 2564 ) 2565 @parametrize("grad_op", ["jvp", "vjp"]) 2566 def test_view_then_inplace_list_return(self, device, dtype, op, grad_op): 2567 for sample_input in op.sample_inputs(device, dtype): 2568 2569 def f(x): 2570 op(sample_input.input, *sample_input.args, **sample_input.kwargs)[ 2571 0 2572 ].copy_(x) 2573 return x 2574 2575 without_grad = op( 2576 sample_input.input, *sample_input.args, **sample_input.kwargs 2577 )[0] 2578 with self.assertRaisesRegex( 2579 RuntimeError, "During a grad .* attempted to call in-place operation" 2580 ): 2581 if grad_op == "jvp": 2582 jvp( 2583 f, 2584 (torch.randn_like(without_grad),), 2585 (torch.randn_like(without_grad),), 2586 ) 2587 else: 2588 assert grad_op == "vjp" 2589 vjp(f, torch.randn_like(without_grad)) 2590 2591 @parametrize("grad_op", ["jvp", "vjp"]) 2592 def test_view_then_inplace_special(self, grad_op): 2593 # some things in __getitem__ use at::index, which doesn't alias, so this tests a subset of them that do alias 2594 ops = [ 2595 lambda x: x[0], 2596 lambda x: x[0, 0, 0], 2597 lambda x: x[:1], 2598 lambda x: x[:, :1], 2599 lambda x: x[:, :1, :], 2600 ] 2601 2602 for op in ops: 2603 2604 def f(x): 2605 op(captured).copy_(x) 2606 return x 2607 2608 captured = torch.randn(4, 3, 3) 2609 without_grad = op(captured) 2610 if grad_op == "jvp": 2611 with self.assertRaisesRegex( 2612 RuntimeError, 2613 "During a grad .* attempted to call in-place operation", 2614 ): 2615 jvp( 2616 f, 2617 (torch.randn_like(without_grad),), 2618 (torch.randn_like(without_grad),), 2619 ) 2620 else: 2621 assert grad_op == "vjp" 2622 with self.assertRaisesRegex( 2623 RuntimeError, 2624 "During a grad .* attempted to call in-place operation", 2625 ): 2626 vjp(f, torch.randn_like(without_grad)) 2627 2628 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 2629 # NOTE: [three-transform testing] 2630 # We only test the autograd_function_db tests here. 2631 # 2632 # Usually testing the composition of two transforms is sufficient to convince 2633 # ourselves that an operator is correctly implemented. For the following cases, 2634 # we want to be extra sure, so we send those through some three-transform tests: 2635 # - autograd.Function. The mechanism is via PyDispatcher/HigherOrderOperator, not the 2636 # regular PyTorch dispatcher, so it's good to exercise more caution. 2637 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2638 @skipOps( 2639 "TestOperators", 2640 "test_vmapvjpvmap", 2641 { 2642 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2643 }, 2644 ) 2645 def test_vmapvjpvmap(self, device, dtype, op): 2646 samples = op.sample_inputs(device, dtype, requires_grad=True) 2647 B = 2 2648 for sample in samples: 2649 args = [sample.input] + list(sample.args) 2650 kwargs = sample.kwargs 2651 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2652 for batched_args, in_dims, kwargs in generator: 2653 inner_vmapped_op = vmap(op, in_dims) 2654 inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) 2655 2656 inner_vmapped_fn, primals = normalize_op_input_output2( 2657 inner_vmapped_op, 2658 batched_args, 2659 kwargs, 2660 sample.output_process_fn_grad, 2661 ) 2662 inner_mapped_fn, _ = normalize_op_input_output2( 2663 inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad 2664 ) 2665 result = inner_mapped_fn(*primals) 2666 cotangents = tree_map(lambda x: torch.rand_like(x), result) 2667 2668 def apply_vjp(fn): 2669 def inner(primals, cotangents): 2670 _, vjp_fn = vjp(fn, *primals) 2671 return vjp_fn(cotangents) 2672 2673 return inner 2674 2675 vjpvmap_fn = apply_vjp(inner_vmapped_fn) 2676 vjpmap_fn = apply_vjp(inner_mapped_fn) 2677 batched_args = (primals, cotangents) 2678 generator = generate_vmap_inputs(batched_args, {}) 2679 2680 for batched_args, in_dims, _ in generator: 2681 # strategy: compare vmap(vjp(vmap(op)) vs map(vjp(map(op)) 2682 vmapvjpvmap_fn = vmap(vjpvmap_fn, in_dims) 2683 mapvjpmap_fn = functools.partial(loop, vjpmap_fn, in_dims, 0, B) 2684 2685 result = vmapvjpvmap_fn(*batched_args) 2686 expected = mapvjpmap_fn(*batched_args) 2687 self.assertEqual(result, expected) 2688 2689 # See NOTE: [three-transform testing] 2690 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2691 @skipOps( 2692 "TestOperators", 2693 "test_vjpvmapvmap", 2694 { 2695 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2696 }, 2697 ) 2698 def test_vjpvmapvmap(self, device, dtype, op): 2699 samples = op.sample_inputs(device, dtype, requires_grad=True) 2700 B = 2 2701 for sample in samples: 2702 args = [sample.input] + list(sample.args) 2703 kwargs = sample.kwargs 2704 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2705 for batched_args, inner_in_dims, kwargs in generator: 2706 inner_vmapped_op = vmap(op, inner_in_dims) 2707 inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B) 2708 generator = generate_vmap_inputs(batched_args, kwargs) 2709 for batched_args, in_dims, kwargs in generator: 2710 # strategy: compare vjp(vmap(vmap(op)) vs vjp(map(map(op)) 2711 vmapped_op = vmap(inner_vmapped_op, in_dims) 2712 mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B) 2713 2714 vmapped_fn, primals = normalize_op_input_output2( 2715 vmapped_op, batched_args, kwargs, sample.output_process_fn_grad 2716 ) 2717 mapped_fn, _ = normalize_op_input_output2( 2718 mapped_op, batched_args, kwargs, sample.output_process_fn_grad 2719 ) 2720 2721 result = mapped_fn(*primals) 2722 cotangents = tree_map(lambda x: torch.rand_like(x), result) 2723 2724 _, vjp_fn = vjp(mapped_fn, *primals) 2725 expected_vjps = vjp_fn(cotangents) 2726 2727 _, vjp_fn = vjp(vmapped_fn, *primals) 2728 result_vjps = vjp_fn(cotangents) 2729 2730 self.assertEqual(result_vjps, expected_vjps) 2731 2732 # See NOTE: [three-transform testing] 2733 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2734 @skipOps( 2735 "TestOperators", 2736 "test_vjpvjpvmap", 2737 { 2738 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2739 }, 2740 ) 2741 def test_vjpvjpvmap(self, device, dtype, op): 2742 samples = op.sample_inputs(device, dtype, requires_grad=True) 2743 B = 2 2744 for sample in samples: 2745 args = [sample.input] + list(sample.args) 2746 kwargs = sample.kwargs 2747 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2748 for batched_args, in_dims, kwargs in generator: 2749 inner_vmapped_op = vmap(op, in_dims) 2750 inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) 2751 2752 vjpmap_fn, args = get_vjpfull_variant2( 2753 inner_mapped_op, batched_args, kwargs 2754 ) 2755 vjpvmap_fn, _ = get_vjpfull_variant2( 2756 inner_vmapped_op, batched_args, kwargs 2757 ) 2758 2759 vjpvjpvmap_fn, new_args = get_vjpfull_variant2(vjpvmap_fn, args, {}) 2760 vjpvjpmap_fn, _ = get_vjpfull_variant2(vjpmap_fn, args, {}) 2761 2762 expected = vjpvjpmap_fn(*new_args) 2763 result = vjpvjpvmap_fn(*new_args) 2764 self.assertEqual(result, expected) 2765 2766 # We're generally convinced that jvp x vmap works (vmap turns an operator 2767 # into another operator and we test jvp support for operators). So 2768 # we only test it on the things we're not sure about: 2769 # - the autograd.Function <> functorch interaction 2770 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2771 @skipOps( 2772 "TestOperators", 2773 "test_jvpvmap", 2774 { 2775 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2776 }, 2777 ) 2778 def test_jvpvmap(self, device, dtype, op): 2779 samples = op.sample_inputs(device, dtype, requires_grad=True) 2780 B = 2 2781 for sample in samples: 2782 args = [sample.input] + list(sample.args) 2783 kwargs = sample.kwargs 2784 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2785 for batched_args, in_dims, kwargs in generator: 2786 inner_vmapped_op = vmap(op, in_dims) 2787 inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) 2788 2789 jvpvmap_op, primals = get_jvp_variant_primals_tangents2( 2790 inner_vmapped_op, 2791 batched_args, 2792 kwargs, 2793 sample.output_process_fn_grad, 2794 ) 2795 jvpmap_op, _ = get_jvp_variant_primals_tangents2( 2796 inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad 2797 ) 2798 2799 expected = jvpmap_op(*primals) 2800 result = jvpvmap_op(*primals) 2801 self.assertEqual(result, expected) 2802 2803 # See NOTE: [three-transform testing] 2804 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2805 @skipOps( 2806 "TestOperators", 2807 "test_jvpvmapvmap", 2808 { 2809 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2810 }, 2811 ) 2812 def test_jvpvmapvmap(self, device, dtype, op): 2813 samples = op.sample_inputs(device, dtype, requires_grad=True) 2814 B = 2 2815 for sample in samples: 2816 args = [sample.input] + list(sample.args) 2817 kwargs = sample.kwargs 2818 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2819 for batched_args, inner_in_dims, kwargs in generator: 2820 inner_vmapped_op = vmap(op, inner_in_dims) 2821 inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B) 2822 generator = generate_vmap_inputs(batched_args, kwargs) 2823 for batched_args, in_dims, kwargs in generator: 2824 # strategy: compare jvp(vmap(vmap(op)) vs jvp(map(map(op)) 2825 vmapped_op = vmap(inner_vmapped_op, in_dims) 2826 mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B) 2827 2828 jvpvmapvmap_fn, primals = get_jvp_variant_primals_tangents2( 2829 vmapped_op, batched_args, kwargs, sample.output_process_fn_grad 2830 ) 2831 jvpmapmap_fn, _ = get_jvp_variant_primals_tangents2( 2832 mapped_op, batched_args, kwargs, sample.output_process_fn_grad 2833 ) 2834 2835 expected = jvpmapmap_fn(*primals) 2836 result = jvpvmapvmap_fn(*primals) 2837 self.assertEqual(result, expected) 2838 2839 # See NOTE: [three-transform testing] 2840 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 2841 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2842 @skipOps( 2843 "TestOperators", 2844 "test_vmapjvpvmap", 2845 { 2846 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2847 }, 2848 ) 2849 def test_vmapjvpvmap(self, device, dtype, op): 2850 samples = op.sample_inputs(device, dtype, requires_grad=True) 2851 B = 2 2852 for sample in samples: 2853 args = [sample.input] + list(sample.args) 2854 kwargs = sample.kwargs 2855 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2856 for batched_args, in_dims, kwargs in generator: 2857 inner_vmapped_op = vmap(op, in_dims) 2858 inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) 2859 2860 jvpvmap_fn, primals = get_jvp_variant_primals_tangents2( 2861 inner_vmapped_op, 2862 batched_args, 2863 kwargs, 2864 sample.output_process_fn_grad, 2865 ) 2866 jvpmap_fn, _ = get_jvp_variant_primals_tangents2( 2867 inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad 2868 ) 2869 2870 generator = generate_vmap_inputs(primals, {}) 2871 2872 for batched_args, in_dims, _ in generator: 2873 # strategy: compare vmap(jvp(vmap(op)) vs map(jvp(map(op)) 2874 vmapjvpvmap_fn = vmap(jvpvmap_fn, in_dims) 2875 mapjvpmap_fn = functools.partial(loop, jvpmap_fn, in_dims, 0, B) 2876 2877 result = vmapjvpvmap_fn(*batched_args) 2878 expected = mapjvpmap_fn(*batched_args) 2879 self.assertEqual(result, expected) 2880 2881 # See NOTE: [three-transform testing] 2882 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2883 @skipOps( 2884 "TestOperators", 2885 "test_jvpjvpvmap", 2886 { 2887 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2888 }, 2889 ) 2890 def test_jvpjvpvmap(self, device, dtype, op): 2891 samples = op.sample_inputs(device, dtype, requires_grad=True) 2892 B = 2 2893 for sample in samples: 2894 args = [sample.input] + list(sample.args) 2895 kwargs = sample.kwargs 2896 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2897 for batched_args, in_dims, kwargs in generator: 2898 inner_vmapped_op = vmap(op, in_dims) 2899 inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) 2900 2901 jvpmap_fn, args = get_jvp_variant_primals_tangents2( 2902 inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad 2903 ) 2904 jvpvmap_fn, _ = get_jvp_variant_primals_tangents2( 2905 inner_vmapped_op, 2906 batched_args, 2907 kwargs, 2908 sample.output_process_fn_grad, 2909 ) 2910 2911 jvpjvpvmap_fn, new_args = get_jvp_variant_primals_tangents2( 2912 jvpvmap_fn, args, {} 2913 ) 2914 jvpjvpmap_fn, _ = get_jvp_variant_primals_tangents2(jvpmap_fn, args, {}) 2915 2916 expected = jvpjvpmap_fn(*new_args) 2917 result = jvpjvpvmap_fn(*new_args) 2918 self.assertEqual(result, expected) 2919 2920 # See NOTE: [three-transform testing] 2921 @ops(autograd_function_db, allowed_dtypes=(torch.float32,)) 2922 @skipOps( 2923 "TestOperators", 2924 "test_jvpvjpvmap", 2925 { 2926 xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable 2927 }, 2928 ) 2929 def test_jvpvjpvmap(self, device, dtype, op): 2930 samples = op.sample_inputs(device, dtype, requires_grad=True) 2931 B = 2 2932 for sample in samples: 2933 args = [sample.input] + list(sample.args) 2934 kwargs = sample.kwargs 2935 generator = generate_vmap_inputs(args, kwargs, batch_size=B) 2936 for batched_args, in_dims, kwargs in generator: 2937 inner_vmapped_op = vmap(op, in_dims) 2938 inner_mapped_op = functools.partial(loop, op, in_dims, 0, B) 2939 2940 vjpmap_fn, args = get_vjpfull_variant2( 2941 inner_mapped_op, batched_args, kwargs 2942 ) 2943 vjpvmap_fn, _ = get_vjpfull_variant2( 2944 inner_vmapped_op, batched_args, kwargs 2945 ) 2946 2947 jvpvjpvmap_fn, new_args = get_jvp_variant_primals_tangents2( 2948 vjpvmap_fn, args, {} 2949 ) 2950 jvpvjpmap_fn, _ = get_jvp_variant_primals_tangents2(vjpmap_fn, args, {}) 2951 2952 expected = jvpvjpmap_fn(*new_args) 2953 result = jvpvjpvmap_fn(*new_args) 2954 self.assertEqual(result, expected) 2955 2956 def test_data_write_errors_under_transform(self, device): 2957 t = torch.randn(3, 3, device=device) 2958 2959 def fn(t): 2960 t.data = torch.randn(3, 3) 2961 return t.sum() 2962 2963 msg = "mutating directly with `.data` inside functorch transform" 2964 with self.assertRaisesRegex(RuntimeError, msg): 2965 grad(fn)(t) 2966 2967 with self.assertRaisesRegex(RuntimeError, msg): 2968 vjp(fn, t) 2969 2970 with self.assertRaisesRegex(RuntimeError, msg): 2971 jvp(fn, (t,), (torch.randn_like(t),)) 2972 2973 def test_tensor_with_scalar_list(self, device): 2974 x = torch.randn((), device=device) 2975 2976 def func_list_of_scalar(x): 2977 return torch.tensor([x], device=device) 2978 2979 def func(x): 2980 return torch.tensor(x, device=device).view(1) 2981 2982 actual_o, actual_fn = vjp(func_list_of_scalar, x) 2983 expected_o, expected_fn = vjp(func, x) 2984 2985 self.assertEqual(actual_o, expected_o) 2986 self.assertEqual( 2987 expected_fn(torch.ones_like(expected_o)), 2988 actual_fn(torch.ones_like(actual_o)), 2989 ) 2990 2991 2992only_for = ("cpu", "cuda") 2993instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) 2994 2995if __name__ == "__main__": 2996 run_tests() 2997