1# Owner(s): ["module: functorch"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import contextlib 10import functools 11import itertools 12import os 13import random 14import types 15import unittest 16import warnings 17from collections import namedtuple 18from typing import OrderedDict 19from unittest.case import skipIf 20 21from common_utils import ( 22 check_vmap_fallback, 23 compute_quantities_for_vmap_test, 24 decorate, 25 DisableVmapFallback, 26 generate_vmap_inputs, 27 get_fallback_and_vmap_exhaustive, 28 is_batch_norm_training, 29 is_valid_inplace_sample_input, 30 opsToleranceOverride, 31 skip, 32 skipOps, 33 tol1, 34 xfail, 35) 36from functorch_additional_op_db import additional_op_db 37 38import functorch 39import torch 40import torch.nn.functional as F 41from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap 42from functorch.experimental import chunk_vmap 43from torch import Tensor 44from torch._C._functorch import reshape_dim_into, reshape_dim_outof 45from torch._functorch.make_functional import functional_init_with_buffers 46from torch._functorch.vmap import restore_vmap 47from torch.nn.attention import sdpa_kernel, SDPBackend 48from torch.testing._internal.autograd_function_db import autograd_function_db 49from torch.testing._internal.common_cuda import ( 50 PLATFORM_SUPPORTS_CUDNN_ATTENTION, 51 PLATFORM_SUPPORTS_FLASH_ATTENTION, 52 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 53 with_tf32_off, 54) 55from torch.testing._internal.common_device_type import ( 56 instantiate_device_type_tests, 57 onlyCUDA, 58 OpDTypes, 59 ops, 60 tol, 61 toleranceOverride, 62) 63from torch.testing._internal.common_methods_invocations import op_db 64from torch.testing._internal.common_utils import ( 65 instantiate_parametrized_tests, 66 IS_WINDOWS, 67 markDynamoStrictTest, 68 parametrize, 69 run_tests, 70 skipIfTorchDynamo, 71 subtest, 72 TEST_WITH_TORCHDYNAMO, 73 TestCase, 74 unMarkDynamoStrictTest, 75 xfailIfTorchDynamo, 76) 77from torch.testing._internal.custom_op_db import custom_op_db 78from torch.utils import _pytree as pytree 79 80 81def get_platform_specific_sdpa(): 82 ret = [SDPBackend.MATH] 83 if PLATFORM_SUPPORTS_FLASH_ATTENTION: 84 ret.append(SDPBackend.FLASH_ATTENTION) 85 if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: 86 ret.append(SDPBackend.EFFICIENT_ATTENTION) 87 if PLATFORM_SUPPORTS_CUDNN_ATTENTION: 88 ret.append(SDPBackend.CUDNN_ATTENTION) 89 return ret 90 91 92PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa() 93 94FALLBACK_REGEX = "There is a performance drop" 95 96 97class EnableVmapFallbackWarnings: 98 def __enter__(self): 99 self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled() 100 torch._C._debug_only_display_vmap_fallback_warnings(True) 101 102 def __exit__(self, *ignored): 103 torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state) 104 105 106@markDynamoStrictTest 107class TestVmapAPI(TestCase): 108 def test_non_tensor_output_raises(self): 109 with self.assertRaisesRegex(ValueError, "got type <class 'float'>"): 110 vmap(lambda x: 3.14)(torch.ones(3)) 111 112 def multiple_outputs(x): 113 return x, 3 114 115 with self.assertRaisesRegex(ValueError, "got type <class 'int'>"): 116 vmap(multiple_outputs)(torch.ones(3)) 117 118 def test_different_map_dim_size_raises(self): 119 x = torch.randn(2) 120 y = torch.randn(3) 121 expected_msg = ( 122 "Expected all tensors to have the same size in the mapped dimension" 123 ) 124 with self.assertRaisesRegex(ValueError, expected_msg): 125 vmap(torch.mul)(x, y) 126 with self.assertRaisesRegex(ValueError, expected_msg): 127 vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) 128 with self.assertRaisesRegex(ValueError, expected_msg): 129 vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))( 130 {"x": x, "y": y} 131 ) 132 133 def test_func_with_no_inputs(self): 134 expected_msg = "got no inputs" 135 136 def foo(): 137 return torch.randn(3) 138 139 def bar(x): 140 return torch.randn(3) 141 142 with self.assertRaisesRegex(ValueError, expected_msg): 143 vmap(foo)() 144 145 with self.assertRaisesRegex(ValueError, expected_msg): 146 vmap(bar)() 147 148 def test_func_with_no_tensors(self): 149 def foo(x): 150 return torch.randn(3) 151 152 with self.assertRaisesRegex(ValueError, "at least one Tensor"): 153 vmap(foo, (None,))(1) 154 155 def test_constant_function(self): 156 output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) 157 self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) 158 159 def test_single_input(self): 160 x = torch.randn(2, 3) 161 162 def square(x): 163 return x * x 164 165 output = vmap(square)(x) 166 self.assertEqual(output, x * x) 167 168 def test_multiple_inputs(self): 169 x = torch.randn(2, 3) 170 y = torch.randn(2, 3) 171 output = vmap(torch.mul)(x, y) 172 self.assertEqual(output, x * y) 173 174 def test_multiple_outputs(self): 175 def foo(x): 176 return x * x, x * x * x 177 178 x = torch.randn(3) 179 outputs = vmap(foo)(x) 180 self.assertEqual(outputs[0], x * x) 181 self.assertEqual(outputs[1], x * x * x) 182 183 def test_multiple_outputs2(self): 184 # This is the same thing as 185 # def returns_tuple_of_tensors(x): 186 # return x, x 187 def returns_tuple_of_tensors(x): 188 return (x, x) 189 190 def returns_list_of_two_tensors(x): 191 return [x, x] 192 193 def returns_list_of_one_tensor(x): 194 return [x] 195 196 x = torch.randn(3) 197 198 # should not throw 199 vmap(returns_tuple_of_tensors)(x) 200 vmap(returns_list_of_two_tensors)(x) 201 vmap(returns_list_of_one_tensor)(x) 202 203 def test_nested_with_same_map_dim(self): 204 x = torch.randn(2, 3, 5) 205 y = torch.randn(2, 3, 5) 206 output = vmap(vmap(torch.mul))(x, y) 207 self.assertEqual(output, x * y) 208 209 output = vmap(vmap(vmap(torch.mul)))(x, y) 210 self.assertEqual(output, x * y) 211 212 def test_nested_with_diag_embed(self): 213 # diag_embed requires special testing because it is registered with conditional functionalization. 214 x = torch.randn(3, 3, 5) 215 output = vmap(vmap(torch.diag_embed))(x) 216 self.assertEqual(output, torch.diag_embed(x)) 217 218 def test_nested_with_different_map_dim(self): 219 x = torch.randn(2, 3) 220 y = torch.randn(5, 3) 221 output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) 222 self.assertEqual(output.shape, (2, 5, 3)) 223 self.assertEqual(output, x.view(2, 1, 3) * y) 224 225 z = torch.randn(7, 3) 226 output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) 227 self.assertEqual(output.shape, (2, 5, 7, 3)) 228 self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) 229 230 def test_noop_in_inner_vmap(self): 231 x = torch.randn(3) 232 y = torch.randn(5) 233 output = vmap(lambda x: vmap(lambda y: x)(y))(x) 234 self.assertEqual(output, x.view(3, 1).expand(3, 5)) 235 236 def test_checkpoint(self): 237 A = torch.randn((3, 8, 8), dtype=torch.float64, requires_grad=True) 238 239 def get_grad(checkpoint): 240 A.grad = None 241 242 def get_loss(A): 243 ortho_A, _ = torch.func.vmap(torch.linalg.qr)(A) 244 return torch.sum(ortho_A) 245 246 if checkpoint: 247 loss = torch.utils.checkpoint.checkpoint( 248 get_loss, A, use_reentrant=False 249 ) 250 else: 251 loss = get_loss(A) 252 loss.backward() 253 return A.grad 254 255 expected = get_grad(checkpoint=False) 256 result = get_grad(checkpoint=True) 257 self.assertEqual(result, expected) 258 259 def test_unsupported_op_err_msg(self): 260 # Unsupported view op 261 tensor = torch.randn(2, 3) 262 msg = ( 263 r"Batching rule not implemented for aten::.+; the " 264 r"fallback path doesn't work on out= or view ops" 265 ) 266 # TODO: find a view op 267 # with self.assertRaisesRegex(RuntimeError, msg): 268 # vmap(torch.ravel)(tensor) 269 270 def out_op(x, y): 271 return torch.abs(x, out=y) 272 273 with self.assertRaisesRegex(RuntimeError, msg): 274 vmap(out_op)(tensor, tensor) 275 276 # Don't support non-tensor returns. This is a limitation of vmap; 277 # functions that don't return tensors must be special cased 278 with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"): 279 vmap(torch.equal)(tensor, tensor) 280 281 def test_nonzero_out_dims(self): 282 # Basic test 283 tensor = torch.randn(2, 3) 284 result = vmap(lambda x: x, out_dims=1)(tensor) 285 self.assertEqual(result, tensor.permute(1, 0)) 286 self.assertEqual(result.data_ptr(), tensor.data_ptr()) 287 288 # Test that the batch dimension gets permuted to dim 2 289 tensor = torch.randn(2, 3, 5, 7) 290 result = vmap(lambda x: x, out_dims=2)(tensor) 291 self.assertEqual(result, tensor.permute(1, 2, 0, 3)) 292 self.assertEqual(result.data_ptr(), tensor.data_ptr()) 293 294 # negative out_dim 295 tensor = torch.randn(2, 3, 5, 7) 296 result = vmap(lambda x: x, out_dims=-1)(tensor) 297 self.assertEqual(result, tensor.permute(1, 2, 3, 0)) 298 self.assertEqual(result.data_ptr(), tensor.data_ptr()) 299 300 # check that out_dims works on ALL outputs 301 tensor = torch.randn(2, 3, 5, 7) 302 other = torch.randn(2, 3, 5, 7) 303 result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other) 304 self.assertEqual( 305 result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)) 306 ) 307 308 # use out_dims with the maximum vmap-able tensor dims (64 dims) 309 ndims = 64 310 shape = [2] + [1] * (ndims - 1) 311 expected_shape = [1, 1, 2] + [1] * (ndims - 3) 312 tensor = torch.randn(shape) 313 result = vmap(lambda x: x, out_dims=2)(tensor) 314 self.assertEqual(result.shape, expected_shape) 315 316 # test something that is not the identity function 317 def foo(x, y): 318 return x, x * y, x * y * y 319 320 x = torch.randn(2, 3, 5) 321 y = torch.randn(2, 3, 5) 322 result = vmap(foo, out_dims=1)(x, y) 323 self.assertEqual( 324 result, 325 ( 326 x.permute(1, 0, 2), 327 (x * y).permute(1, 0, 2), 328 (x * y * y).permute(1, 0, 2), 329 ), 330 ) 331 332 def test_multiple_out_dims(self): 333 def foo(x): 334 return x, x 335 336 def bar(x, y): 337 return x, x, x, x * y 338 339 x = torch.randn(2, 3, 5) 340 y = torch.randn(2, 3, 5) 341 result = vmap(foo, out_dims=(0, 1))(x) 342 self.assertEqual(result, (x, x.permute(1, 0, 2))) 343 344 result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y) 345 expected = ( 346 x.permute(1, 2, 0), 347 x, 348 x.permute(1, 0, 2), 349 (x * y).permute(1, 2, 0), 350 ) 351 self.assertEqual(result, expected) 352 353 def test_nested_out_dims(self): 354 y = torch.randn(2, 3, 5, 7) 355 356 # Inner vmap has non-zero out_dim 357 result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) 358 self.assertEqual(result.shape, (2, 5, 3, 7)) 359 self.assertEqual(result, y.permute(0, 2, 1, 3)) 360 361 # all vmaps have non-zero out_dim 362 result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) 363 self.assertEqual(result.shape, (5, 2, 3, 7)) 364 self.assertEqual(result, y.permute(2, 0, 1, 3)) 365 366 # throwing in some negative out_dims 367 result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) 368 self.assertEqual(result.shape, (5, 7, 3, 2)) 369 self.assertEqual(result, y.permute(2, 3, 1, 0)) 370 371 # testing fn that isn't the identity 372 x = torch.randn(2, 3) 373 y = torch.randn(5, 3) 374 result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) 375 self.assertEqual(result.shape, (3, 2, 5)) 376 self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0)) 377 378 def test_out_dims_edge_case(self): 379 def foo(x): 380 return x 381 382 # Test that we accept out_dims=(1,) for a function with one output. 383 tensor = torch.randn(2, 3) 384 expected = vmap(foo, out_dims=1)(tensor) 385 result = vmap(foo, out_dims=(1,))(tensor) 386 self.assertEqual(result, expected) 387 388 def test_out_dims_none_tuple(self): 389 def foo(x): 390 return x, "hello world" 391 392 tensor = torch.randn(2, 3) 393 result = vmap(foo, out_dims=(0, None))(tensor) 394 self.assertEqual(result[1], "hello world") 395 self.assertEqual(result[0], tensor) 396 397 def foo(x): 398 x.add_(1) 399 return None, "hello world" 400 401 result = vmap(foo, out_dims=(None, None))(tensor) 402 self.assertEqual(result, (None, "hello world")) 403 404 def test_out_dims_none(self): 405 def foo(x): 406 return x 407 408 tensor = torch.randn(2, 3) 409 with self.assertRaisesRegex( 410 ValueError, "can not return a BatchedTensor when out_dim is None" 411 ): 412 vmap(foo, out_dims=None)(tensor) 413 414 def foo(x): 415 x.add_(1) 416 return "hello world" 417 418 result = vmap(foo, out_dims=None)(tensor) 419 self.assertEqual(result, "hello world") 420 421 def test_out_dims_normal_tensor(self): 422 def foo(x): 423 return torch.arange(3) 424 425 tensor = torch.randn(2, 3) 426 result = vmap(foo)(tensor) 427 self.assertEqual(result.shape, [2, 3]) 428 429 result = vmap(foo, out_dims=None)(tensor) 430 self.assertEqual(result, torch.arange(3)) 431 432 def test_pytree_returns(self): 433 x = torch.randn(2, 3) 434 435 def f(x): 436 y = x.sin() 437 return y, (y, y), [y, (y, y)] 438 439 y0, (y1, y2), (y3, (y4, y5)) = vmap(f)(x) 440 self.assertEqual(y0, x.sin()) 441 self.assertEqual(y0, y1) 442 self.assertEqual(y2, y1) 443 self.assertEqual(y2, y3) 444 self.assertEqual(y4, y3) 445 self.assertEqual(y5, y4) 446 447 def test_pytree_odict_returns(self): 448 x = torch.randn(2, 3) 449 450 def f(t): 451 y = t.sin() 452 return OrderedDict([("sin", y), ("cos", t.cos())]) 453 454 out = vmap(f)(x) 455 assert isinstance(out, OrderedDict) 456 expected = f(x) 457 self.assertEqual(out["sin"], expected["sin"]) 458 self.assertEqual(out["cos"], expected["cos"]) 459 460 def test_pytree_returns_outdims(self): 461 x = torch.randn(2, 3) 462 463 def f(x): 464 y = x.sin() 465 return y, (y, y) 466 467 y0, (y1, y2) = vmap(f, out_dims=(0, (0, 1)))(x) 468 self.assertEqual(y0, x.sin()) 469 self.assertEqual(y1, x.sin()) 470 self.assertEqual(y2, x.sin().t()) 471 472 def test_pytree_returns_broadcast_simple(self): 473 x = torch.randn(2, 3) 474 475 def f(x): 476 y = x.sin() 477 return y, (y, y) 478 479 y0, (y1, y2) = vmap(f, out_dims=1)(x) 480 self.assertEqual(y0, x.sin().t()) 481 self.assertEqual(y1, y0) 482 self.assertEqual(y2, y0) 483 484 def test_pytree_returns_broadcast_nested(self): 485 x = torch.randn(2, 3) 486 487 def f(x): 488 y = x.sin() 489 return y, (y, y) 490 491 y0, (y1, y2) = vmap(f, out_dims=(0, 1))(x) 492 self.assertEqual(y0, x.sin()) 493 self.assertEqual(y1, y0.t()) 494 self.assertEqual(y2, y0.t()) 495 496 def test_out_dims_must_be_int_or_collection_of_int_err_msg(self): 497 msg = "must be an int, None or a python collection of ints" 498 tensor = torch.randn(2, 3) 499 with self.assertRaisesRegex(ValueError, msg): 500 vmap(lambda x: x, out_dims="lol")(tensor) 501 with self.assertRaisesRegex(ValueError, msg): 502 vmap(lambda x: x, out_dims=("lol",))(tensor) 503 504 def test_out_dims_and_num_outputs_mismatch_err_msg(self): 505 msg = "not compatible" 506 x = torch.randn(2, 3, 5) 507 508 # Too many out_dims 509 with self.assertRaisesRegex(ValueError, msg): 510 vmap(lambda x: x, out_dims=(0, 0))(x) 511 with self.assertRaisesRegex(ValueError, msg): 512 vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x) 513 514 # Too few out_dims 515 with self.assertRaisesRegex(ValueError, msg): 516 vmap(lambda x: (x, x), out_dims=(0,))(x) 517 with self.assertRaisesRegex(ValueError, msg): 518 vmap(lambda x: (x, x, x), out_dims=(0, 0))(x) 519 520 def test_out_dim_out_of_bounds_err_msg(self): 521 # TODO(rzou): This error message isn't that great. It comes straight 522 # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to 523 # the error message in the future in C++ 524 msg = "Dimension out of range" 525 x = torch.randn(2, 3, 5) 526 with self.assertRaisesRegex(IndexError, msg): 527 vmap(lambda x: x, out_dims=3)(x) 528 with self.assertRaisesRegex(IndexError, msg): 529 vmap(lambda x: x, out_dims=-4)(x) 530 531 def test_non_zero_in_dims(self): 532 tensor = torch.randn(2, 3, 5) 533 534 # Implicit out_dims = 0; vmap will move the batch dim to the front. 535 output = vmap(lambda x: x, (1,))(tensor) 536 self.assertEqual(output, tensor.permute(1, 0, 2)) 537 self.assertEqual(output.data_ptr(), tensor.data_ptr()) 538 539 x = torch.randn(2, 3) 540 y = torch.randn(3, 2) 541 output = vmap(torch.mul, (0, 1))(x, y) 542 self.assertEqual(output, x * y.t()) 543 output = vmap(torch.mul, (1, 0))(x, y) 544 self.assertEqual(output, x.t() * y) 545 546 def test_none_in_dims(self): 547 x = torch.randn(2, 3) 548 y = torch.randn(2, 3) 549 550 # None in_dim for a Tensor means we don't map over it 551 output = vmap(torch.mul, (0, None))(x, y) 552 self.assertEqual(output.shape, (2, 2, 3)) 553 self.assertEqual(output, x.view(2, 1, 3) * y) 554 555 # None in_dim for non-tensor arguments 556 output = vmap(torch.mul, (0, None))(x, 2) 557 self.assertEqual(output, x * 2) 558 559 def test_nested_non_default_in_dims(self): 560 x = torch.rand(5, 2, 3) 561 y = torch.rand(3, 5, 2) 562 result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y) 563 self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1)) 564 565 def test_nested_negative_in_dims(self): 566 x = torch.randn(2, 3) 567 y = torch.randn(2, 3) 568 output = vmap(torch.mul, (-1, -1))(x, y) 569 self.assertEqual(output.shape, (3, 2)) 570 self.assertEqual(output, (x * y).permute(1, 0)) 571 572 def test_non_default_in_dims_out_dims(self): 573 x = torch.randn(2, 3, 5) 574 575 # Same in_dim as out_dim, vmap over identity 576 result = vmap(lambda x: x, in_dims=1, out_dims=1)(x) 577 self.assertEqual(result, x) 578 self.assertEqual(result.data_ptr(), x.data_ptr()) 579 580 # Different in_dim from out_dim, vmap over identity 581 result = vmap(lambda x: x, in_dims=2, out_dims=1)(x) 582 self.assertEqual(result.shape, (2, 5, 3)) 583 self.assertEqual(result, x.transpose(1, 2)) 584 self.assertEqual(result.data_ptr(), x.data_ptr()) 585 586 def foo(x): 587 return x * 2 588 589 # Same in_dim as out_dim, vmap over operation 590 result = vmap(foo, in_dims=1, out_dims=1)(x) 591 self.assertEqual(result, x * 2) 592 593 # Different in_dim as out_dim, vmap over operation 594 result = vmap(foo, in_dims=2, out_dims=1)(x) 595 self.assertEqual(result.shape, (2, 5, 3)) 596 self.assertEqual(result, (x * 2).transpose(1, 2)) 597 598 # Basic nested test. 599 result = vmap(vmap(foo, 1, 1), 1, 1)(x) 600 self.assertEqual(result, x * 2) 601 602 def test_item_throws(self): 603 def f(x): 604 return x.item() 605 606 with self.assertRaisesRegex(RuntimeError, r"item\(\) on a Tensor"): 607 vmap(f)(torch.randn(3)) 608 609 def test_data_dependent_control_flow_throws(self): 610 def f(x): 611 if x: 612 return x 613 return 0 614 615 with self.assertRaisesRegex(RuntimeError, r"data-dependent control flow"): 616 vmap(f)(torch.randn(3)) 617 618 def test_accepts_nested_inputs(self): 619 x = torch.randn(2, 3) 620 y = torch.randn(2, 3) 621 622 # Single layer of nesting 623 out = vmap(lambda z: z[0] + z[1])((x, y)) 624 self.assertEqual(out, x + y) 625 out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y)) 626 self.assertEqual(out, x + y) 627 out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) 628 self.assertEqual(out, x + y) 629 630 out = vmap(lambda z: z[0] + z[1])([x, y]) 631 self.assertEqual(out, x + y) 632 out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y]) 633 self.assertEqual(out, x + y) 634 out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y]) 635 self.assertEqual(out, x + y) 636 637 out = vmap(lambda z: z["x"] + z["y"])({"x": x, "y": y}) 638 self.assertEqual(out, x + y) 639 out = vmap(lambda z: z["x"] + z["y"], in_dims=(0,))({"x": x, "y": y}) 640 self.assertEqual(out, x + y) 641 out = vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))( 642 {"x": x, "y": y} 643 ) 644 self.assertEqual(out, x + y) 645 646 # Multiple layers of nesting 647 out_fn = vmap(lambda z: z["x"][0] + z["x"][1][0] + z["y"][0] + z["y"][1]) 648 out = out_fn({"x": [x, (x,)], "y": [y, y]}) 649 self.assertEqual(out, x + x + y + y) 650 651 def test_in_dims_wrong_type_err_msg(self): 652 x = torch.randn(3) 653 y = torch.randn(3) 654 msg = r"expected `in_dims` to be int or a \(potentially nested\) tuple" 655 with self.assertRaisesRegex(ValueError, msg): 656 vmap(torch.mul, [0, 0])(x, y) 657 with self.assertRaisesRegex(ValueError, msg): 658 vmap(torch.mul, set({0}))(x, y) 659 with self.assertRaisesRegex(ValueError, msg): 660 vmap(torch.mul, "lol")(x, y) 661 with self.assertRaisesRegex(ValueError, msg): 662 vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y]) 663 # The following should not throw 664 vmap(torch.mul, (0, 0))(x, y) 665 666 def test_not_enough_in_dims_err_msg(self): 667 x = torch.randn(3) 668 y = torch.randn(3) 669 msg = r"in_dims is not compatible with the structure of `inputs`" 670 671 with self.assertRaisesRegex(ValueError, msg): 672 vmap(torch.mul, (0,))(x, y) 673 with self.assertRaisesRegex(ValueError, msg): 674 vmap(torch.mul, (0, 0, 0))(x, y) 675 with self.assertRaisesRegex(ValueError, msg): 676 vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y]) 677 with self.assertRaisesRegex(ValueError, msg): 678 vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y]) 679 # The following should not throw 680 vmap(torch.mul, (0, 0))(x, y) 681 682 def test_integer_in_dim_but_not_tensor_input_err_msg(self): 683 def foo(xy): 684 return xy[0] * xy[1] 685 686 def bar(x, yz): 687 return x * yz[0] * yz[1] 688 689 x = torch.randn(2, 3) 690 691 # the following are errors in jax (and will always be errors) 692 msg = "Got in_dim=0 for an input but the input is of type" 693 with self.assertRaisesRegex(ValueError, msg): 694 vmap(torch.sum)(x, 0) 695 with self.assertRaisesRegex(ValueError, msg): 696 vmap(torch.sum, (0, 0))(x, 0) 697 with self.assertRaisesRegex(ValueError, msg): 698 vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1]) 699 # The following should not throw 700 vmap(torch.sum, (0, None))(x, 0) 701 702 def test_in_dim_not_in_tensor_err_msg(self): 703 def foo(x): 704 return x * x 705 706 x = torch.randn(2, 3) 707 y = torch.randn(2, 3) 708 709 msg = r"Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w" 710 with self.assertRaisesRegex(ValueError, msg): 711 vmap(foo)(torch.randn([])) 712 with self.assertRaisesRegex(ValueError, msg): 713 vmap(foo, in_dims=(0,))(torch.randn([])) 714 with self.assertRaisesRegex(ValueError, msg): 715 vmap(foo, in_dims=(-3,))(x) 716 with self.assertRaisesRegex(ValueError, msg): 717 vmap(foo, in_dims=(2,))(y) 718 with self.assertRaisesRegex(ValueError, msg): 719 vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y]) 720 # the following should not throw 721 vmap(foo, in_dims=(0,))(torch.randn(2, 3)) 722 vmap(foo, in_dims=(1,))(torch.randn(2, 3)) 723 724 def test_fallback_does_not_warn_by_default(self): 725 op = torch._test_functorch_fallback 726 x = torch.randn(11) 727 y = torch.randn(11) 728 with warnings.catch_warnings(record=True) as wa: 729 torch.vmap(op)(x, y) 730 # The single warning here is the "vmap is experimental" 731 # warning, not a warning from the vmap fallback path. 732 self.assertEqual(len(wa), 1) 733 734 @unittest.expectedFailure 735 def test_fallback_warns_when_warnings_are_enabled(self): 736 # NB: One day we will implement a batching rule for torch.atan2. 737 # If/when we do, this test should be replaced to test the fallback 738 # path on another operator to avoid bitrot. 739 op = torch._test_functorch_fallback 740 x = torch.randn(11) 741 y = torch.randn(11) 742 with warnings.catch_warnings(record=True) as wa: 743 with EnableVmapFallbackWarnings(): 744 torch.vmap(op)(x, y) 745 self.assertEqual(len(wa), 2) 746 self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) 747 748 def _assert_uses_vmap_fallback(self, vmap_args, inputs): 749 return 750 # with warnings.catch_warnings(record=True) as wa: 751 # with EnableVmapFallbackWarnings(): 752 # result = vmap(*vmap_args)(*inputs) 753 # self.assertEqual(len(wa), 2) 754 # self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) 755 756 def test_fallback_zero_dim(self): 757 op = torch._test_functorch_fallback 758 x = torch.randn(11) 759 y = torch.randn(11) 760 self._assert_uses_vmap_fallback((op,), (x, y)) 761 762 B0, B1 = 0, 3 763 x = torch.randn(B0, 11) 764 y = torch.randn(11) 765 766 msg = "The fallback path does not support vmap over dims of size 0" 767 768 with self.assertRaisesRegex(RuntimeError, msg): 769 vmap(op, (0, None))(x, y) 770 with self.assertRaisesRegex(RuntimeError, msg): 771 vmap(op, (None, 0))(y, x) 772 with self.assertRaisesRegex(RuntimeError, msg): 773 vmap(op)(x, x) 774 775 x = torch.randn(B0, B1, 11) 776 y = torch.randn(B1, 11) 777 with self.assertRaisesRegex(RuntimeError, msg): 778 vmap(op, (0, None))(x, y) 779 with self.assertRaisesRegex(RuntimeError, msg): 780 vmap(op, (None, 0))(y, x) 781 with self.assertRaisesRegex(RuntimeError, msg): 782 vmap(op)(x, x) 783 784 def test_fallback_warning(self): 785 # We use a dummy function _test_functorch_fallback 786 # defined in prim_native_functions.cpp for this 787 op = torch._test_functorch_fallback 788 789 x = torch.randn(5, 7, 11) 790 y = torch.randn(5, 7, 11) 791 792 self._assert_uses_vmap_fallback((op,), (x, y)) 793 794 x = torch.randn(7, 11, 5) 795 y = torch.randn(5, 7, 11) 796 result = vmap(op, (2, 0))(x, y) 797 self.assertEqual(result, op(x.permute(2, 0, 1), y)) 798 799 # nested vmap 800 x = torch.randn(7, 11, 5) 801 y = torch.randn(5, 7, 11) 802 result = vmap(vmap(op), (2, 0))(x, y) 803 self.assertEqual(result, op(x.permute(2, 0, 1), y)) 804 805 # big batch size (total 10000) 806 x = torch.randn(100, 10, 10, 5) 807 y = torch.randn(100, 10, 10) 808 result = vmap(vmap(vmap(op)))(x, y) 809 self.assertEqual(result, op(x, y.view(100, 10, 10, 1))) 810 811 # TODO: No clue what is wrong here. 812 @unittest.skip 813 def test_fallback_masked_fill(self): 814 # NB: One day we will implement a batching rule for masked_fill 815 # If/when we do, this test should be replaced to test the fallback 816 # path on another operator to avoid bitrot. 817 def run_test(batch_size): 818 B0 = batch_size 819 x = torch.randn(B0, 7, 11, 13) 820 dim = 0 821 index = torch.tensor([0, 4, 2]) 822 values = torch.randn(B0, 3, 13) 823 824 self._assert_uses_vmap_fallback( 825 (torch.index_add, (0, None, None, 0)), (x, dim, index, values) 826 ) 827 828 result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values) 829 expected = torch.index_add(x, dim + 1, index, values.view(B0, 3, 1, 13)) 830 self.assertEqual(result, expected) 831 832 run_test(batch_size=5) 833 run_test(batch_size=1237) 834 835 def test_fallback_multiple_returns(self): 836 # NB: One day we will implement a batching rule for torch.var_mean 837 # If/when we do, this test should be replaced to test the fallback 838 # path on another operator to avoid bitrot. 839 B0, B1, B2 = 2, 3, 1237 840 tensor = torch.randn(B0, 10) 841 842 self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,)) 843 844 # fallback correctness on torch.var_mean 845 result = vmap(torch.var_mean)(tensor) 846 expected = torch.var_mean(tensor, dim=1) 847 self.assertEqual(result, expected) 848 849 # nested vmap 850 tensor = torch.randn(B0, B1, 10) 851 result = vmap(vmap(torch.var_mean))(tensor) 852 expected = torch.var_mean(tensor, dim=2) 853 self.assertEqual(result, expected) 854 855 # big batch size, nested vmap 856 tensor = torch.randn(B0, B1, B2, 10) 857 result = vmap(vmap(vmap(torch.var_mean)))(tensor) 858 expected = torch.var_mean(tensor, dim=3) 859 self.assertEqual(result, expected) 860 861 def test_inplace_fallback_unary(self): 862 # Test the in-place fallback on an in-place method that takes no 863 # additional Tensor arguments. This is the simplest case of the fallback. 864 # NB: One day we will implement a batching rule for acos_. 865 # If/when we do, this test should be replaced to test the fallback 866 # path on another operator to avoid bitrot. 867 op = Tensor.acos_ 868 B0, B1, B2 = 2, 3, 10000 869 870 x = torch.randn(B0, 5) 871 self._assert_uses_vmap_fallback((op,), (x,)) 872 873 # Single vmap 874 x_orig = torch.rand(B0, 5) 875 x = x_orig.clone() 876 result = vmap(op)(x) 877 self.assertTrue(result is x) 878 self.assertEqual(result, x_orig.acos()) 879 880 # Single vmap + different out_dim produces a view(!) 881 x_orig = torch.rand(B0, 5) 882 x = x_orig.clone() 883 result = vmap(op, out_dims=(1,))(x) 884 self.assertTrue(result._base is x) 885 self.assertEqual(result, x_orig.t().acos()) 886 887 # Nested vmap 888 x_orig = torch.randn(B0, B1, 5) 889 x = x_orig.clone() 890 result = vmap(vmap(op))(x) 891 self.assertTrue(result is x) 892 self.assertEqual(result, x_orig.acos()) 893 894 # Nested vmap, large batch size 895 x_orig = torch.randn(B0, B1, B2, 5) 896 x = x_orig.clone() 897 result = vmap(vmap(vmap(op)))(x) 898 self.assertTrue(result is x) 899 self.assertEqual(result, x_orig.acos()) 900 901 def test_inplace_fallback_nary_same_levels(self): 902 # NB: One day we will implement a batching rule for atan2_ 903 # If/when we do, this test should be replaced to test the fallback 904 # path on another operator to avoid bitrot. 905 op = Tensor.atan2_ 906 outplace_op = torch.atan2 907 908 x = torch.randn(5, 7, 11) 909 y = torch.randn(5, 7, 11) 910 self._assert_uses_vmap_fallback((op,), (x, y)) 911 912 # Single vmap 913 B0 = 5 914 x_orig = torch.randn(7, 11, B0) 915 x = x_orig.clone() 916 y = torch.randn(B0, 7, 11) 917 vmap(op, (2, 0))(x, y) 918 self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2))) 919 920 # Nested vmap 921 B0, B1 = 5, 7 922 x_orig = torch.randn(B1, 11, B0) 923 x = x_orig.clone() 924 y = torch.randn(B0, B1, 11) 925 vmap(vmap(op), (2, 0))(x, y) 926 self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0]))) 927 928 # big batch size (total 10000) 929 B0, B1, B2 = 100, 10, 10 930 x_orig = torch.randn(B0, B1, B2, 5) 931 x = x_orig.clone() 932 y = torch.randn(B0, B1, B2) 933 vmap(vmap(vmap(op)))(x, y) 934 self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1))) 935 936 # ("Fallback isInplaceVmapCompatible check is broken") 937 @unittest.expectedFailure 938 def test_inplace_fallback_nary_different_levels(self): 939 # NB: One day we will implement a batching rule for atan2_ 940 # If/when we do, this test should be replaced to test the fallback 941 # path on another operator to avoid bitrot. 942 op = Tensor.atan2_ 943 outplace_op = torch.atan2 944 B0, B1 = 2, 3 945 946 x = torch.rand(B0, 7) 947 y = torch.rand(7) 948 self._assert_uses_vmap_fallback((op, (0, None)), (x, y)) 949 950 # op(left, right): All of the levels in right are found in left 951 x_orig = torch.rand(B0, 7) 952 x = x_orig.clone() 953 y = torch.rand(7) 954 vmap(op, in_dims=(0, None))(x, y) 955 self.assertEqual(x, outplace_op(x_orig, y)) 956 957 x_orig = torch.rand(B0, B1, 7) 958 x = x_orig.clone() 959 y = torch.rand(B0, 7) 960 vmap(vmap(op, in_dims=(0, None)))(x, y) 961 self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7))) 962 963 # op(left, right): Some of the levels in right are not found in left 964 msg = r"vmap: aten::atan2_\(self, \*extra_args\) is not possible" 965 x = torch.rand(7) 966 y = torch.rand(B0, 7) 967 with self.assertRaisesRegex(RuntimeError, msg): 968 vmap(op, in_dims=(None, 0))(x, y) 969 970 x = torch.rand(B1, 7) 971 y = torch.rand(B0, 7) 972 with self.assertRaisesRegex(RuntimeError, msg): 973 vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y) 974 975 x = torch.rand(B1, 7) 976 y = torch.rand(7, B0) 977 with self.assertRaisesRegex(RuntimeError, msg): 978 vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y) 979 980 x = torch.rand(B0, 7) 981 y = torch.rand(B0, B1, 7) 982 with self.assertRaisesRegex(RuntimeError, msg): 983 vmap(vmap(op, in_dims=(None, 0)))(x, y) 984 985 def test_backward_unsupported_interaction(self): 986 x = torch.randn(3, requires_grad=True) 987 y = torch.randn(5) 988 grad = torch.randn_like(x) 989 err_msg = r"backward\(\) called inside a functorch transform" 990 991 def backward_on_vmapped_tensor(x): 992 x.sum().backward() 993 994 # FIXME 995 return self.skipTest( 996 "error: element 0 of tensors does not require grad and does not have a grad_fn" 997 ) 998 with self.assertRaisesRegex(RuntimeError, err_msg): 999 vmap(backward_on_vmapped_tensor)(x) 1000 1001 def backward_with_vmapped_grad(x, grad): 1002 x.backward(grad) 1003 1004 with self.assertRaisesRegex(RuntimeError, err_msg): 1005 vmap(backward_with_vmapped_grad)(x, grad) 1006 1007 def completely_unrelated_backward(y): 1008 x.sum().backward() 1009 return y 1010 1011 with self.assertRaisesRegex(RuntimeError, err_msg): 1012 vmap(completely_unrelated_backward)(y) 1013 1014 @unittest.expectedFailure 1015 def test_grad_unsupported_interaction(self): 1016 input_tensor = torch.randn(3, requires_grad=True) 1017 err_msg = "autograd.grad.* called inside torch.vmap" 1018 1019 captured = torch.randn(3, requires_grad=True) 1020 1021 def output_to_grad_is_vmapped(input_tensor): 1022 output = (captured * input_tensor).sum() 1023 return torch.autograd.grad([output], [captured])[0] 1024 1025 with self.assertRaisesRegex(RuntimeError, err_msg): 1026 vmap(output_to_grad_is_vmapped)(input_tensor) 1027 1028 output = (input_tensor**2).sum() 1029 1030 def input_to_grad_is_vmapped(input_tensor): 1031 return torch.autograd.grad([output], [input_tensor])[0] 1032 1033 with self.assertRaisesRegex(RuntimeError, err_msg): 1034 vmap(input_to_grad_is_vmapped)(input_tensor) 1035 1036 def test_batched_gradient_basic(self): 1037 N = 3 1038 x = torch.randn(N, requires_grad=True) 1039 y = torch.randn(N) 1040 1041 def vjp_mul(v): 1042 return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0] 1043 1044 batched_v = torch.eye(N) 1045 jacobian = vmap(vjp_mul)(batched_v) 1046 self.assertEqual(jacobian, torch.diagflat(y)) 1047 1048 def test_functools_partial(self): 1049 x = torch.randn(3) 1050 y = torch.randn(2, 3) 1051 result = vmap(functools.partial(torch.mul, x))(y) 1052 self.assertEqual(result, x * y) 1053 1054 def test_nn_module(self): 1055 tensor = torch.randn(2, 3) 1056 model = torch.nn.Linear(3, 3, bias=False) 1057 result = vmap(model)(tensor) 1058 self.assertEqual(result, model(tensor)) 1059 1060 def test_fallback_with_undefined_grad(self): 1061 B0 = 7 1062 x = torch.randn(2, 3, 4, 5, requires_grad=True) 1063 weight = torch.randn(3, 3, 1, 1) 1064 v = torch.randn(B0, 2, 3, 4, 5) 1065 1066 def get_vjp(v): 1067 result = torch.nn.functional.conv2d(x, weight) 1068 (grad_x,) = torch.autograd.grad(result, x, v) 1069 return grad_x 1070 1071 # Runs vmap(get_vjp)(v), which should not error out. 1072 # The backward formula for convolution returns an undefined 1073 # Tensor for grad_bias because the original bias does not exist. 1074 # 1075 # In the future we'll probably add a batching rule for convolution 1076 # backward. When this happens, we should modify this test to use a 1077 # different op (and/or create and use a dummy operator) to avoid bitrot. 1078 self._assert_uses_vmap_fallback([get_vjp], [v]) 1079 1080 def test_reshape_dim_into(self): 1081 x = torch.randn(2, 3, 5, 7) 1082 1083 y = reshape_dim_into(0, 0, x) 1084 self.assertEqual(y, x.reshape(6, 5, 7)) 1085 1086 y = reshape_dim_into(0, 1, x) 1087 self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7)) 1088 1089 y = reshape_dim_into(0, 2, x) 1090 self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7)) 1091 1092 y = reshape_dim_into(1, 2, x) 1093 self.assertEqual(y, x.movedim(1, 2).reshape(2, 5, 3 * 7)) 1094 1095 y = reshape_dim_into(0, -2, x) 1096 self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7)) 1097 1098 y = reshape_dim_into(0, -1, x) 1099 self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7)) 1100 1101 y = reshape_dim_into(-4, -1, x) 1102 self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7)) 1103 1104 def test_reshape_dim_outof(self): 1105 x = torch.randn(12, 12, 12).permute(2, 1, 0) 1106 1107 y = reshape_dim_outof(0, 2, x) 1108 self.assertEqual(y, x.reshape(2, 6, 12, 12)) 1109 1110 y = reshape_dim_outof(1, 4, x) 1111 self.assertEqual(y, x.reshape(12, 4, 3, 12)) 1112 1113 y = reshape_dim_outof(2, 6, x) 1114 self.assertEqual(y, x.reshape(12, 12, 6, 2)) 1115 1116 y = reshape_dim_outof(-1, 6, x) 1117 self.assertEqual(y, x.reshape(12, 12, 6, 2)) 1118 1119 # Case: `0` sized dim. 1120 x = torch.randn(12, 12, 0) 1121 y = reshape_dim_outof(-1, 6, x) 1122 self.assertEqual(y.shape, torch.Size((12, 12, 6, 0))) 1123 1124 def test_batch_rule_does_not_need_to_handle_no_batched_input(self): 1125 def f(x, y): 1126 res = torch.dot(y, torch.ones(2)) 1127 return x + res 1128 1129 x = torch.randn(7, 5) 1130 y = torch.randn(3, 2) 1131 out = vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y) 1132 expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x 1133 self.assertEqual(out, expected) 1134 1135 def test_decomposition_under_python_dispatcher(self): 1136 # This test will raise an error if the vmap fallback gets invoked. 1137 # Here we test that decomps registered to FuncTorchBatchedDecomposition 1138 # are respected by the Python Dispatcher. 1139 t = torch.ones(3, 3) * 5 1140 with DisableVmapFallback(): 1141 with torch._dispatch.python.enable_python_dispatcher(): 1142 o = torch.vmap(torch.square)(t) 1143 self.assertEqual(o, torch.square(t)) 1144 1145 def _test_vmap_autocast(self, device): 1146 if torch.device(device).type == "cpu": 1147 amp_dtype = torch.bfloat16 1148 else: 1149 amp_dtype = torch.float16 1150 1151 a_float32 = torch.rand(4, 2, 3, device=device) 1152 b_float32 = torch.rand(4, 3, 2, device=device) 1153 c_float32 = torch.rand(4, 2, 2, device=device) 1154 d_float32 = torch.rand(4, 3, 2, device=device) 1155 1156 # Case 1, autocast inside vmapped function 1157 def func1(x, y, z, w): 1158 with torch.autocast(dtype=amp_dtype, device_type=device): 1159 e_float16 = torch.matmul(x, y) 1160 assert e_float16.dtype == amp_dtype, e_float16.dtype 1161 f_float16 = torch.matmul(z, e_float16) 1162 assert f_float16.dtype == amp_dtype, f_float16.dtype 1163 return torch.matmul(w, f_float16.float()) 1164 1165 expected = func1(a_float32, b_float32, c_float32, d_float32) 1166 out = vmap(func1)(a_float32, b_float32, c_float32, d_float32) 1167 assert expected.allclose(out) 1168 1169 # Case 2, autocast decorator inside vmapped function 1170 @torch.autocast(dtype=amp_dtype, device_type=device) 1171 def func2(x, y, z, w): 1172 e_float16 = torch.matmul(x, y) 1173 assert e_float16.dtype == amp_dtype, e_float16.dtype 1174 f_float16 = torch.matmul(z, e_float16) 1175 assert f_float16.dtype == amp_dtype, f_float16.dtype 1176 return torch.matmul(w, f_float16) 1177 1178 expected = func2(a_float32, b_float32, c_float32, d_float32) 1179 out = vmap(func2)(a_float32, b_float32, c_float32, d_float32) 1180 assert expected.allclose(out) 1181 1182 # Case 3, autocast is outside vmapped function 1183 def func3(x, y, z, w): 1184 e_float16 = torch.matmul(x, y) 1185 assert e_float16.dtype == amp_dtype, e_float16.dtype 1186 f_float16 = torch.matmul(z, e_float16) 1187 assert f_float16.dtype == amp_dtype, f_float16.dtype 1188 return torch.matmul(w, f_float16) 1189 1190 with torch.autocast(dtype=amp_dtype, device_type=device): 1191 expected = func3(a_float32, b_float32, c_float32, d_float32) 1192 out = vmap(func3)(a_float32, b_float32, c_float32, d_float32) 1193 1194 assert expected.allclose(out) 1195 1196 @unittest.skip("Somehow, vmap and autocast do not work on CPU") 1197 def test_vmap_autocast_cpu(self): 1198 self._test_vmap_autocast("cpu") 1199 1200 @skipIf(not torch.cuda.is_available(), "CUDA is unavailable") 1201 def test_vmap_autocast_cuda(self): 1202 self._test_vmap_autocast("cuda") 1203 1204 def test_restore_vmap_pytree_input_output(self): 1205 def f(x, y): 1206 output0 = x[0] + x[1] 1207 output1 = y 1208 return {"a": output0, "b": output1} 1209 1210 B = 2 1211 x0 = torch.randn(B, 3) 1212 x1 = torch.randn(B) 1213 y = torch.randn(4, B) 1214 1215 out, out_dims = restore_vmap(f, ((0, 0), 1), B, "error")((x0, x1), y) 1216 expected = vmap(f, in_dims=((0, 0), 1), out_dims={"a": 0, "b": 1})((x0, x1), y) 1217 self.assertEqual(out, expected) 1218 self.assertEqual(out_dims, {"a": 0, "b": 1}) 1219 1220 def test_restore_vmap_no_vmapped_inputs(self): 1221 def f(x, y, z): 1222 return x, y * z, z 1223 1224 B = 2 1225 # Mix of tensor and non-tensor inputs 1226 x = torch.randn(3) 1227 y = torch.randn(4) 1228 z = 5 1229 out, out_dims = restore_vmap(f, (None, None, None), B, "error")(x, y, z) 1230 self.assertEqual(out, f(x, y, z)) 1231 self.assertEqual(out_dims, (None, None, None)) 1232 1233 def test_restore_vmap_unexpanded_outputs(self): 1234 def f(x, y): 1235 # Mix of tensor and non-tensor outputs 1236 return 3 * y, y.sum(), None 1237 1238 B = 2 1239 x = torch.randn(B, 3) 1240 y = torch.randn(4) 1241 out, out_dims = restore_vmap(f, (0, None), B, "error")(x, y) 1242 self.assertEqual(out, f(None, y)) 1243 self.assertEqual(out_dims, (None, None, None)) 1244 1245 def test_data_attribute(self): 1246 def foo(x): 1247 y = x.data 1248 return x 1249 1250 with self.assertRaisesRegex( 1251 RuntimeError, "accessing `data` under vmap transform" 1252 ): 1253 torch.func.vmap(foo)(torch.randn(3, 3)) 1254 1255 def foo(x): 1256 x.data = torch.ones(3, 3) 1257 return x 1258 1259 with self.assertRaisesRegex( 1260 RuntimeError, "mutating directly with `.data` under vmap" 1261 ): 1262 torch.func.vmap(foo)(torch.randn(3, 3)) 1263 1264 1265def slice_inputs(inputs, bdims, i): 1266 result = [] 1267 for inp, bdim in zip(inputs, bdims): 1268 if bdim is None: 1269 result.append(inp) 1270 else: 1271 result.append(inp.select(bdim, i)) 1272 return tuple(result) 1273 1274 1275def reference_vmap(op, inputs, in_dims=0, out_dims=0, return_nt=False): 1276 if isinstance(in_dims, int): 1277 in_dims = (in_dims,) * len(inputs) 1278 bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None] 1279 assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes) 1280 bdim_size = bdim_sizes[0] 1281 results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size)) 1282 1283 assert len(results) > 0 1284 op_has_single_return = not isinstance(results[0], tuple) 1285 if op_has_single_return: 1286 assert all(isinstance(result, torch.Tensor) for result in results) 1287 if isinstance(out_dims, int): 1288 out_dims = (out_dims,) * 1 1289 if return_nt: 1290 return torch.nested.nested_tensor(list(results)) 1291 else: 1292 return torch.stack(results, dim=out_dims[0]) 1293 1294 assert all(isinstance(result, tuple) for result in results) 1295 num_returns = len(results[0]) 1296 assert all(len(result) == num_returns for result in results) 1297 if isinstance(out_dims, int): 1298 out_dims = (out_dims,) * num_returns 1299 if return_nt: 1300 return tuple( 1301 torch.nested.nested_tensor(list(result_shards)) 1302 for result_shards in zip(*results) 1303 ) 1304 else: 1305 return tuple( 1306 torch.stack(result_shards, out_dim) 1307 for result_shards, out_dim in zip(zip(*results), out_dims) 1308 ) 1309 1310 1311class TensorFactory: 1312 @staticmethod 1313 def rand(size, device="cpu", dtype=torch.float): 1314 return torch.rand(size, device=device, dtype=dtype) 1315 1316 @staticmethod 1317 def randn(size, device="cpu", dtype=torch.float): 1318 return torch.randn(size, device=device, dtype=dtype) 1319 1320 @staticmethod 1321 def randp1(size, device="cpu", dtype=torch.float): 1322 return torch.rand(size, device=device, dtype=dtype) + 1 1323 1324 1325# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a 1326# (slow) sequential map+stack fallback. 1327# 1328# check_view: Test if the first returned output is a view of the first input 1329# check_propagates_grad: Test if the operation propagates gradients. 1330 1331 1332def _vmap_test( 1333 self, 1334 op, 1335 inputs, 1336 in_dims=0, 1337 out_dims=0, 1338 check_view=False, 1339 check_propagates_grad=True, 1340): 1341 result = vmap(op, in_dims, out_dims)(*inputs) 1342 are_nested = [t.is_nested for t in pytree.tree_leaves(result)] 1343 reference_result = reference_vmap( 1344 op, inputs, in_dims, out_dims, return_nt=any(are_nested) 1345 ) 1346 self.assertEqual(result, reference_result) 1347 op_has_single_return = not isinstance(result, tuple) 1348 1349 if check_view: 1350 result_as_tuple = (result,) if op_has_single_return else result 1351 for output in result_as_tuple: 1352 input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base 1353 self.assertTrue( 1354 output._base is input0_base, 1355 msg="result was not a view of the first input!", 1356 ) 1357 1358 if not check_propagates_grad: 1359 return 1360 # Assuming input[0] is a floating-point tensor. Check if the vmap 1361 # operation propagates the requires_grad flag to the zeroth output. 1362 # Some vmap operators are implemented in a way that assumes that 1363 # they are composite with respect to autograd. If the operator ever is 1364 # changed to not be composite with respect to autograd, then the 1365 # following check should fail. 1366 inputs_clone = list(inputs) 1367 inputs_clone[0] = inputs[0].clone().requires_grad_() 1368 result = vmap(op, in_dims, out_dims)(*inputs_clone) 1369 result_as_tuple = (result,) if op_has_single_return else result 1370 self.assertTrue(result[0].requires_grad) 1371 1372 1373def should_allow_vmap_fallback_usage(fn): 1374 return getattr(fn, "_allow_vmap_fallback_usage", False) 1375 1376 1377def allowVmapFallbackUsage(fn): 1378 fn._allow_vmap_fallback_usage = True 1379 return fn 1380 1381 1382# All tests of TestVmapBase check that the slow vmap fallback is never invoked. 1383# This is so that we can incrementally add batching rules for operators to 1384# replace the slow vmap fallback path for said operators. To skip this check, 1385# please use the allowVmapFallbackUsage decorator. 1386# 1387# NB: Don't add tests to TestVmapBase directly, unless you want them to run 1388# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators. 1389# 1390# NB: TestVmapBase is a nested class. This prevents test runners from picking 1391# it up and running it. 1392 1393 1394class Namespace: 1395 class TestVmapBase(TestCase): 1396 def __init__(self, method_name="runTest"): 1397 super().__init__(method_name) 1398 1399 test_method = getattr(self, method_name, None) 1400 if test_method is None: 1401 return 1402 1403 if not should_allow_vmap_fallback_usage(test_method): 1404 setattr( 1405 self, 1406 method_name, 1407 self._wrap_method_with_vmap_fallback_check(test_method), 1408 ) 1409 1410 def _wrap_method_with_vmap_fallback_check(self, method): 1411 # msg = ( 1412 # 'Expected the test to not invoke the vmap fallback path, i.e., ' 1413 # 'all of the operators being tested in this test should have batching ' 1414 # 'rules implemented. If you are intentionally testing something to ' 1415 # 'do with the fallback path, use allowVmapFallbackUsage. Otherwise, ' 1416 # 'please make sure that batching rules are implemented for the ' 1417 # 'operator(s) being tested.' 1418 # ) 1419 1420 @functools.wraps(method) 1421 def wrapper(self, *args, **kwargs): 1422 with warnings.catch_warnings(record=True): 1423 warnings.simplefilter("always") 1424 with EnableVmapFallbackWarnings(): 1425 method(*args, **kwargs) 1426 # for captured_warning in wa: 1427 # self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg) 1428 1429 return types.MethodType(wrapper, self) 1430 1431 @allowVmapFallbackUsage 1432 def test_vmap_fallback_check_ok(self): 1433 # One day we'll implement a batching rule for torch.var_mean. 1434 # When that happens, please change the example to use an 1435 # operator that doesn't have a batching rule implemented. 1436 op_using_fallback = torch.var_mean 1437 vmap(op_using_fallback)(torch.rand(3)) 1438 1439 @unittest.expectedFailure 1440 def test_vmap_fallback_check(self): 1441 @self._wrap_method_with_vmap_fallback_check 1442 def no_fallback(self): 1443 pass 1444 1445 # One day we'll implement a batching rule for torch.var_mean. 1446 # When that happens, please change the example to use an 1447 # operator that doesn't have a batching rule implemented. 1448 op_using_fallback = torch.var_mean 1449 1450 @self._wrap_method_with_vmap_fallback_check 1451 def uses_fallback(self): 1452 vmap(op_using_fallback)(torch.rand(3)) 1453 1454 no_fallback(self) 1455 1456 with self.assertRaises(AssertionError): 1457 uses_fallback(self) 1458 1459 1460def _make_case(op, input_getter=TensorFactory.randn): 1461 return (op, input_getter) 1462 1463 1464@markDynamoStrictTest 1465class TestVmapOperators(Namespace.TestVmapBase): 1466 def _vmap_test(self, *args, **kwargs): 1467 return _vmap_test(self, *args, **kwargs) 1468 1469 def _vmap_view_test(self, *args, **kwargs): 1470 self._vmap_test(*args, **kwargs, check_view=True) 1471 1472 def _test_unary(self, op, getter, device, *args, **kwargs): 1473 test = functools.partial(self._vmap_test, *args, **kwargs) 1474 B0, B1 = 7, 11 1475 1476 # Single vmap, various in_dims / out_dims 1477 test(op, [getter([B0, 3], device)]) 1478 test(op, [getter([2, 5, B0, 3], device)], in_dims=2) 1479 test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2) 1480 1481 # Doubly nested vmap 1482 test(vmap(op), [getter([B0, B1], device)]) 1483 test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2) 1484 test( 1485 vmap(op, in_dims=2), 1486 [getter([2, 5, B0, B1, 3], device)], 1487 in_dims=2, 1488 out_dims=2, 1489 ) 1490 1491 @parametrize( 1492 "case", 1493 [ 1494 (torch.abs, TensorFactory.randn), 1495 (torch.acos, TensorFactory.rand), 1496 (torch.asin, TensorFactory.rand), 1497 (torch.atan, TensorFactory.rand), 1498 (torch.ceil, TensorFactory.randn), 1499 (torch.cos, TensorFactory.rand), 1500 (torch.cosh, TensorFactory.rand), 1501 (torch.digamma, TensorFactory.rand), 1502 (torch.exp, TensorFactory.randn), 1503 (torch.expm1, TensorFactory.randn), 1504 (torch.floor, TensorFactory.randn), 1505 (torch.frac, TensorFactory.randn), 1506 (torch.lgamma, TensorFactory.rand), 1507 (torch.log, TensorFactory.randp1), 1508 (torch.log10, TensorFactory.randp1), 1509 (torch.log1p, TensorFactory.randp1), 1510 (torch.log2, TensorFactory.randp1), 1511 (torch.neg, TensorFactory.randn), 1512 (torch.reciprocal, TensorFactory.randp1), 1513 (torch.relu, TensorFactory.randn), 1514 (torch.round, TensorFactory.randn), 1515 (torch.rsqrt, TensorFactory.randp1), 1516 (torch.sigmoid, TensorFactory.randn), 1517 (torch.sign, TensorFactory.randn), 1518 (torch.sin, TensorFactory.rand), 1519 (torch.sinh, TensorFactory.rand), 1520 (torch.sqrt, TensorFactory.rand), 1521 (torch.tan, TensorFactory.rand), 1522 (torch.tanh, TensorFactory.rand), 1523 (torch.trunc, TensorFactory.randn), 1524 ], 1525 name_fn=lambda x: x[0].__name__, 1526 ) 1527 def test_unary_pointwise(self, case): 1528 op, getter = case 1529 self._test_unary(op, getter, "cpu") 1530 1531 # test in-place 1532 method = getattr(Tensor, f'{op.__name__ + "_"}') 1533 self._test_unary(method, getter, "cpu", check_propagates_grad=False) 1534 1535 def test_clone(self): 1536 # Some basic tests 1537 self._test_unary(lambda x: x.clone(), TensorFactory.randn, "cpu") 1538 self._test_unary( 1539 lambda x: x.clone(memory_format=torch.preserve_format), 1540 TensorFactory.randn, 1541 "cpu", 1542 ) 1543 self._test_unary( 1544 lambda x: x.clone(memory_format=torch.contiguous_format), 1545 TensorFactory.randn, 1546 "cpu", 1547 ) 1548 1549 # Test that the per-examples are contiguous when using torch.contiguous_format 1550 def clone_contiguous(x): 1551 return x.clone(memory_format=torch.contiguous_format) 1552 1553 B0, B1 = 3, 5 1554 x = torch.randn(2, B0, 7) 1555 y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x) 1556 self.assertTrue(y.movedim(1, 0).is_contiguous()) 1557 self.assertTrue(y[:, 0, :].is_contiguous()) 1558 1559 x = torch.randn(2, B0, 7, B1) 1560 y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x) 1561 self.assertTrue(y.is_contiguous()) 1562 self.assertTrue(y[0][0].is_contiguous()) 1563 1564 msg = r"only supported with memory_format torch.preserve_format or torch.contiguous_format" 1565 with self.assertRaisesRegex(RuntimeError, msg): 1566 vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0)) 1567 with self.assertRaisesRegex(RuntimeError, msg): 1568 vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))( 1569 torch.randn(B0) 1570 ) 1571 1572 def test_weird_matmul_case(self): 1573 # Check that this doesn't crash. 1574 # https://github.com/pytorch/functorch/issues/417 1575 x = torch.randn(5, 2, 2, 2) 1576 y = torch.randn(5, 7, 2) 1577 1578 vmap(vmap(torch.matmul, in_dims=(None, 0)))(x, y) 1579 1580 @parametrize( 1581 "case", 1582 ( 1583 (torch.clamp_min_, TensorFactory.randn), 1584 (torch.clamp_max_, TensorFactory.randn), 1585 ), 1586 name_fn=lambda x: x[0].__name__, 1587 ) 1588 def test_clamp_inplace_variant(self, case): 1589 test = self._vmap_test 1590 1591 def get_number(getter): 1592 return getter([]).item() 1593 1594 op, getter = case 1595 device = "cpu" 1596 B0, B1 = 7, 11 1597 1598 # Single vmap: op(Tensor, Tensor) 1599 test( 1600 op, 1601 (getter([B0, 3], device), getter([B0, 3], device)), 1602 check_propagates_grad=False, 1603 ) 1604 test( 1605 op, 1606 (getter([B0], device), getter([B0], device)), 1607 check_propagates_grad=False, 1608 ) 1609 test( 1610 op, 1611 (getter([2, B0, 3], device), getter([2, B0, 3], device)), 1612 in_dims=(1, 1), 1613 check_propagates_grad=False, 1614 ) 1615 test( 1616 op, 1617 (getter([B0, 2, 3], device), getter([2, B0, 3], device)), 1618 in_dims=(0, 1), 1619 out_dims=1, 1620 check_propagates_grad=False, 1621 ) 1622 test( 1623 op, 1624 (getter([B0, 2, 3], device), getter([1, 1], device)), 1625 in_dims=(0, None), 1626 check_propagates_grad=False, 1627 ) 1628 test( 1629 op, 1630 (getter([B0, 3], device), getter([B0, 3], device)), 1631 in_dims=(0, 0), 1632 check_propagates_grad=False, 1633 ) 1634 1635 # Nested vmap: op(Tensor, Tensor) 1636 test( 1637 vmap(op), 1638 (getter([B0, B1, 2, 3], device), getter([B0, B1, 1, 3], device)), 1639 check_propagates_grad=False, 1640 ) 1641 1642 # Python number overload: op(Tensor, Number) 1643 number = get_number(getter) 1644 self._test_unary( 1645 lambda t: op(t, number), getter, device, check_propagates_grad=False 1646 ) 1647 1648 @parametrize( 1649 "case", 1650 [ 1651 subtest(_make_case(torch.clamp_min), name="clamp_min"), 1652 subtest(_make_case(torch.clamp_max), name="clamp_max"), 1653 ], 1654 ) 1655 def test_clamp_variant(self, case): 1656 test = self._vmap_test 1657 1658 def get_number(getter): 1659 return getter([]).item() 1660 1661 op, getter = case 1662 device = "cpu" 1663 B0, B1 = 7, 11 1664 1665 # Single vmap: op(Tensor, Tensor) 1666 test(op, (getter([B0, 3], device), getter([B0, 3], device))) 1667 test(op, (getter([B0], device), getter([B0, 2, 3], device))) 1668 test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1)) 1669 test( 1670 op, 1671 (getter([B0], device), getter([2, B0, 3], device)), 1672 in_dims=(0, 1), 1673 out_dims=1, 1674 ) 1675 test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None)) 1676 test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(None, 0)) 1677 1678 # Nested vmap: op(Tensor, Tensor) 1679 test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))) 1680 test( 1681 vmap(op, in_dims=(None, 0)), 1682 (getter([B0, 2, 3], device), getter([B1, 3], device)), 1683 in_dims=(0, None), 1684 ) 1685 1686 # Python number overload: op(Tensor, Number) 1687 number = get_number(getter) 1688 self._test_unary(lambda t: op(t, number), getter, device) 1689 1690 def test_copy_(self): 1691 x = torch.randn(3) 1692 y = torch.randn(3) 1693 vmap(Tensor.copy_)(x, y) 1694 self.assertEqual(x, y) 1695 1696 x = torch.randn(3) 1697 y = torch.randn(3, 2) 1698 vmap(Tensor.copy_, in_dims=(1, None))(y, x) 1699 self.assertEqual(y, x.expand(2, 3).t()) 1700 1701 x = torch.randn(3) 1702 y = torch.randn(2, 3) 1703 with self.assertRaisesRegex(RuntimeError, "inplace"): 1704 vmap(Tensor.copy_, in_dims=(None, 0))(x, y) 1705 1706 def test_silu_backward(self): 1707 test = self._vmap_test 1708 device = "cpu" 1709 getter = TensorFactory.randp1 1710 B0 = 7 1711 op = torch.ops.aten.silu_backward 1712 1713 # Single vmap: op(Tensor, Tensor) 1714 test(op, (getter([B0, 3], device), getter([B0, 3], device))) 1715 test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0)) 1716 test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None)) 1717 1718 @skipIf( 1719 TEST_WITH_TORCHDYNAMO 1720 and os.getenv("BUILD_ENVIRONMENT", "") == "linux-focal-py3.8-clang10", 1721 "Segfaults with dynamo on focal, see https://github.com/pytorch/pytorch/issues/107173", 1722 ) 1723 @parametrize( 1724 "case", 1725 [ 1726 subtest(_make_case(torch.add), name="add"), 1727 subtest(_make_case(lambda x, y: x + y), name="add_dunder"), 1728 subtest(_make_case(torch.sub), name="sub"), 1729 subtest(_make_case(lambda x, y: x - y), name="sub_dunder"), 1730 subtest(_make_case(torch.mul), name="mul"), 1731 subtest(_make_case(lambda x, y: x * y), name="mul_dunder"), 1732 subtest( 1733 _make_case(torch.div, input_getter=TensorFactory.randp1), name="div" 1734 ), 1735 subtest( 1736 _make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), 1737 name="div_dunder", 1738 ), 1739 subtest( 1740 _make_case(torch.pow, input_getter=TensorFactory.randp1), name="pow" 1741 ), 1742 subtest( 1743 _make_case(lambda x, y: x**y, input_getter=TensorFactory.randp1), 1744 name="pow_dunder", 1745 ), 1746 ], 1747 ) 1748 def test_arithmetic(self, case): 1749 test = self._vmap_test 1750 1751 def get_number(getter): 1752 return getter([]).item() 1753 1754 op, getter = case 1755 device = "cpu" 1756 B0, B1 = 7, 11 1757 1758 # Single vmap: op(Tensor, Tensor) 1759 test(op, (getter([B0, 3], device), getter([B0, 3], device))) 1760 test(op, (getter([B0], device), getter([B0, 2, 3], device))) 1761 test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1)) 1762 test( 1763 op, 1764 (getter([B0], device), getter([2, B0, 3], device)), 1765 in_dims=(0, 1), 1766 out_dims=1, 1767 ) 1768 test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None)) 1769 test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None)) 1770 1771 # Nested vmap: op(Tensor, Tensor) 1772 test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))) 1773 test( 1774 vmap(op, in_dims=(None, 0)), 1775 (getter([B0, 2, 3], device), getter([B1, 3], device)), 1776 in_dims=(0, None), 1777 ) 1778 1779 # Python number overload: op(Tensor, Number) (and vice-versa) 1780 number = get_number(getter) 1781 self._test_unary(lambda t: op(t, number), getter, device) 1782 number = get_number(getter) 1783 self._test_unary(lambda t: op(number, t), getter, device) 1784 1785 # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor) 1786 test(op, (getter([B0], device), getter([B0], device, dtype=torch.double))) 1787 test(op, (getter([B0], device, dtype=torch.double), getter([B0], device))) 1788 test(op, (getter([B0], device), getter([B0], device))) 1789 1790 # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa) 1791 test(op, (getter([B0, 2], device), getter([B0], device, torch.double))) 1792 test(op, (getter([B0], device, torch.double), getter([B0, 2], device))) 1793 1794 if not torch.cuda.is_available(): 1795 return 1796 1797 # TODO(rzou): fix the following 1798 # # Test cross-device scalars 1799 # number = get_number(getter) 1800 # self._test_unary(lambda t: op(t, number), getter, device='cuda') 1801 # self._test_unary(lambda t: op(number, t), getter, device='cuda') 1802 # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda') 1803 1804 def test_as_strided(self): 1805 def _test(sizes, strides, offset, tensor, lambd): 1806 # bdim at dim 0 test 1807 result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor) 1808 expected = vmap(lambd)(tensor) 1809 self.assertTrue(result._base is expected._base) 1810 self.assertEqual(result, expected) 1811 1812 # bdim at dim -1 test 1813 tensor = tensor.movedim(0, -1) 1814 result = vmap(lambda t: t.as_strided(sizes, strides, offset), -1)(tensor) 1815 expected = vmap(lambd, -1)(tensor) 1816 self.assertTrue(result._base is expected._base) 1817 self.assertEqual(result, expected) 1818 1819 # single vmap test 1820 B0 = 5 1821 # Each Tensor has shape [B0, 2, 3]; the expressions below 1822 # are just to get tensors of different strides that have shape [B0, 2, 3] 1823 tensors = [ 1824 # contiguous 1825 torch.randn(B0, 2, 3), 1826 # non-contiguous 1827 torch.randn(B0, 3, 2).transpose(1, 2), 1828 torch.randn(3, 2, B0).movedim(-1, 0).transpose(1, 2), 1829 # non-zero storage offset 1830 torch.randn(2, B0, 2, 3)[1], 1831 torch.randn(2, 2, B0, 3)[1].movedim(1, 0), 1832 # non-contiguous strides, zero storage offset 1833 torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0], 1834 torch.randn(2, 4, B0, 3, 7).movedim(2, 0)[:, :, 0, :, 0], 1835 # non-contiguous strides, non-zero storage offset 1836 torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1], 1837 torch.randn(2, 4, 3, 7, B0).movedim(-1, 0)[:, :, 2, :, 1], 1838 ] 1839 1840 for x in tensors: 1841 S0, S1 = x.stride()[1:] 1842 offset = x.storage_offset() 1843 1844 # Broadcast 1845 _test( 1846 [5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3) 1847 ) 1848 # transpose 1849 _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1)) 1850 # select 1851 _test([2], [S0], offset + S1, x, lambda x: x[:, 1]) 1852 # diagonal 1853 _test([2], [S0 + S1], offset, x, lambda x: x.diagonal()) 1854 # strided slice 1855 _test([2], [S1 * 2], offset, x, lambda x: x[0, ::2]) 1856 1857 # Nested vmap test 1858 B1 = 7 1859 x = torch.randn(B1, B0, 2, 3) 1860 S0, S1 = x.stride()[2:] 1861 result = vmap( 1862 vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1 1863 )(x) 1864 expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x) 1865 self.assertTrue(result._base is expected._base) 1866 self.assertEqual(result, expected) 1867 1868 # Check that mal-formatted size/strides doesn't crash 1869 with self.assertRaisesRegex( 1870 RuntimeError, "size and stride must have the same length" 1871 ): 1872 x = torch.randn(B0, 2, 3).transpose(0, 1) 1873 vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x) 1874 1875 # All the Sanity check #1{a,b,c} cases check that 1876 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1877 # doesn't index memory that is out of bounds of xs[i]. This condition 1878 # is important to the correctness of the as_strided batching rule 1879 # (see NOTE: [When will the as_strided_batching_rule fail?]) 1880 1881 # Sanity check #1a: The maximum indexable location of 1882 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1883 # is less than or equal to the maximum indexable location of xs[i]. 1884 msg = "This is not supported inside of vmap" 1885 with self.assertRaisesRegex(RuntimeError, msg): 1886 x = torch.randn(B0, 3) 1887 vmap(lambda x: x.as_strided([3], [1], 1))(x) 1888 with self.assertRaisesRegex(RuntimeError, msg): 1889 x = torch.randn(B0, 3, 5) 1890 vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x) 1891 with self.assertRaisesRegex(RuntimeError, msg): 1892 x = torch.randn(B0, B1, 3, 5) 1893 vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x) 1894 1895 # Sanity check #1b: The min indexable location of 1896 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1897 # is greater than or equal to the min indexable location of xs[i]. 1898 with self.assertRaisesRegex(RuntimeError, msg): 1899 x = torch.randn(2, B0, 3)[1] 1900 vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x) 1901 1902 # Sanity check #1c: 1903 # xs[i] is a zero-dim tensor, but 1904 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1905 # is not 1906 with self.assertRaisesRegex(RuntimeError, msg): 1907 x = torch.randn(B0, 0, 3) 1908 vmap(lambda x: x.as_strided([3], [1]))(x) 1909 1910 def test_nll_loss(self): 1911 test = self._vmap_test 1912 op = F.nll_loss 1913 B = 3 1914 1915 y = torch.randn(B, 2, 5) 1916 t = torch.randint(0, 5, (B, 2)) 1917 test(op, (y, t)) 1918 test(functools.partial(op, reduction="sum"), (y, t)) 1919 test(functools.partial(op, reduction="none"), (y, t)) 1920 1921 y = torch.randn(B, 2, 5) 1922 t = torch.randint(0, 5, (2,)) 1923 test(op, (y, t), in_dims=(0, None)) 1924 test(functools.partial(op, reduction="sum"), (y, t), in_dims=(0, None)) 1925 test(functools.partial(op, reduction="none"), (y, t), in_dims=(0, None)) 1926 1927 def test_adaptive_avg_pool2d(self): 1928 test = self._vmap_test 1929 op = functools.partial(F.adaptive_avg_pool2d, output_size=(3, 3)) 1930 1931 x = torch.randn(3, 5, 7, 9, 11) 1932 test(op, (x,)) 1933 test(op, (x,), in_dims=(1,)) 1934 test(op, (x,), in_dims=(4,)) 1935 1936 def test_bmm(self): 1937 op = torch.bmm 1938 test = self._vmap_test 1939 B0, B1 = 7, 11 1940 1941 # shape mismatch 1942 msg = "" 1943 with self.assertRaisesRegex(RuntimeError, msg): 1944 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1945 with self.assertRaisesRegex(RuntimeError, msg): 1946 vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2)) 1947 with self.assertRaisesRegex(RuntimeError, msg): 1948 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) 1949 1950 # left arg is vmapped 1951 test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None)) 1952 test( 1953 vmap(op, in_dims=(0, None)), 1954 (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)), 1955 in_dims=(1, None), 1956 ) 1957 1958 # right arg is vmapped 1959 test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) 1960 test( 1961 vmap(op, in_dims=(None, 0)), 1962 (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)), 1963 in_dims=(None, 1), 1964 ) 1965 1966 # both args are vmapped 1967 test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3))) 1968 test( 1969 vmap(op), 1970 (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), 1971 in_dims=(1, 0), 1972 ) 1973 test( 1974 vmap(op, in_dims=(0, None)), 1975 (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), 1976 in_dims=(None, 0), 1977 ) 1978 1979 def test_cat(self): 1980 test = self._vmap_test 1981 B0, B1 = 5, 7 1982 1983 # Quick hack b/c vmap can't accept a list of tensors as an argument 1984 def get_op(dim): 1985 def op(*tensors): 1986 return torch.cat(tensors, dim=dim) 1987 1988 return op 1989 1990 test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3))) 1991 test(get_op(0), (torch.rand(B0, 0), torch.rand(B0, 0))) 1992 test(get_op(0), (torch.rand(2), torch.rand(B0, 0)), in_dims=(None, 0)) 1993 test( 1994 get_op(1), 1995 (torch.rand(2, 5), torch.rand(B0, 0), torch.rand(2, 3)), 1996 in_dims=(None, 0, None), 1997 ) 1998 test(get_op(1), (torch.rand(B0, 2, 3), torch.rand(B0, 0))) 1999 test(get_op(1), (torch.rand(B0, 2, 3, 4), torch.rand(0)), in_dims=(0, None)) 2000 test( 2001 get_op(0), 2002 (torch.rand(0), torch.rand(B0, 2), torch.rand(B0, 0)), 2003 in_dims=(None, 0, 0), 2004 ) 2005 test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0)) 2006 test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2)) 2007 test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2)) 2008 test( 2009 vmap(get_op(0), in_dims=(0, None)), 2010 (torch.rand(B1, 2), torch.rand(B0, 3)), 2011 in_dims=(None, 0), 2012 ) 2013 test( 2014 vmap(get_op(0), in_dims=(0, 0)), 2015 (torch.rand(B1, 2), torch.rand(B0, B1, 3)), 2016 in_dims=(None, 0), 2017 ) 2018 2019 def test_unsafe_view(self): 2020 # Unsafe view isn't exposed, so we get at it via 2021 # vmap(grad(matmul)) 2022 test = functools.partial(self._vmap_test, check_propagates_grad=False) 2023 B = 2 2024 x = torch.randn(B, 2, 3, 3) 2025 y = torch.randn(B, 3, 3) 2026 2027 def baz(x, y): 2028 return (x @ y).sum() 2029 2030 test(functorch.grad(baz), (x, y)) 2031 2032 def test_conj(self): 2033 op = torch.conj 2034 2035 def run_test(dtype): 2036 def get(shape): 2037 return torch.randn(shape, dtype=dtype) 2038 2039 B0, B1 = 7, 11 2040 test = self._vmap_test 2041 2042 # Single vmap, various in_dims / out_dims 2043 test(op, [get([B0, 3])]) 2044 test(op, [get([2, 5, B0, 3])], in_dims=2) 2045 test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) 2046 2047 # Doubly nested vmap 2048 test(vmap(op), [get([B0, B1])]) 2049 test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2) 2050 test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2) 2051 2052 # correctness tests 2053 run_test(torch.float) 2054 run_test(torch.cfloat) 2055 2056 # check that torch.conj on a non-complex tensor returns the same tensor 2057 real_tensor = torch.randn(3) 2058 result = vmap(op)(real_tensor) 2059 self.assertEqual(result.data_ptr(), real_tensor.data_ptr()) 2060 2061 def test_contiguous(self): 2062 op = Tensor.contiguous 2063 2064 self._test_unary(op, TensorFactory.randn, "cpu") 2065 2066 # check that contiguous returns the original tensor if the per-examples 2067 # are already contiguous 2068 B0 = 3 2069 x = torch.randn(B0, 2, 5, 7) 2070 x = x.movedim(0, 2) 2071 result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x) 2072 self.assertTrue(result is x) 2073 2074 msg = "NYI: querying is_contiguous inside of vmap for memory_format" 2075 tensor = torch.randn(B0, 3) 2076 with self.assertRaisesRegex(RuntimeError, msg): 2077 vmap(functools.partial(op, memory_format=torch.channels_last))(tensor) 2078 with self.assertRaisesRegex(RuntimeError, msg): 2079 vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor) 2080 2081 def test_stride(self): 2082 B0 = 3 2083 2084 x = torch.randn(B0, 2, 5, 7) 2085 2086 def foo(x): 2087 assert x.stride() == (7 * 5, 7, 1) 2088 return x 2089 2090 vmap(foo)(x) 2091 2092 x = torch.randn(2, B0, 5, 7).movedim(1, 0) 2093 2094 def bar(x): 2095 assert x.stride() == (7 * 5 * B0, 7, 1) 2096 return x 2097 2098 vmap(bar)(x) 2099 2100 def test_chunk(self): 2101 test = self._vmap_view_test 2102 op = torch.chunk 2103 B0, B1, B2 = 7, 11, 13 2104 2105 # tests for torch.split(self, split_size: int, dim) 2106 test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None)) 2107 test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None)) 2108 test( 2109 vmap(op, in_dims=(0, None, None)), 2110 (torch.rand(B1, 1023, B0, 5), 4, 0), 2111 in_dims=(2, None, None), 2112 ) 2113 test( 2114 vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 2115 (torch.rand(B1, 2, B0, 64, B2),), 2116 in_dims=2, 2117 ) 2118 2119 def test_clamp(self): 2120 clamp_cases = ( 2121 (lambda t: t.clamp(min=-0.5), TensorFactory.randn), 2122 (lambda t: t.clamp(max=0.5), TensorFactory.randn), 2123 (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn), 2124 (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn), 2125 (lambda t: t.clamp_max(max=0.5), TensorFactory.randn), 2126 ) 2127 for op, getter in clamp_cases: 2128 self._test_unary(op, getter, "cpu") 2129 2130 def test_comparison_ops(self): 2131 test = functools.partial(self._vmap_test, check_propagates_grad=False) 2132 2133 getter = TensorFactory.randn 2134 B0, B1 = 7, 11 2135 2136 ops = ( 2137 torch.eq, 2138 lambda x, y: x == y, 2139 torch.gt, 2140 lambda x, y: x > y, 2141 torch.ge, 2142 lambda x, y: x >= y, 2143 torch.le, 2144 lambda x, y: x <= y, 2145 torch.lt, 2146 lambda x, y: x < y, 2147 torch.ne, 2148 lambda x, y: x != y, 2149 ) 2150 2151 for op in ops: 2152 # Single vmap: op(Tensor, Tensor) 2153 test(op, (getter([B0, 3]), getter([B0, 3]))) 2154 test(op, (getter([B0]), getter([B0, 2, 3]))) 2155 test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1)) 2156 test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1) 2157 test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None)) 2158 test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None)) 2159 2160 # Nested vmap: op(Tensor, Tensor) 2161 test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3]))) 2162 test( 2163 vmap(op, in_dims=(None, 0)), 2164 (getter([B0, 2, 3]), getter([B1, 3])), 2165 in_dims=(0, None), 2166 ) 2167 2168 # test number as inputs 2169 number = getter([]).item() 2170 self._test_unary( 2171 lambda t: op(t, number), getter, "cpu", check_propagates_grad=False 2172 ) 2173 2174 def test_cross_batch_size_three(self): 2175 # Let's test corner case when batch_size is 3 and cross' dim argument is not specified 2176 # According to the cross API, dim will be assigned to the first dim with value 3 2177 # In this test we ensure that found dim is not batch dim. 2178 op = torch.cross 2179 test = self._vmap_test 2180 B0 = B1 = 3 2181 test(op, (torch.rand(B0, 2, 3), torch.rand(B0, 2, 3))) 2182 test( 2183 vmap(op, in_dims=(0, None)), 2184 (torch.rand(B0, B1, 2, 3), torch.rand(B0, B1, 2, 3)), 2185 in_dims=(None, 1), 2186 ) 2187 2188 def test_diagonal(self): 2189 tensor = torch.randn(3, 5, 7, 11, 13) 2190 test = self._vmap_view_test 2191 op = torch.diagonal 2192 test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None)) 2193 test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None)) 2194 test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None)) 2195 test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1) 2196 test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1) 2197 test( 2198 vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3), 2199 (tensor,), 2200 in_dims=1, 2201 out_dims=1, 2202 ) 2203 2204 def test_dot(self): 2205 op = torch.dot 2206 test = self._vmap_test 2207 B0, B1 = 7, 11 2208 2209 # shape mismatch 2210 msg = "" 2211 with self.assertRaisesRegex(RuntimeError, msg): 2212 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 2213 with self.assertRaisesRegex(RuntimeError, msg): 2214 vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) 2215 with self.assertRaisesRegex(RuntimeError, msg): 2216 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2)) 2217 2218 # left arg is vmapped 2219 test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None)) 2220 test( 2221 vmap(op, in_dims=(0, None)), 2222 (torch.rand(B1, B0, 5), torch.rand(5)), 2223 in_dims=(1, None), 2224 ) 2225 2226 # right arg is vmapped 2227 test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0)) 2228 test( 2229 vmap(op, in_dims=(None, 0)), 2230 (torch.rand(5), torch.rand(B1, B0, 5)), 2231 in_dims=(None, 1), 2232 ) 2233 2234 # both args are vmapped 2235 test(op, (torch.rand(B0, 5), torch.rand(B0, 5))) 2236 test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)) 2237 test( 2238 vmap(op, in_dims=(0, None)), 2239 (torch.rand(B1, 5), torch.rand(B0, 5)), 2240 in_dims=(None, 0), 2241 ) 2242 2243 def test_expand_as(self): 2244 op = torch.Tensor.expand_as 2245 test = self._vmap_view_test 2246 B0, B1, B2 = 7, 11, 13 2247 test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5))) 2248 test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None)) 2249 test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) 2250 test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5))) 2251 test( 2252 vmap(op), 2253 (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), 2254 in_dims=(0, 1), 2255 ) 2256 test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) 2257 test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5))) 2258 2259 def test_fill_and_zero_inplace(self): 2260 test = functools.partial(self._vmap_test, check_propagates_grad=False) 2261 B0, B1 = 7, 11 2262 ops = ( 2263 lambda t: t.fill_(0.1), 2264 lambda t: t.fill_(torch.tensor(0.2)), 2265 lambda t: t.zero_(), 2266 ) 2267 2268 for op in ops: 2269 # Single vmap, various in_dims / out_dims 2270 test(op, [TensorFactory.randn([B0, 3])]) 2271 test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2) 2272 test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) 2273 2274 # Doubly nested vmap 2275 test(vmap(op), [TensorFactory.randn([B0, B1])]) 2276 test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2) 2277 test( 2278 vmap(op, in_dims=2), 2279 [TensorFactory.randn([2, 5, B0, B1, 3])], 2280 in_dims=2, 2281 out_dims=2, 2282 ) 2283 2284 # test when value is a batched tensor for fill_ operator 2285 B0, B1 = 3, 5 2286 test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)]) 2287 2288 with self.assertRaisesRegex(RuntimeError, ""): 2289 # Runtime Error is thrown when the tensor being written to isn't being vmapped over 2290 vmap(Tensor.fill_, (None, 0))( 2291 TensorFactory.randn([B0, B1]), TensorFactory.randn([B0]) 2292 ) 2293 2294 def _test_complex_views(self, op, dtypes): 2295 test = self._vmap_view_test 2296 2297 def run_test(op, dtype): 2298 def get(shape): 2299 return torch.randn(shape, dtype=dtype) 2300 2301 B0, B1 = 7, 11 2302 2303 # Single vmap, various in_dims / out_dims 2304 test(op, [get([B0, 3])]) 2305 test(op, [get([3, B0])], in_dims=1) 2306 test(op, [get([2, 5, B0, 3])], in_dims=2) 2307 test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) 2308 2309 # Doubly nested vmap 2310 test(vmap(op), [get([B0, B1])]) 2311 test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4) 2312 test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2) 2313 2314 for dtype in dtypes: 2315 run_test(op, dtype) 2316 2317 def test_real(self): 2318 self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble]) 2319 2320 def test_imag(self): 2321 self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble]) 2322 2323 def test_view_as_real(self): 2324 self._test_complex_views( 2325 torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble] 2326 ) 2327 2328 def test_view_as_complex(self): 2329 def run_test(dtype): 2330 def get(shape): 2331 return torch.randn(shape, dtype=dtype) 2332 2333 op = torch.view_as_complex 2334 test = self._vmap_view_test 2335 B0, B1 = 7, 11 2336 2337 # Single vmap, various in_dims / out_dims 2338 test(op, [get([B0, 3, 2])]) 2339 test(op, [get([2, 5, B0, 3, 2])], in_dims=2) 2340 test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2) 2341 2342 # Doubly nested vmap 2343 test(vmap(op), [get([B0, B1, 2])]) 2344 test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2) 2345 test( 2346 vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], in_dims=2, out_dims=2 2347 ) 2348 2349 # Interesting case #1: Batch dim directly before dim of size 2 2350 test(op, [get([3, B0, 2])], in_dims=1) 2351 test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2) 2352 2353 # Interesting case #2: Batch dim at end of tensor, success cases 2354 # view_as_complex requires that the dim with size 2 have stride 1 2355 # in order for the view to function property 2356 test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1) 2357 test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)]) 2358 test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)]) 2359 2360 # Interesting case #3: Batch dim at end of tensor, failure cases 2361 msg = "Tensor must have a last dimension with stride 1" 2362 with self.assertRaisesRegex(RuntimeError, msg): 2363 vmap(op, in_dims=1)(get([2, B0])) 2364 with self.assertRaisesRegex(RuntimeError, msg): 2365 vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1])) 2366 2367 # Invalid input: no dimension of size 2 2368 msg = "Input tensor must have one or more dimensions" 2369 with self.assertRaisesRegex(RuntimeError, msg): 2370 vmap(op)(get([B0])) 2371 with self.assertRaisesRegex(RuntimeError, msg): 2372 vmap(vmap(op))(get([B0, B1])) 2373 2374 # Invalid input: Batch dim has size 2, but the logical last dim does 2375 # not have size 2 2376 msg = "Tensor must have a last dimension of size 2" 2377 with self.assertRaisesRegex(RuntimeError, msg): 2378 vmap(op, in_dims=1)(get([3, 2])) 2379 2380 for dtype in [torch.float, torch.double]: 2381 run_test(dtype) 2382 2383 def test_is_complex(self): 2384 ctensor = torch.randn(3, dtype=torch.cfloat) 2385 tensor = torch.randn(3) 2386 2387 def foo(x): 2388 if x.is_complex(): 2389 return torch.tensor(1) 2390 else: 2391 return torch.tensor(0) 2392 2393 self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1])) 2394 self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0])) 2395 2396 def test_is_floating_point(self): 2397 float_tensor = torch.tensor([1.0, 2.0, 3.0]) 2398 long_tensor = torch.tensor([1, 2, 3]) 2399 2400 def foo(x): 2401 if x.is_floating_point(): 2402 return torch.tensor(1) 2403 else: 2404 return torch.tensor(0) 2405 2406 self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1])) 2407 self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0])) 2408 2409 @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") 2410 def test_is_contiguous(self): 2411 def foo(x): 2412 if x.is_contiguous(): 2413 return torch.tensor(1.0) 2414 else: 2415 return torch.tensor(0.0) 2416 2417 B0, B1 = 3, 5 2418 2419 # Single batch dim 2420 contig = torch.randn(B0, 2, 7) 2421 self.assertEqual(vmap(foo)(contig), torch.ones(B0)) 2422 2423 noncontig = torch.randn(2, B0, 7) 2424 self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0)) 2425 2426 noncontig = torch.randn(2, B0, 7).movedim(1, 0) 2427 self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0)) 2428 2429 noncontig = torch.randn(2, 7, B0) 2430 self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0)) 2431 2432 # Multiple batch dims 2433 contig = torch.randn(B0, B1, 3) 2434 self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) 2435 2436 contig = torch.randn(B1, B0, 3) 2437 self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1)) 2438 2439 contig = torch.randn(B1, B0, 3).movedim(0, 1) 2440 self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) 2441 2442 noncontig = torch.randn(B0, 3, B1) 2443 self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1)) 2444 2445 # is_contiguous on empty tensor is True 2446 def bar(x): 2447 assert x.is_contiguous() 2448 return x 2449 2450 vmap(bar)(torch.randn(B0, 0, 3)) 2451 vmap(bar, in_dims=1)(torch.randn(0, B0, 3)) 2452 vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2)) 2453 2454 # is_contiguous with other memory formats 2455 def baz(x, memory_format): 2456 x.is_contiguous(memory_format=memory_format) 2457 return x 2458 2459 msg = "NYI: querying is_contiguous inside of vmap for memory_format" 2460 tensor = torch.randn(B0, 2, 7, 3) 2461 with self.assertRaisesRegex(RuntimeError, msg): 2462 vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor) 2463 with self.assertRaisesRegex(RuntimeError, msg): 2464 vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) 2465 2466 for mf in (torch.channels_last, torch.channels_last_3d): 2467 2468 @torch.compile(backend="eager", fullgraph=True) 2469 def f(x): 2470 if x.is_contiguous(memory_format=mf): 2471 return x.sin() 2472 return x.cos() 2473 2474 with self.assertRaisesRegex(RuntimeError, msg): 2475 vmap(f)(torch.randn(3, 3)) 2476 2477 def test_unsqueeze(self): 2478 op = torch.unsqueeze 2479 test = self._vmap_view_test 2480 B0, B1 = 7, 11 2481 2482 # unsqueeze dim 0 2483 test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None)) 2484 test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None)) 2485 2486 # unsqueeze last dim (positive) 2487 test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None)) 2488 test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None)) 2489 2490 # unsqueeze last dim (negative) 2491 test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None)) 2492 test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None)) 2493 2494 # nested vmaps 2495 def unsqueeze_0(x): 2496 return torch.unsqueeze(x, 0) 2497 2498 def unsqueeze_last(x): 2499 return torch.unsqueeze(x, -1) 2500 2501 # bdims in canonical order 2502 test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2),)) 2503 test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),)) 2504 2505 # wild bdims 2506 test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2) 2507 test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2) 2508 test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2) 2509 test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2) 2510 2511 def test_movedim(self): 2512 op = torch.movedim 2513 test = self._vmap_view_test 2514 B0, B1, B2 = 7, 11, 13 2515 2516 # movedim(tensor, int, int) variant 2517 test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None)) 2518 test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None)) 2519 test( 2520 vmap(op, in_dims=(0, None, None)), 2521 (torch.rand(B1, 2, B0, 5), 0, 1), 2522 in_dims=(2, None, None), 2523 ) 2524 test( 2525 vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), 2526 (torch.rand(B1, 2, B0, 5, B2), 0, 1), 2527 in_dims=(2, None, None), 2528 ) 2529 2530 # movedim(tensor, intlist, intlist) variant 2531 test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None)) 2532 test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None)) 2533 test( 2534 vmap(op, in_dims=(0, None, None)), 2535 (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), 2536 in_dims=(2, None, None), 2537 ) 2538 test( 2539 vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), 2540 (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), 2541 in_dims=(2, None, None), 2542 ) 2543 2544 def test_mm(self): 2545 op = torch.mm 2546 test = self._vmap_test 2547 B0, B1 = 7, 11 2548 2549 # shape mismatch 2550 msg = "Shape mismatch" 2551 with self.assertRaisesRegex(RuntimeError, msg): 2552 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 2553 with self.assertRaisesRegex(RuntimeError, msg): 2554 vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) 2555 with self.assertRaisesRegex(RuntimeError, msg): 2556 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) 2557 2558 # left arg is vmapped 2559 test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None)) 2560 test( 2561 vmap(op, in_dims=(0, None)), 2562 (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)), 2563 in_dims=(1, None), 2564 ) 2565 2566 # right arg is vmapped 2567 test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0)) 2568 test( 2569 vmap(op, in_dims=(None, 0)), 2570 (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)), 2571 in_dims=(None, 1), 2572 ) 2573 2574 # both args are vmapped 2575 test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2))) 2576 test( 2577 vmap(op), 2578 (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), 2579 in_dims=(1, 0), 2580 ) 2581 test( 2582 vmap(op, in_dims=(0, None)), 2583 (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), 2584 in_dims=(None, 0), 2585 ) 2586 2587 def test_mv(self): 2588 op = torch.mv 2589 test = self._vmap_test 2590 B0, B1 = 7, 11 2591 2592 # shape mismatch 2593 msg = "" 2594 with self.assertRaisesRegex(RuntimeError, msg): 2595 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 2596 with self.assertRaisesRegex(RuntimeError, msg): 2597 vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2)) 2598 with self.assertRaisesRegex(RuntimeError, msg): 2599 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2)) 2600 2601 # left arg is vmapped 2602 test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None)) 2603 test( 2604 vmap(op, in_dims=(0, None)), 2605 (torch.rand(B1, B0, 2, 5), torch.rand(5)), 2606 in_dims=(1, None), 2607 ) 2608 2609 # right arg is vmapped 2610 test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0)) 2611 test( 2612 vmap(op, in_dims=(None, 0)), 2613 (torch.rand(2, 5), torch.rand(B1, B0, 5)), 2614 in_dims=(None, 1), 2615 ) 2616 2617 # both args are vmapped 2618 test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5))) 2619 test( 2620 vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0) 2621 ) 2622 test( 2623 vmap(op, in_dims=(0, None)), 2624 (torch.rand(B1, 2, 5), torch.rand(B0, 5)), 2625 in_dims=(None, 0), 2626 ) 2627 2628 def test_narrow(self): 2629 op = torch.narrow 2630 test = self._vmap_view_test 2631 B0, B1, B2 = 7, 11, 13 2632 2633 test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None)) 2634 test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None)) 2635 test( 2636 vmap(op, in_dims=(0, None, None, None)), 2637 (torch.rand(B1, 2, B0, 5), 1, 0, 0), 2638 in_dims=(2, None, None, None), 2639 ) 2640 test( 2641 vmap( 2642 vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None) 2643 ), 2644 (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), 2645 in_dims=(2, None, None, None), 2646 ) 2647 2648 def test_new_empty(self): 2649 # Empty is non-deterministic so we just check that the shape of the 2650 # output tensor is what we expect and that the vmap fallback isn't used. 2651 op = Tensor.new_empty 2652 2653 B0, B1 = 7, 11 2654 2655 result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0)) 2656 self.assertEqual(result.shape, [B0, 2, 3]) 2657 2658 result = vmap(lambda x: op(x, []))(torch.randn(B0)) 2659 self.assertEqual(result.shape, [B0]) 2660 2661 result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1)) 2662 self.assertEqual(result.shape, [B0, B1, 2, 3]) 2663 2664 def test_new_empty_strided(self): 2665 # Empty is non-deterministic so we just check that the size and shape 2666 # of the output are what we expect and that the vmap fallback isn't used 2667 B0, B1 = 7, 11 2668 2669 def _test_single_vmap(size, stride, B0): 2670 x = torch.randn(B0) 2671 result = vmap(lambda x: x.new_empty_strided(size, stride))(x) 2672 S = torch.empty_strided(size, stride).storage().size() 2673 self.assertEqual(result.shape, [B0] + size) 2674 self.assertEqual(result.stride(), [S] + stride) 2675 2676 def _test_double_vmap(size, stride, B0, B1): 2677 x = torch.randn(B0, B1) 2678 result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x) 2679 S = torch.empty_strided(size, stride).storage().size() 2680 self.assertEqual(result.shape, [B0, B1] + size) 2681 self.assertEqual(result.stride(), [B1 * S, S] + stride) 2682 2683 x = torch.randn(B1, B0) 2684 result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)( 2685 x 2686 ) 2687 S = x.new_empty_strided(size, stride).storage().size() 2688 self.assertEqual(result.shape, [B0, B1] + size) 2689 self.assertEqual(result.stride(), [B1 * S, S] + stride) 2690 2691 # contiguous case 2692 _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0) 2693 _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1) 2694 2695 # expanded 2696 _test_single_vmap([2, 3, 5], [0, 5, 1], B0) 2697 _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1) 2698 2699 # some of these cases are pretty strange, just verifying that if 2700 # empty_strided allows them then BatchedTensor.new_empty_strided 2701 # can as well 2702 for shape in [[2, 3, 4], [0, 2, 0]]: 2703 for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]: 2704 _test_single_vmap(shape, strides, B0) 2705 _test_double_vmap(shape, strides, B0, B1) 2706 2707 def test_new_zeros(self): 2708 op = Tensor.new_zeros 2709 test = functools.partial(self._vmap_test, check_propagates_grad=False) 2710 B0, B1 = 7, 11 2711 2712 test(lambda x: op(x, 2, 3), (torch.rand(B0),)) 2713 test(lambda x: op(x, []), (torch.rand(B0),)) 2714 test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),)) 2715 2716 def test_select(self): 2717 op = torch.select 2718 test = self._vmap_view_test 2719 B0, B1, B2 = 7, 11, 13 2720 test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None)) 2721 test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None)) 2722 test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2723 test( 2724 vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), 2725 (torch.rand(B1, 2, B0, B2, 5),), 2726 in_dims=2, 2727 ) 2728 2729 def test_roll_no_dims(self): 2730 op = torch.roll 2731 test = self._vmap_test 2732 B0, B1, B2 = 7, 11, 13 2733 test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None)) 2734 test(op, (torch.rand(2, B0, 5), 3), in_dims=(1, None)) 2735 test(vmap(lambda t: op(t, 3)), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2736 test( 2737 vmap(vmap(lambda t: op(t, 3), in_dims=1)), 2738 (torch.rand(B1, 2, B0, B2, 5),), 2739 in_dims=2, 2740 ) 2741 2742 def test_stack(self): 2743 test = self._vmap_test 2744 B0, B1 = 5, 7 2745 2746 # Quick hack b/c vmap can't accept a list of tensors as an argument 2747 def get_op(dim): 2748 def op(*tensors): 2749 return torch.stack(tensors, dim=dim) 2750 2751 return op 2752 2753 test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3))) 2754 test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0)) 2755 test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) 2756 test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) 2757 test( 2758 vmap(get_op(0), in_dims=(0, None)), 2759 (torch.rand(B1, 2), torch.rand(B0, 2)), 2760 in_dims=(None, 0), 2761 ) 2762 test( 2763 vmap(get_op(0), in_dims=(0, 0)), 2764 (torch.rand(B1, 2), torch.rand(B0, B1, 2)), 2765 in_dims=(None, 0), 2766 ) 2767 2768 def test_slice(self): 2769 test = self._vmap_view_test 2770 B0, B1, B2 = 7, 11, 13 2771 test(lambda t: t[0:1], (torch.rand(B0, 3, 5),)) 2772 test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2) 2773 test( 2774 vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2 2775 ) 2776 test( 2777 vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), 2778 (torch.rand(3, 5, B0, B1, B2),), 2779 in_dims=2, 2780 ) 2781 2782 @xfailIfTorchDynamo 2783 def test_squeeze(self): 2784 def verify_behavior(op, min_ndim=1): 2785 test = self._vmap_view_test 2786 B0, B1 = 1, 11 2787 # These tests cannot be used with an operator that requires more 2788 # than 1 dimension after batching. 2789 if min_ndim <= 1: 2790 test(op, (torch.rand(B0),)) 2791 test(op, (torch.rand(B1),)) 2792 test(vmap(op), (torch.rand(B0, B1, 1),)) 2793 test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2) 2794 test(op, (torch.rand(B0, 3, 5),)) 2795 test(op, (torch.rand(1, B0, 5),), in_dims=1) 2796 test(op, (torch.rand(B0, 0, 1, 5, 1),)) 2797 test(op, (torch.rand(B0, 1, 1, 1, 1),)) 2798 test(vmap(op), (torch.rand(B0, B1, 1, 3, 4),)) 2799 test(vmap(op), (torch.rand(B1, 1, B0, 4, 5),), in_dims=2) 2800 2801 verify_behavior(torch.squeeze) 2802 verify_behavior(lambda x: torch.squeeze(x, dim=0), min_ndim=1) 2803 verify_behavior(lambda x: torch.squeeze(x, dim=1), min_ndim=2) 2804 verify_behavior(lambda x: torch.squeeze(x, dim=-1), min_ndim=2) 2805 verify_behavior(lambda x: torch.squeeze(x, dim=-2), min_ndim=3) 2806 2807 msg = "" 2808 try: 2809 torch.squeeze(torch.rand(10), dim=1) 2810 except IndexError as err: 2811 msg = str(err) 2812 with self.assertRaises(RuntimeError, msg=msg): 2813 vmap(lambda x: torch.squeeze(x, dim=1))(torch.rand(10)) 2814 2815 def _test_mean_sum_dim(self, op): 2816 test = self._vmap_test 2817 B0, B1 = 5, 7 2818 2819 # Single vmap, various in_dims / out_dims 2820 test(lambda x: op(x, 0), [torch.randn([B0])]) 2821 test(lambda x: op(x, -1), [torch.randn([B0])]) 2822 test(lambda x: op(x, 0), [torch.randn([B0, 3])]) 2823 test(lambda x: op(x, -1), [torch.randn([2, 5, B0, 3])], in_dims=2) 2824 test(lambda x: op(x, 2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) 2825 2826 # Doubly nested vmap 2827 test(vmap(lambda x: op(x, 0)), [torch.randn([B0, B1])]) 2828 test(vmap(lambda x: op(x, -1)), [torch.randn([B0, B1])]) 2829 test(vmap(lambda x: op(x, -2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2) 2830 test( 2831 vmap(lambda x: op(x, 2), in_dims=2), 2832 [torch.randn([2, 5, B0, B1, 3])], 2833 in_dims=2, 2834 out_dims=2, 2835 ) 2836 2837 def test_sum_dim(self): 2838 self._test_mean_sum_dim(torch.sum) 2839 2840 def test_mean_dim(self): 2841 self._test_mean_sum_dim(torch.mean) 2842 2843 def test_argmax_dim(self): 2844 def test(f, args): 2845 for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}): 2846 self.assertEqual(loop_out, batched_out) 2847 2848 B0 = 5 2849 test(lambda x: torch.argmax(x), [torch.randn(B0)]) 2850 test(lambda x: torch.argmax(x), [torch.randn(B0, 2, 3)]) 2851 test(lambda x: torch.argmax(x, 0), [torch.randn(B0, 2, 3)]) 2852 test(lambda x: torch.argmax(x, -1), [torch.randn(B0, 2, 3)]) 2853 test(lambda x: torch.argmax(x, 2), [torch.randn(B0, 2, 3)]) 2854 2855 def _test_sum_mean(self, op): 2856 test = self._vmap_test 2857 B0, B1 = 5, 7 2858 2859 # Single vmap, various in_dims / out_dims 2860 test(op, [torch.randn([B0])]) 2861 test(op, [torch.randn([B0, 3])]) 2862 test(op, [torch.randn([2, 5, B0, 3])], in_dims=2) 2863 test(op, [torch.randn([2, 5, B0, 3])], in_dims=2) 2864 2865 # Doubly nested vmap 2866 test(vmap(op), [torch.randn([B0, B1])]) 2867 test(vmap(op), [torch.randn([B1, 2, 5, B0, 3])]) 2868 test(vmap(op), [torch.randn([2, 5, B0, B1, 3])], in_dims=2) 2869 2870 def test_sum(self): 2871 self._test_sum_mean(torch.sum) 2872 2873 def test_mean(self): 2874 self._test_sum_mean(torch.mean) 2875 2876 def test_repeat(self): 2877 test = self._vmap_test 2878 B0 = 7 2879 op = Tensor.repeat 2880 test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),)) 2881 test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1) 2882 2883 @skipIfTorchDynamo() 2884 def test_slogdet(self): 2885 test = functools.partial(self._vmap_test, check_propagates_grad=False) 2886 B0 = 7 2887 op = torch.linalg.slogdet 2888 test(op, (torch.rand(B0, 1, 1),)) 2889 test(op, (torch.rand(B0, 2, 2),)) 2890 test(op, (torch.rand(B0, 3, 2, 2),)) 2891 test(op, (torch.rand(3, 2, 2, B0),), in_dims=3) 2892 2893 def test_reshape(self): 2894 test = self._vmap_test 2895 B0, B1, B2 = 7, 11, 13 2896 op = torch.reshape 2897 test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True) 2898 test( 2899 op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False 2900 ) 2901 test( 2902 vmap(lambda t: t.reshape([-1])), 2903 (torch.rand(B0, B1, 2, 5),), 2904 check_view=True, 2905 ) 2906 test( 2907 vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1), 2908 (torch.rand(3, B1, 2, B2, 5, B0),), 2909 in_dims=5, 2910 check_view=False, 2911 ) 2912 2913 def test_reshape_as(self): 2914 test = self._vmap_test 2915 B0, B1, B2 = 7, 11, 13 2916 op = torch.Tensor.reshape_as 2917 test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True) 2918 test( 2919 op, 2920 (torch.rand(2 * 5), torch.rand(B0, 2, 5)), 2921 in_dims=(None, 0), 2922 check_view=True, 2923 ) 2924 test( 2925 op, 2926 (torch.rand(B0, 2 * 5), torch.rand(2, 5)), 2927 in_dims=(0, None), 2928 check_view=True, 2929 ) 2930 2931 test( 2932 op, 2933 (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), 2934 in_dims=(1, None), 2935 check_view=False, 2936 ) 2937 2938 test( 2939 vmap(op), 2940 (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), 2941 check_view=True, 2942 ) 2943 test( 2944 vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)), 2945 (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)), 2946 in_dims=(5, 0), 2947 check_view=False, 2948 ) 2949 2950 def test_result_type(self): 2951 def scalar_tensor_with_dtype(op): 2952 def wrapped(*args, **kwargs): 2953 dtype = op(*args, **kwargs) 2954 return torch.ones([], dtype=dtype) 2955 2956 return wrapped 2957 2958 test = self._vmap_test 2959 op = scalar_tensor_with_dtype(torch.result_type) 2960 2961 B0 = 2 2962 2963 test( 2964 op, 2965 (torch.randn(B0), torch.randn(B0, dtype=torch.float64)), 2966 check_propagates_grad=False, 2967 ) 2968 test( 2969 op, 2970 (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)), 2971 check_propagates_grad=False, 2972 ) 2973 2974 test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False) 2975 test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False) 2976 2977 test( 2978 lambda x: op(x, torch.tensor(1)), 2979 (torch.randn(B0),), 2980 check_propagates_grad=False, 2981 ) 2982 test( 2983 lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), 2984 (torch.randn(B0),), 2985 check_propagates_grad=False, 2986 ) 2987 2988 test( 2989 op, 2990 (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)), 2991 check_propagates_grad=False, 2992 ) 2993 test( 2994 op, 2995 (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)), 2996 check_propagates_grad=False, 2997 ) 2998 2999 test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False) 3000 test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False) 3001 3002 test( 3003 lambda x: op(x, torch.tensor(1)), 3004 (torch.randn(B0, 2),), 3005 check_propagates_grad=False, 3006 ) 3007 test( 3008 lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), 3009 (torch.randn(B0, 2),), 3010 check_propagates_grad=False, 3011 ) 3012 3013 test( 3014 op, 3015 (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)), 3016 check_propagates_grad=False, 3017 ) 3018 test( 3019 op, 3020 (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)), 3021 check_propagates_grad=False, 3022 ) 3023 3024 def test_tensor_split(self): 3025 test = self._vmap_view_test 3026 op = torch.tensor_split 3027 B0, B1, B2 = 7, 11, 13 3028 3029 # tests for torch.tensor_split(self, indices_or_sections: int, dim) 3030 test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None)) 3031 test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None)) 3032 test( 3033 vmap(op, in_dims=(0, None, None)), 3034 (torch.rand(B1, 1023, B0, 5), 256, 0), 3035 in_dims=(2, None, None), 3036 ) 3037 test( 3038 vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 3039 (torch.rand(B1, 2, B0, 64, B2),), 3040 in_dims=2, 3041 ) 3042 3043 # tests for torch.tensor_split(self, indices_or_sections: List[int], dim) 3044 test( 3045 op, 3046 (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), 3047 in_dims=(0, None, None), 3048 ) 3049 test( 3050 op, 3051 (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), 3052 in_dims=(1, None, None), 3053 ) 3054 test( 3055 vmap(op, in_dims=(0, None, None)), 3056 (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0), 3057 in_dims=(2, None, None), 3058 ) 3059 test( 3060 vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)), 3061 (torch.rand(B1, 2, B0, 64, B2),), 3062 in_dims=2, 3063 ) 3064 3065 @skipIfTorchDynamo("really slow") 3066 def test_split(self): 3067 test = self._vmap_view_test 3068 op = torch.split 3069 B0, B1, B2 = 7, 11, 13 3070 3071 # tests for torch.split(self, split_size: int, dim) 3072 test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None)) 3073 test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None)) 3074 test( 3075 vmap(op, in_dims=(0, None, None)), 3076 (torch.rand(B1, 1023, B0, 5), 256, 0), 3077 in_dims=(2, None, None), 3078 ) 3079 test( 3080 vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 3081 (torch.rand(B1, 2, B0, 64, B2),), 3082 in_dims=2, 3083 ) 3084 3085 # tests for torch.split(self, split_size: List[int], dim) 3086 test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None)) 3087 test( 3088 op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None) 3089 ) 3090 test( 3091 vmap(op, in_dims=(0, None, None)), 3092 (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0), 3093 in_dims=(2, None, None), 3094 ) 3095 test( 3096 vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), 3097 (torch.rand(B1, 2, B0, 64, B2),), 3098 in_dims=2, 3099 ) 3100 3101 def test_trace(self): 3102 op = torch.trace 3103 test = self._vmap_test 3104 B0, B1, B2 = 7, 11, 13 3105 test(op, (torch.rand(B0, 2, 5),)) 3106 test(op, (torch.rand(2, B0, 5),), in_dims=1) 3107 test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 3108 test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) 3109 3110 def test_transpose(self): 3111 op = torch.transpose 3112 test = self._vmap_view_test 3113 3114 B0, B1, B2 = 7, 11, 13 3115 test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),)) 3116 test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),)) 3117 test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),)) 3118 test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1) 3119 test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) 3120 test( 3121 vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)), 3122 (torch.rand(B1, 2, B0, 5, B2),), 3123 in_dims=2, 3124 ) 3125 3126 # Special case: scalar tensor 3127 for dim1, dim2 in itertools.product([0, -1], [0, -1]): 3128 x = torch.rand(B0) 3129 result = vmap(lambda x: op(x, dim1, dim2))(x) 3130 self.assertTrue(result is x) 3131 3132 def test_t(self): 3133 op = torch.t 3134 test = self._vmap_view_test 3135 B0, B1, B2 = 7, 11, 13 3136 test(op, (torch.rand(B0, 2, 5),)) 3137 test(op, (torch.rand(2, B0, 5),), in_dims=1) 3138 test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 3139 test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) 3140 3141 def test_T_numpy(self): 3142 def op(t): 3143 return t.T 3144 3145 test = self._vmap_view_test 3146 B0, B1, B2 = 7, 11, 13 3147 test(op, (torch.rand(B0, 2, 3, 5),)) 3148 test(op, (torch.rand(2, B0, 3, 5),), in_dims=1) 3149 test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 3150 test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2) 3151 test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2) 3152 3153 def test_to(self): 3154 test = self._vmap_test 3155 B0, B1 = 7, 11 3156 3157 test(lambda t: t.to("cpu"), (torch.rand(B0),)) 3158 test(lambda t: t.to(torch.double), (torch.rand(B0),)) 3159 test( 3160 lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)) 3161 ) 3162 test( 3163 lambda t, o: t.to(o), 3164 (torch.rand(B0), torch.randn(B0, dtype=torch.float64)), 3165 in_dims=(0, None), 3166 ) 3167 test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),)) 3168 3169 # also test some casting methods 3170 test(lambda t: t.double(), (torch.rand(B0),)) 3171 test(lambda t: t.float(), (torch.rand(B0),)) 3172 test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False) 3173 test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False) 3174 3175 def test_unfold(self): 3176 op = torch.Tensor.unfold 3177 test = self._vmap_view_test 3178 B0, B1, B2 = 3, 2, 5 3179 3180 test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None)) 3181 test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None)) 3182 test( 3183 vmap(op, in_dims=(0, None, None, None)), 3184 (torch.rand(B1, 7, B0, 11), 1, 5, 1), 3185 in_dims=(2, None, None, None), 3186 ) 3187 test( 3188 vmap( 3189 vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None) 3190 ), 3191 (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), 3192 in_dims=(2, None, None, None), 3193 ) 3194 3195 def test_unbind(self): 3196 test = self._vmap_view_test 3197 op = torch.unbind 3198 B0, B1, B2 = 7, 11, 13 3199 3200 test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None)) 3201 test(op, (torch.rand(B0, 2, 0),)) 3202 test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None)) 3203 test( 3204 vmap(op, in_dims=(0, None)), 3205 (torch.rand(B1, 1023, B0, 5), 1), 3206 in_dims=(2, None), 3207 ) 3208 test( 3209 vmap(vmap(lambda t: op(t, dim=1), in_dims=2)), 3210 (torch.rand(B1, 2, B0, 32, B2),), 3211 in_dims=2, 3212 ) 3213 3214 def test_view(self): 3215 test = self._vmap_view_test 3216 B0, B1, B2 = 7, 11, 13 3217 op = torch.Tensor.view 3218 3219 # We should error out if the view would produce an incorrect result 3220 with self.assertRaises(RuntimeError): 3221 vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10]) 3222 3223 test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None)) 3224 test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None)) 3225 test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),)) 3226 test( 3227 vmap(vmap(lambda t: t.reshape([-1])), in_dims=1), 3228 (torch.rand(B2, B0, B1, 3, 2, 5),), 3229 in_dims=1, 3230 ) 3231 3232 def test_view_as(self): 3233 test = self._vmap_view_test 3234 B0, B1, B2 = 7, 11, 13 3235 op = torch.Tensor.view_as 3236 3237 # We should error out if the view would produce an incorrect result 3238 with self.assertRaises(RuntimeError): 3239 vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10)) 3240 3241 test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5))) 3242 test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0)) 3243 test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None)) 3244 3245 test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None)) 3246 3247 test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10))) 3248 test( 3249 vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)), 3250 (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)), 3251 in_dims=(2, 0), 3252 ) 3253 3254 def test_conv2d(self): 3255 conv_setups = [ 3256 (torch.nn.Conv1d, torch.conv1d, [2, 4, 15]), 3257 (torch.nn.Conv2d, torch.conv2d, [2, 4, 15, 20]), 3258 (torch.nn.Conv3d, torch.conv3d, [2, 4, 15, 20, 25]), 3259 # (torch.nn.ConvTranspose2d, torch.conv_transpose2d, [2, 4, 15, 20]) 3260 ] 3261 for conv_mod, conv_fn, inp_shape in conv_setups: 3262 mod = conv_mod(4, 8, kernel_size=3) 3263 arg_values = [torch.randn(inp_shape), mod.weight, mod.bias] 3264 kwarg_values = {} 3265 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 3266 conv_fn, arg_values, kwarg_values 3267 ): 3268 self.assertEqual(loop_out, batched_out) 3269 3270 arg_values = [torch.randn(inp_shape), mod.weight, None] 3271 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 3272 conv_fn, arg_values, kwarg_values 3273 ): 3274 self.assertEqual(loop_out, batched_out) 3275 3276 mod2 = conv_mod( 3277 4, 8, kernel_size=3, groups=2, stride=3, padding=1, dilation=2 3278 ) 3279 arg_values = [torch.randn(inp_shape), mod2.weight, mod2.bias] 3280 kwarg_values = dict(groups=2, stride=3, padding=1, dilation=2) 3281 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 3282 conv_fn, arg_values, kwarg_values 3283 ): 3284 self.assertEqual(loop_out, batched_out) 3285 3286 arg_values = [torch.randn(inp_shape), mod2.weight, None] 3287 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 3288 conv_fn, arg_values, kwarg_values 3289 ): 3290 self.assertEqual(loop_out, batched_out) 3291 3292 def test_one_hot(self): 3293 sample_inputs = [ 3294 (torch.randint(0, 3, []), 3), 3295 (torch.randint(0, 3, [2, 3, 4]), 4), 3296 ] 3297 for args in sample_inputs: 3298 for loop_out, batched_out in get_fallback_and_vmap_exhaustive( 3299 F.one_hot, args, {} 3300 ): 3301 self.assertEqual(loop_out, batched_out) 3302 3303 def test_conj_bit(self): 3304 x = torch.tensor([1 + 1j, 2 + 1j]) 3305 3306 def foo(x): 3307 assert not x.is_conj() 3308 y = x.conj() 3309 assert y.is_conj() 3310 return y 3311 3312 res = vmap(foo)(x) 3313 self.assertEqual(res, x.conj()) 3314 3315 def test_mode_key(self): 3316 def vmap_f(x): 3317 return x + torch.randn(()) 3318 3319 def naive_f(x, shape): 3320 return x + torch.randn(shape) 3321 3322 torch.manual_seed(0) 3323 out1 = vmap(vmap(vmap_f, randomness="different"), randomness="different")( 3324 torch.ones(2, 3) 3325 ) 3326 3327 torch.manual_seed(0) 3328 out2 = naive_f(torch.ones(2, 3), (2, 3)) 3329 self.assertEqual(out1, out2) 3330 3331 torch.manual_seed(0) 3332 out1 = vmap(vmap(vmap_f, randomness="different"), randomness="different")( 3333 torch.ones(2, 3, 4) 3334 ) 3335 3336 torch.manual_seed(0) 3337 out2 = naive_f(torch.ones(2, 3, 4), (2, 3, 1)) 3338 self.assertEqual(out1, out2) 3339 3340 self.assertTrue(torch.randn(()).dim() == 0) 3341 3342 @parametrize("in_dim", [0, 1, 2]) 3343 @parametrize("out_dim", [0, 1, 2]) 3344 @parametrize("randomness", ["error", "same"]) 3345 def test_chunk_vmap(self, in_dim, out_dim, randomness): 3346 x = torch.randn(4, 5, 6) 3347 3348 def f(x): 3349 y = x.sin() 3350 if randomness != "error": 3351 y = y + torch.rand_like(x) 3352 return y 3353 3354 rs = torch.get_rng_state() 3355 expected = vmap(f, in_dims=in_dim, out_dims=out_dim, randomness=randomness)(x) 3356 3357 for chunks in [1, 2, 3, 4, 7, 10, 16]: 3358 torch.set_rng_state(rs) 3359 output = chunk_vmap( 3360 f, 3361 in_dims=in_dim, 3362 out_dims=out_dim, 3363 randomness=randomness, 3364 chunks=chunks, 3365 )(x) 3366 self.assertEqual(output, expected) 3367 3368 @parametrize("in_dim", [0, 1, 2]) 3369 @parametrize("out_dim", [0, 1, 2]) 3370 @parametrize("randomness", ["error", "same"]) 3371 def test_vmap_chunksize(self, in_dim, out_dim, randomness): 3372 x = torch.randn(4, 5, 6) 3373 y = torch.randn_like(x) 3374 3375 # fn: Single Input/Single Output 3376 def f(x): 3377 y = x.sin() 3378 if randomness != "error": 3379 y = y + torch.rand_like(x) 3380 return y 3381 3382 f_args = (x,) 3383 f_kwargs = {"in_dims": in_dim, "out_dims": out_dim, "randomness": randomness} 3384 3385 # fn: Nested Input/Single Output 3386 def f1(pair): 3387 x, y = pair 3388 z = x.sin() + y.cos() 3389 if randomness != "error": 3390 z = z + torch.rand_like(z) 3391 return z 3392 3393 f1_args = ((x, y),) 3394 f1_kwargs = { 3395 "in_dims": ((in_dim,) * 2,), 3396 "out_dims": out_dim, 3397 "randomness": randomness, 3398 } 3399 3400 # fn: Single Input/Nested Output 3401 def f2(x): 3402 y = x.sin() 3403 if randomness != "error": 3404 y = y + torch.rand_like(x) 3405 return {"out": y, "out1": y + 2} 3406 3407 f2_args = (x,) 3408 f2_kwargs = {"in_dims": in_dim, "out_dims": out_dim, "randomness": randomness} 3409 3410 # fn: Nested Input/Nested Output (first tensor is not vmapped). 3411 def f3(inp_dict): 3412 x = inp_dict["inp"] 3413 y = inp_dict["inp1"] 3414 z = x.sin() + y.cos() 3415 if randomness != "error": 3416 z = z + torch.rand_like(z) 3417 return {"z": z, "tuple": (z, z + 1)} 3418 3419 f3_args = ( 3420 { 3421 "inp": x.index_select(in_dim, torch.tensor([0])).squeeze(in_dim), 3422 "inp1": y, 3423 }, 3424 ) 3425 f3_kwargs = { 3426 "in_dims": ({"inp": None, "inp1": in_dim},), 3427 "out_dims": out_dim, 3428 "randomness": randomness, 3429 } 3430 3431 # fn: Nested Input/Nested Output (first argument is not a Tensor). 3432 def f4(inp_dict): 3433 x = inp_dict["inp"] 3434 y = inp_dict["inp1"] 3435 z = x + y.cos() 3436 if randomness != "error": 3437 z = z + torch.rand_like(z) 3438 return {"z": z, "tuple": (z, z + 1)} 3439 3440 f4_args = ({"inp": 2.0, "inp1": y},) 3441 f4_kwargs = { 3442 "in_dims": ({"inp": None, "inp1": in_dim},), 3443 "out_dims": out_dim, 3444 "randomness": randomness, 3445 } 3446 3447 fns_and_args = ( 3448 (f, f_args, f_kwargs), 3449 (f1, f1_args, f1_kwargs), 3450 (f2, f2_args, f2_kwargs), 3451 (f3, f3_args, f3_kwargs), 3452 (f4, f4_args, f4_kwargs), 3453 ) 3454 for fn, args, kwargs in fns_and_args: 3455 rs = torch.get_rng_state() 3456 expected_vmap = vmap(fn, **kwargs)(*args) 3457 for chunk_size in (1, 2, 3, 4, 7, 10, 16, 100): 3458 torch.set_rng_state(rs) 3459 output = vmap(fn, chunk_size=chunk_size, **kwargs)(*args) 3460 self.assertEqual(output, expected_vmap) 3461 3462 @parametrize("in_dim", [0, 1]) 3463 @parametrize("out_dim", [0, 1]) 3464 @parametrize("randomness", ["error", "same"]) 3465 def test_vmap_chunksize_error(self, in_dim, out_dim, randomness): 3466 x = torch.randn(4, 5, 6) 3467 3468 def f(x): 3469 y = x.sin() 3470 if randomness != "error": 3471 y = y + torch.rand_like(x) 3472 return y 3473 3474 # Incorrect `chunk_size` 3475 for chunk_size in (-1, 0): 3476 with self.assertRaisesRegex( 3477 ValueError, "vmap: chunk_size should be None or greater than 0." 3478 ): 3479 vmap( 3480 f, 3481 in_dims=in_dim, 3482 out_dims=out_dim, 3483 randomness=randomness, 3484 chunk_size=chunk_size, 3485 )(x) 3486 3487 # Incorrect `out_dims` 3488 msg = "out_dims is not compatible with the structure of `outputs`" 3489 with self.assertRaisesRegex(ValueError, msg): 3490 vmap( 3491 f, 3492 in_dims=in_dim, 3493 out_dims=(out_dim, out_dim), 3494 randomness=randomness, 3495 chunk_size=2, 3496 )(x) 3497 3498 @parametrize("in_dim", [0, 1]) 3499 @parametrize("out_dim", [0, 1]) 3500 @parametrize("randomness", ["error", "same"]) 3501 def test_vmap_chunksize_composition(self, in_dim, out_dim, randomness): 3502 x = torch.randn(4, 5, 6) 3503 y = torch.randn_like(x) 3504 3505 # fn: Single Input/Single Output 3506 def f(x): 3507 y = x.sin() 3508 if randomness != "error": 3509 y = y + torch.rand_like(x) 3510 return y 3511 3512 f_args = (x,) 3513 3514 # fn: Nested Input/Single Output 3515 def f1(pair): 3516 x, y = pair 3517 z = x.sin() + y.cos() 3518 if randomness != "error": 3519 z = z + torch.rand_like(z) 3520 return z 3521 3522 f1_args = ((x, y),) 3523 3524 # fn: Single Input/Nested Output 3525 def f2(x): 3526 y = x.sin() 3527 if randomness != "error": 3528 y = y + torch.rand_like(x) 3529 return {"out": y, "out1": y + 2} 3530 3531 f2_args = (x,) 3532 3533 # fn: Nested Input/Nested Output 3534 def f3(inp_dict): 3535 x = inp_dict["inp"] 3536 y = inp_dict["inp1"] 3537 z = x.sin() + y.cos() 3538 if randomness != "error": 3539 z = z + torch.rand_like(z) 3540 return {"z": z, "tuple": (z, z + 1)} 3541 3542 f3_args = ({"inp": x, "inp1": y},) 3543 3544 for fn, args in ((f, f_args), (f1, f1_args), (f2, f2_args), (f3, f3_args)): 3545 rs = torch.get_rng_state() 3546 expected = vmap( 3547 vmap(fn, in_dims=in_dim, out_dims=out_dim, randomness=randomness), 3548 in_dims=in_dim, 3549 out_dims=out_dim, 3550 randomness=randomness, 3551 )(*args) 3552 for chunk_size in (1, 2, 3, 4, 7, 10, 16, 100): 3553 torch.set_rng_state(rs) 3554 actual = vmap( 3555 vmap( 3556 fn, 3557 in_dims=in_dim, 3558 out_dims=out_dim, 3559 randomness=randomness, 3560 chunk_size=chunk_size, 3561 ), 3562 in_dims=in_dim, 3563 out_dims=out_dim, 3564 randomness=randomness, 3565 chunk_size=chunk_size, 3566 )(*args) 3567 self.assertEqual(actual, expected) 3568 3569 3570instantiate_parametrized_tests(TestVmapOperators) 3571 3572 3573def construct_v(output, batch_size, contig=False): 3574 if contig: 3575 return torch.randn( 3576 batch_size, *output.shape, dtype=output.dtype, device=output.device 3577 ) 3578 result = torch.randn( 3579 *output.shape, batch_size, dtype=output.dtype, device=output.device 3580 ) 3581 return result.movedim(-1, 0) 3582 3583 3584def as_tuple(x): 3585 if isinstance(x, tuple): 3586 return x 3587 elif isinstance(x, list): 3588 return tuple(x) 3589 else: 3590 return (x,) 3591 3592 3593def differentiable(args): 3594 return tuple( 3595 arg 3596 for arg in as_tuple(args) 3597 if isinstance(arg, torch.Tensor) and arg.requires_grad 3598 ) 3599 3600 3601def _get_rand_no_zeros(*args, **kwargs): 3602 requires_grad = kwargs.get("requires_grad", False) 3603 kwargs_without_requires_grad = kwargs.copy() 3604 kwargs_without_requires_grad["requires_grad"] = False 3605 result = torch.rand(*args, **kwargs_without_requires_grad) 3606 return result.clamp_min_(0.1).requires_grad_(requires_grad) 3607 3608 3609@markDynamoStrictTest 3610class TestVmapBatchedGradient(Namespace.TestVmapBase): 3611 def _vmap_test(self, *args, **kwargs): 3612 return _vmap_test(self, *args, **kwargs) 3613 3614 # Tests batched gradient computation of outputs = op(*args, **kwargs) 3615 # by comparing it to a sequential map+stack fallback. 3616 # 3617 # output_process_fn: a function that maps the outputs to the part 3618 # that should be differentiated. 3619 # batch_size: the batch dim size for the batched grad 3620 def _batched_grad_test( 3621 self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3 3622 ): 3623 if kwargs is None: 3624 kwargs = {} 3625 outputs = op(*args, **kwargs) 3626 outputs = differentiable(output_process_fn(outputs)) 3627 for contig in [True, False]: 3628 batched_vectors = tuple( 3629 construct_v(out, batch_size, contig) for out in outputs 3630 ) 3631 3632 def vector_jacobian_product(*vectors): 3633 return torch.autograd.grad( 3634 outputs, differentiable(args), vectors, retain_graph=True 3635 ) 3636 3637 self._vmap_test( 3638 vector_jacobian_product, batched_vectors, check_propagates_grad=False 3639 ) 3640 3641 # Tests batched second grad computation of outputs = op(*args, **kwargs). 3642 # by comparing it to a sequential map+stack fallback. 3643 # 3644 # output_process_fn: a function that maps the outputs to the part 3645 # that should be differentiated. 3646 # batch_size: the batch dim size for the batched grad 3647 # 3648 # NB: we only test computing batched gradients in the second gradient 3649 # computation. One specific use case that does this is computing the hessian 3650 # matrix of a scalar-valued function; this is useful in Bayesian Logistic 3651 # Regression. 3652 # It might be useful to have a test that computes batched first gradients and 3653 # then uses those to compute batched second gradients in the future. 3654 def _batched_grad_grad_test( 3655 self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3 3656 ): 3657 if kwargs is None: 3658 kwargs = {} 3659 outputs = op(*args, **kwargs) 3660 outputs = differentiable(output_process_fn(outputs)) 3661 ones = tuple(torch.ones_like(out) for out in outputs) 3662 # Same thing as summing together all of the outputs and calling .backward() 3663 first_grads = torch.autograd.grad( 3664 outputs, differentiable(args), ones, create_graph=True 3665 ) 3666 first_grads = differentiable(first_grads) 3667 self.assertNotEqual( 3668 len(first_grads), 0, "None of the first grads depend on the input!" 3669 ) 3670 3671 for contig in [True, False]: 3672 batched_vectors = tuple( 3673 construct_v(grad, batch_size, contig) for grad in first_grads 3674 ) 3675 3676 def vector_hessian_product(*vectors): 3677 outputs = torch.autograd.grad( 3678 first_grads, 3679 differentiable(args), 3680 vectors, 3681 retain_graph=True, 3682 allow_unused=True, 3683 ) 3684 outputs = tuple(out for out in outputs if out is not None) 3685 assert len(outputs) > 0 3686 return outputs 3687 3688 self._vmap_test( 3689 vector_hessian_product, batched_vectors, check_propagates_grad=False 3690 ) 3691 3692 def _test_arithmetic(self, op, device, test_grad_grad=True): 3693 x = torch.randn(2, 3, requires_grad=True, device=device) 3694 y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 3695 scalar = 3.14 3696 self._batched_grad_test(op, (x, y)) 3697 self._batched_grad_test(op, (scalar, y)) 3698 self._batched_grad_test(op, (x, scalar)) 3699 3700 if test_grad_grad: 3701 self._batched_grad_grad_test(op, (x, y)) 3702 3703 def test_add(self, device): 3704 self._test_arithmetic(torch.add, device, test_grad_grad=False) 3705 self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False) 3706 3707 def test_sub(self, device): 3708 self._test_arithmetic(torch.sub, device, test_grad_grad=False) 3709 self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False) 3710 3711 def test_mul(self, device): 3712 self._test_arithmetic(torch.mul, device) 3713 self._test_arithmetic(lambda x, y: x * y, device) 3714 3715 def test_div(self, device): 3716 self._test_arithmetic(torch.div, device) 3717 self._test_arithmetic(lambda x, y: x / y, device) 3718 3719 def test_binary_cross_entropy(self, device): 3720 x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True)) 3721 target = torch.rand(3, 2, device=device) 3722 3723 op = functools.partial(F.binary_cross_entropy, target=target) 3724 3725 self._batched_grad_test(op, (x,), {}) 3726 self._batched_grad_grad_test(op, (x,), {}) 3727 3728 def test_log_softmax(self, device): 3729 op = functools.partial(torch.log_softmax, dim=-1) 3730 x = torch.randn(3, 2, device=device, requires_grad=True) 3731 3732 self._batched_grad_test(op, (x,), {}) 3733 self._batched_grad_grad_test(op, (x,), {}) 3734 3735 def test_expand(self, device): 3736 x = torch.randn(2, 3, device=device, requires_grad=True) 3737 3738 def op(x): 3739 return x.expand(5, 5, 2, 3) 3740 3741 self._batched_grad_test(op, (x,)) 3742 3743 @allowVmapFallbackUsage 3744 def test_index(self, device): 3745 x = torch.randn(2, 3, requires_grad=True, device=device) 3746 index = torch.tensor([[0, 0], [1, 1]], device=device) 3747 3748 def op(x): 3749 y = x * x 3750 return y[index] 3751 3752 self._batched_grad_test(op, (x,)) 3753 self._batched_grad_grad_test(op, (x,)) 3754 3755 def test_lgamma(self, device): 3756 x = torch.randn(2, 3, requires_grad=True, device=device) 3757 self._batched_grad_test(Tensor.lgamma, (x,)) 3758 self._batched_grad_grad_test(Tensor.lgamma, (x,)) 3759 3760 def test_log(self, device): 3761 x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 3762 self._batched_grad_test(torch.log, (x,)) 3763 self._batched_grad_grad_test(torch.log, (x,)) 3764 3765 def test_logsumexp(self, device): 3766 x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 3767 3768 def op(x): 3769 return torch.logsumexp(x, -1) 3770 3771 self._batched_grad_test(op, (x,)) 3772 self._batched_grad_grad_test(op, (x,)) 3773 3774 def test_log1p(self, device): 3775 x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 3776 self._batched_grad_test(torch.log1p, (x,)) 3777 self._batched_grad_grad_test(torch.log1p, (x,)) 3778 3779 @allowVmapFallbackUsage 3780 def test_max(self, device): 3781 x = torch.randn(2, 3, requires_grad=True, device=device) 3782 self._batched_grad_test(torch.max, (x,)) 3783 3784 @allowVmapFallbackUsage 3785 def test_median(self, device): 3786 x = torch.randn(2, 3, requires_grad=True, device=device) 3787 self._batched_grad_test(torch.median, (x,)) 3788 3789 @allowVmapFallbackUsage 3790 def test_min(self, device): 3791 x = torch.randn(2, 3, requires_grad=True, device=device) 3792 self._batched_grad_test(torch.min, (x,)) 3793 3794 def test_permute(self, device): 3795 x = torch.randn(2, 3, 5, requires_grad=True, device=device) 3796 3797 def op(x): 3798 return x.permute(2, 0, 1) 3799 3800 self._batched_grad_test(op, (x,)) 3801 3802 def test_reshape(self, device): 3803 x = torch.randn(2, 3, 5, requires_grad=True, device=device) 3804 3805 def op(x): 3806 return x.reshape([2 * 3, 5]) 3807 3808 self._batched_grad_test(op, (x,)) 3809 3810 def test_sigmoid(self, device): 3811 x = torch.randn(2, 3, requires_grad=True, device=device) 3812 self._batched_grad_test(Tensor.sigmoid, (x,)) 3813 self._batched_grad_grad_test(Tensor.sigmoid, (x,)) 3814 3815 def test_stack(self, device): 3816 x = torch.randn(2, 3, device=device, requires_grad=True) 3817 y = torch.randn(2, 3, device=device, requires_grad=True) 3818 3819 def op(x, y): 3820 return torch.stack([x, y]) 3821 3822 self._batched_grad_test(op, (x, y)) 3823 3824 def test_select(self, device): 3825 x = torch.randn(2, 3, device=device, requires_grad=True) 3826 self._batched_grad_test(lambda x: x[1], (x,)) 3827 self._batched_grad_test(lambda x: x.select(1, 2), (x,)) 3828 self._batched_grad_test(lambda x: x.select(-1, 0), (x,)) 3829 3830 def test_slice(self, device): 3831 x = torch.randn(2, 3, 5, device=device, requires_grad=True) 3832 self._batched_grad_test(lambda x: x[0:1], (x,)) 3833 self._batched_grad_test(lambda x: x[:, 1:3], (x,)) 3834 self._batched_grad_test(lambda x: x[..., 1:3], (x,)) 3835 3836 def test_trace(self, device): 3837 x = torch.randn(2, 3, device=device, requires_grad=True) 3838 self._batched_grad_test(Tensor.trace, (x,)) 3839 3840 x = torch.randn(3, 2, 2, device=device) 3841 3842 def sum_grad_trace(x): 3843 return grad(torch.trace)(x).sum() 3844 3845 output = vmap(grad(sum_grad_trace))(x) 3846 self.assertEqual(output, torch.zeros_like(output)) 3847 3848 def test_where(self, device): 3849 x = torch.randn(3, 2, device=device) 3850 y = torch.ones(3, 2, device=device) 3851 3852 def f(x, y): 3853 return torch.where(x > 0, x, y) 3854 3855 # Check that there is no runtime error, exactness tests are done with opinfo 3856 vmap(f)(x, y) 3857 3858 x = torch.randint(0, 2, size=(4, 3), dtype=torch.float) 3859 3860 def f(t): 3861 return torch.where(t) 3862 3863 with self.assertRaisesRegex( 3864 RuntimeError, r"Attempted to vmap over aten::where" 3865 ): 3866 vmap(f)(x) 3867 3868 def test_threshold(self, device): 3869 x = torch.randn(2, 3, device=device, requires_grad=True) 3870 self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,)) 3871 3872 @parametrize("backend", PLATFORM_SPECIFIC_SDPA) 3873 def test_sdpa(self, device, backend): 3874 if device == "cpu": 3875 raise unittest.SkipTest("This test is only for CUDA for now") 3876 3877 def T(*args): 3878 return torch.randn(*args, dtype=torch.float16, device=device) 3879 3880 backend_ctx = sdpa_kernel([backend]) 3881 with backend_ctx: 3882 for batching in [ 3883 (True, True, True), 3884 (True, False, False), 3885 (False, True, True), 3886 ]: 3887 size = [8, 4, 128, 64] 3888 if batching[0]: 3889 query = T(3, *size) 3890 else: 3891 query = T(*size) 3892 if batching[1]: 3893 key = T(3, *size) 3894 else: 3895 key = T(*size) 3896 if batching[2]: 3897 value = T(3, *size) 3898 else: 3899 value = T(*size) 3900 in_dims = tuple(0 if b else None for b in batching) 3901 attention = F.scaled_dot_product_attention 3902 3903 self._vmap_test( 3904 attention, 3905 (query, key, value), 3906 in_dims=in_dims, 3907 ) 3908 # Backwards test doesn't work yet 3909 # self._batched_grad_test( 3910 # lambda query, key, value: F.scaled_dot_product_attention( 3911 # query, key, value 3912 # ), 3913 # (query, key, value), 3914 # ) 3915 3916 B = 4 3917 query = torch.rand(4, 32, B, 8, 128, dtype=torch.float16, device=device) 3918 key = torch.rand(4, B, 32, 8, 128, dtype=torch.float16, device=device) 3919 value = torch.rand(4, 32, 8, 128, dtype=torch.float16, device=device) 3920 self._vmap_test( 3921 F.scaled_dot_product_attention, 3922 (query, key, value), 3923 in_dims=(2, 1, None), 3924 ) 3925 3926 @parametrize("backend", PLATFORM_SPECIFIC_SDPA) 3927 @parametrize("randomness", ["error", "same", "different"]) 3928 def test_randomness(self, device, randomness, backend): 3929 if device == "cpu": 3930 raise unittest.SkipTest("This test is only for CUDA for now") 3931 backend_ctx = sdpa_kernel([backend]) 3932 with backend_ctx: 3933 B = 4 3934 query = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) 3935 key = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) 3936 value = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) 3937 3938 def f(q, k, v, dropout): 3939 return F.scaled_dot_product_attention(q, k, v, dropout_p=dropout) 3940 3941 # No matter the randomness mode, dropout=0.0 should pass 3942 vmap( 3943 functools.partial(f, dropout=0.0), 3944 in_dims=(0, 0, 0), 3945 randomness=randomness, 3946 )(query, key, value) 3947 3948 fail_with_randomness = randomness == "error" 3949 if backend != SDPBackend.MATH: 3950 fail_with_randomness |= randomness == "same" 3951 context = ( 3952 self.assertRaises(RuntimeError) 3953 # We currently don't support randomness == "same", and "error" should always error with randomness 3954 if fail_with_randomness 3955 else contextlib.nullcontext() 3956 ) 3957 with context: 3958 vmap( 3959 functools.partial(f, dropout=0.5), 3960 in_dims=(0, 0, 0), 3961 randomness=randomness, 3962 )(query, key, value) 3963 3964 @allowVmapFallbackUsage 3965 def test_inplace_view(self, device): 3966 leaf = torch.randn(4, 5, requires_grad=True) 3967 3968 def func(leaf): 3969 # Make sure the function is non-trivially twice differentiable 3970 base = leaf * leaf 3971 view = base[0] 3972 view.cos_() 3973 return view 3974 3975 self._batched_grad_test(func, (leaf,), {}) 3976 self._batched_grad_grad_test(func, (leaf,), {}) 3977 3978 @allowVmapFallbackUsage 3979 def test_inplace_manyview(self, device): 3980 leaf = torch.randn(4, 4, 5, requires_grad=True) 3981 3982 def func(leaf): 3983 # Make sure the function is non-trivially twice differentiable 3984 base = leaf * leaf 3985 view = base.transpose(0, 2) 3986 view = view[1] 3987 view = view.diagonal() 3988 view = view[::2] 3989 view.cos_() 3990 return view 3991 3992 self._batched_grad_test(func, (leaf,), {}) 3993 self._batched_grad_grad_test(func, (leaf,), {}) 3994 3995 def test_diagonal(self, device): 3996 x = torch.randn(4, 5, device=device, requires_grad=True) 3997 self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,)) 3998 3999 x = torch.randn(3, 4, 5, device=device, requires_grad=True) 4000 self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,)) 4001 4002 @allowVmapFallbackUsage 4003 def test_unrelated_output(self, device): 4004 B0 = 3 4005 x = torch.randn([], requires_grad=True) 4006 y = torch.randn([], requires_grad=True) 4007 gy = torch.randn(B0, requires_grad=True) 4008 4009 def vjp(v): 4010 (res,) = torch.autograd.grad(y, x, v, allow_unused=True) 4011 return torch.zeros_like(x) if res is None else res 4012 4013 result = vmap(vjp)(gy) 4014 self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) 4015 4016 @allowVmapFallbackUsage 4017 def test_unrelated_output_multiple_grad(self, device): 4018 B0 = 3 4019 x = torch.randn([], requires_grad=True) 4020 y = torch.randn([], requires_grad=True) 4021 gy = torch.randn(B0, requires_grad=True) 4022 4023 def vjp(v): 4024 (res,) = torch.autograd.grad(y, x, v, allow_unused=True) 4025 return torch.zeros_like(x) if res is None else res 4026 4027 _ = vjp(gy[0]) 4028 result = vmap(vjp)(gy) 4029 self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) 4030 4031 4032def discover_variants(opinfo): 4033 aliases = [] 4034 inplace_variants = [] 4035 4036 if opinfo.inplace_variant: 4037 inplace_variants.append(opinfo.inplace_variant) 4038 4039 aliases.append(opinfo.op) 4040 for alias in opinfo.aliases: 4041 aliases.append(alias.op) 4042 if alias.inplace_variant: 4043 inplace_variants.append(alias.inplace_variant) 4044 return aliases, inplace_variants 4045 4046 4047# TODO: enable this when we get a bit closer to getting torch.vmap x torch.compile working. 4048# @markDynamoStrictTest 4049@unMarkDynamoStrictTest 4050class TestVmapOperatorsOpInfo(TestCase): 4051 def vmap_outplace_test( 4052 self, 4053 func, 4054 args, 4055 kwargs, 4056 in_dims, 4057 check_shape_only=False, 4058 postprocess_fn=None, 4059 out_dim=0, 4060 ): 4061 for vmap_out, loop_out in compute_quantities_for_vmap_test( 4062 func, args, kwargs, in_dims, out_dim=out_dim 4063 ): 4064 if postprocess_fn is not None: 4065 loop_out = postprocess_fn(loop_out) 4066 vmap_out = postprocess_fn(vmap_out) 4067 if check_shape_only: 4068 self.assertEqual(vmap_out.shape, loop_out.shape) 4069 continue 4070 self.assertEqual(vmap_out, loop_out) 4071 4072 def vmap_inplace_test( 4073 self, func, args, kwargs, in_dims, postprocess_fn=None, out_dim=0 4074 ): 4075 # NB: This test assumes that the first argument is being modified. 4076 # This is OK because it's what every other OpInfo-based test assumes, 4077 # but it is going to need a more robust solution eventually. 4078 if in_dims[0] is None: 4079 # Check that we correctly raise an error when vmap is impossible 4080 # on the in-place operation 4081 with self.assertRaises(RuntimeError): 4082 for _ in compute_quantities_for_vmap_test( 4083 func, 4084 args, 4085 kwargs, 4086 in_dims, 4087 out_dim=out_dim, 4088 compute_loop_out=False, 4089 clone_inputs=True, 4090 ): 4091 pass 4092 return 4093 for vmap_out, loop_out in compute_quantities_for_vmap_test( 4094 func, 4095 args, 4096 kwargs, 4097 in_dims, 4098 clone_inputs=True, 4099 out_dim=out_dim, 4100 ): 4101 if postprocess_fn is not None: 4102 loop_out = postprocess_fn(loop_out) 4103 vmap_out = postprocess_fn(vmap_out) 4104 self.assertEqual(vmap_out, loop_out) 4105 4106 def opinfo_vmap_test( 4107 self, 4108 device, 4109 dtype, 4110 op, 4111 check_has_batch_rule, 4112 skip_inplace=(), 4113 postprocess_fn=None, 4114 ): 4115 def test(): 4116 # Error inputs check 4117 if op.error_inputs_func is not None: 4118 error_inputs = op.error_inputs(device) 4119 for error_input in error_inputs: 4120 sample_input = error_input.sample_input 4121 args = (sample_input.input,) + tuple(sample_input.args) 4122 kwargs = sample_input.kwargs 4123 for batched_args, in_dims, _ in generate_vmap_inputs(args, {}): 4124 with self.assertRaises(Exception): 4125 vmap(op, in_dims)(*batched_args, **kwargs) 4126 4127 # Sample inputs check 4128 sample_inputs_op = { 4129 # Take too long with reference inputs 4130 "special.chebyshev_polynomial_t", 4131 "special.chebyshev_polynomial_u", 4132 "special.chebyshev_polynomial_v", 4133 "special.chebyshev_polynomial_w", 4134 "special.hermite_polynomial_he", 4135 "special.laguerre_polynomial_l", 4136 "special.legendre_polynomial_p", 4137 "special.shifted_chebyshev_polynomial_t", 4138 "special.shifted_chebyshev_polynomial_u", 4139 "special.shifted_chebyshev_polynomial_v", 4140 "special.shifted_chebyshev_polynomial_w", 4141 } 4142 if op.name in sample_inputs_op: 4143 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 4144 else: 4145 sample_inputs_itr = op.reference_inputs( 4146 device, dtype, requires_grad=False 4147 ) 4148 aliases, inplace_aliases = discover_variants(op) 4149 check_shape_only = op.name in ("empty_like", "new_empty") 4150 for sample_input in sample_inputs_itr: 4151 args = (sample_input.input,) + sample_input.args 4152 if not any(isinstance(arg, torch.Tensor) for arg in args): 4153 # Atleast one tensor required for vmap. 4154 continue 4155 kwargs = sample_input.kwargs 4156 is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs) 4157 out_dim = 0 4158 if op.name == "NumpySplitCopyWithIntCustomOp": 4159 # special case for this custom op 4160 def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim): 4161 return [0 for _ in range(len(splits) + 1)], None 4162 4163 out_dim = sample_vmap_out_dim_numpy_split_copy_with_int(*args) 4164 for batched_args, in_dims, _ in generate_vmap_inputs( 4165 args, {}, is_batch_norm_and_training=is_batch_norm_and_training 4166 ): 4167 for func in aliases: 4168 self.vmap_outplace_test( 4169 func, 4170 batched_args, 4171 kwargs, 4172 in_dims, 4173 check_shape_only, 4174 postprocess_fn, 4175 out_dim=out_dim, 4176 ) 4177 if op.name in skip_inplace: 4178 continue 4179 if not is_valid_inplace_sample_input( 4180 sample_input, op, op.inplace_variant 4181 ): 4182 continue 4183 for func in inplace_aliases: 4184 self.vmap_inplace_test( 4185 func, batched_args, kwargs, in_dims, postprocess_fn 4186 ) 4187 4188 if check_has_batch_rule: 4189 check_vmap_fallback(self, test, op) 4190 else: 4191 test() 4192 4193 vmap_fail = { 4194 # -------------------- ALLOWED FAILURES -------------------------------- 4195 # These are things that we either cannot fix or are not actually problems 4196 xfail("resize_"), 4197 xfail("resize_as_"), 4198 xfail("to_sparse"), 4199 xfail("__getitem__"), # dynamic mask 4200 xfail("index_put"), # dynamic mask 4201 xfail( 4202 "nn.functional.dropout" 4203 ), # works, can't check against for loop because of randomness inconsistency 4204 xfail("nn.functional.scaled_dot_product_attention"), # randomness 4205 xfail("nn.functional.multi_head_attention_forward"), # randomness 4206 xfail("masked_select"), # dynamic op 4207 xfail("nonzero"), # dynamic op 4208 xfail("unique", ""), # dynamic op 4209 xfail("unique_consecutive", ""), # dynamic op 4210 xfail("allclose"), # returns a boolean 4211 xfail("uniform"), # randomness is tested separately 4212 xfail("rand_like"), # randomness is tested separately 4213 xfail("randint_like"), # randomness is tested separately 4214 xfail("randn_like"), # randomness is tested separately 4215 xfail("bernoulli", ""), # randomness is tested separately 4216 xfail("normal", ""), # randomness is tested separately 4217 xfail("normal", "number_mean"), # randomness is tested separately 4218 xfail("multinomial", ""), # randomness 4219 xfail("nn.functional.embedding", ""), # we only support some cases 4220 xfail("nn.functional.rrelu"), # randomness 4221 xfail("nn.functional.dropout2d", ""), # randomness 4222 xfail("nn.functional.dropout3d", ""), # randomness 4223 xfail("nn.functional.alpha_dropout", ""), # randomness 4224 xfail("nn.functional.feature_alpha_dropout", "with_train"), # randomness 4225 xfail("as_strided"), # Our test runner can't handle this; manual test exists 4226 xfail("as_strided_copy"), 4227 xfail( 4228 "as_strided_scatter" 4229 ), # no batching rule implemented, default doesnt work 4230 skip( 4231 "new_empty_strided" 4232 ), # empty tensor data is garbage so it's hard to make comparisons with it 4233 xfail("nn.functional.fractional_max_pool3d"), # randomness 4234 xfail("nn.functional.fractional_max_pool2d"), # randomness 4235 xfail("pca_lowrank", ""), # random operation 4236 xfail("svd_lowrank", ""), # random operation 4237 xfail("sparse.sampled_addmm"), # sparse 4238 xfail("sparse.mm", "reduce"), # sparse 4239 xfail( 4240 "NumpyCubeNotComposableAutogradFunction" 4241 ), # Not composable autograd.Function 4242 skip("_softmax_backward_data"), 4243 skip( 4244 "linalg.eigh", "" 4245 ), # not always return the same result for the same input, see test_linalg_eigh for manual test 4246 skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format 4247 # UnimplementedError: data-dependent operators cannot be vmapped 4248 xfail("NumpyNonzeroCustomOp"), 4249 xfail("NumpyNMSCustomOp"), 4250 # ---------------------------------------------------------------------- 4251 # ---------------------------- BUGS ------------------------------------ 4252 # entries in here don't work and need to be fixed. 4253 # Each one of these is a bug 4254 decorate("frexp", decorator=skipIfTorchDynamo()), 4255 xfail("clamp_min", ""), # Exception not raised on error input 4256 xfail("clamp_max", ""), # Exception not raised on error input 4257 xfail( 4258 "view_as_complex" 4259 ), # RuntimeError: Tensor must have a last dimension with stride 1 4260 xfail("tensor_split"), # data_ptr 4261 xfail( 4262 "histogramdd" 4263 ), # expected Tensor as element 0 in argument 0, but got tuple 4264 xfail("nn.functional.gaussian_nll_loss"), # data-dependent control flow error 4265 xfail( 4266 "nn.functional.embedding_bag" 4267 ), # embedding renorm vmap inplace incompatible 4268 xfail("narrow"), # Batching rule not implemented for aten::narrow.Tensor 4269 # required rank 4 tensor to use channels_last format 4270 xfail("bfloat16"), 4271 xfail("bool"), 4272 xfail("byte"), 4273 xfail("char"), 4274 xfail("double"), 4275 xfail("float"), 4276 xfail("half"), 4277 xfail("int"), 4278 xfail("long"), 4279 xfail("short"), 4280 xfail("cdouble"), 4281 xfail("cfloat"), 4282 xfail( 4283 "jiterator_binary", device_type="cuda" 4284 ), # NYI: querying is_contiguous inside of vmap 4285 xfail( 4286 "jiterator_binary_return_by_ref", device_type="cuda" 4287 ), # NYI: querying is_contiguous inside of vmap 4288 xfail( 4289 "jiterator_4inputs_with_extra_args", device_type="cuda" 4290 ), # NYI: querying is_contiguous inside of vmap 4291 xfail( 4292 "equal", "" 4293 ), # TypeError: object of type 'bool' has no len(); likely testrunner problem 4294 xfail( 4295 "jiterator_unary", device_type="cuda" 4296 ), # NYI: querying is_contiguous inside of vmap 4297 xfail( 4298 "jiterator_2inputs_2outputs", device_type="cuda" 4299 ), # NYI: querying is_contiguous inside of vmap 4300 # --------------------------------------------------------------------- 4301 # TypeError: expected Tensor as element 0 in argument 0, but got NotImplementedType 4302 xfail("__rsub__"), 4303 # RuntimeError: Batching rule not implemented for aten::moveaxis.int; 4304 # the fallback path doesn't work on out= or view ops. 4305 xfail("movedim"), 4306 # RuntimeError: NYI: querying is_contiguous inside of vmap for 4307 # memory_format other than torch.contiguous_format 4308 xfail("contiguous"), 4309 # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only supported 4310 # with memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast) 4311 xfail("clone"), 4312 # RuntimeError: When vmap-ing torch.nn.functional.one_hot, 4313 # please provide an explicit positive num_classes argument. 4314 xfail("nn.functional.one_hot"), 4315 # RuntimeError: Expected all tensors to be on the same device, 4316 # but found at least two devices, cuda:0 and cpu! 4317 xfail("eq", device_type="cuda"), 4318 xfail("ge", device_type="cuda"), 4319 xfail("gt", device_type="cuda"), 4320 xfail("le", device_type="cuda"), 4321 xfail("lt", device_type="cuda"), 4322 xfail("ne", device_type="cuda"), 4323 # RuntimeError: aten::_flash_attention_forward hit the vmap fallback which is currently disabled 4324 xfail("torch.ops.aten._flash_attention_forward"), 4325 } 4326 4327 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 4328 @ops( 4329 op_db + additional_op_db + autograd_function_db + custom_op_db, 4330 dtypes=OpDTypes.any_one, 4331 ) 4332 @opsToleranceOverride( 4333 "TestVmapOperatorsOpInfo", 4334 "test_vmap_exhaustive", 4335 ( 4336 tol1( 4337 "linalg.det", 4338 {torch.float32: tol(atol=1e-04, rtol=1e-04)}, 4339 device_type="cuda", 4340 ), 4341 # The following is often flaky, but just on windows. 4342 # We should investigate if it's actually a problem or not. 4343 tol1( 4344 "nn.functional.conv_transpose3d", 4345 {torch.float32: tol(atol=1e-04, rtol=1e-02)}, 4346 device_type="cuda", 4347 ), 4348 ), 4349 ) 4350 @toleranceOverride( 4351 { 4352 torch.float32: tol(atol=1e-04, rtol=1e-04), 4353 torch.complex64: tol(atol=1e-04, rtol=1e-04), 4354 } 4355 ) 4356 @skipOps( 4357 "TestVmapOperatorsOpInfo", 4358 "test_vmap_exhaustive", 4359 vmap_fail.union( 4360 { 4361 # RuntimeError: Batch norm got a batched tensor as input while the running_mean or running_var, 4362 # which will be updated in place, were not batched. 4363 xfail("native_batch_norm"), 4364 xfail("_native_batch_norm_legit"), 4365 # TODO: implement batching rule 4366 xfail("_batch_norm_with_update"), 4367 xfail("tril"), # Exception not raised on error input 4368 xfail("triu"), # Exception not raised on error input 4369 xfail("as_strided", "partial_views"), 4370 # RuntimeError: output with shape [4, 4] doesn't match the broadcast shape [1, 4, 4] 4371 xfail("addcdiv"), 4372 xfail("addcmul"), 4373 xfail("clamp"), 4374 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 4375 # TypeError: expected Tensor as element 0 in argument 0, but got float 4376 xfail("item"), 4377 } 4378 ), 4379 ) 4380 def test_vmap_exhaustive(self, device, dtype, op): 4381 # needs to be fixed 4382 inplace_failure_list = () 4383 self.opinfo_vmap_test( 4384 device, 4385 dtype, 4386 op, 4387 check_has_batch_rule=False, 4388 skip_inplace=inplace_failure_list, 4389 ) 4390 4391 @with_tf32_off 4392 @ops( 4393 op_db + additional_op_db + autograd_function_db + custom_op_db, 4394 dtypes=OpDTypes.any_one, 4395 ) 4396 @opsToleranceOverride( 4397 "TestVmapOperatorsOpInfo", 4398 "test_op_has_batch_rule", 4399 ( 4400 tol1( 4401 "linalg.det", 4402 {torch.float32: tol(atol=1e-04, rtol=1e-04)}, 4403 device_type="cuda", 4404 ), 4405 ), 4406 ) 4407 @toleranceOverride( 4408 { 4409 torch.float32: tol(atol=1e-04, rtol=1e-04), 4410 torch.complex64: tol(atol=1e-04, rtol=1e-04), 4411 } 4412 ) 4413 @skipOps( 4414 "TestVmapOperatorsOpInfo", 4415 "test_op_has_batch_rule", 4416 vmap_fail.union( 4417 { 4418 xfail("as_strided", "partial_views"), 4419 skip( 4420 "to" 4421 ), # RuntimeError: required rank 4 tensor to use channels_last format 4422 xfail("fill"), 4423 # Batch norm got a batched tensor as input while the running_mean or running_var, 4424 # which will be updated in place, were not batched. 4425 xfail("native_batch_norm"), 4426 xfail("_native_batch_norm_legit"), 4427 # TODO: implement batching rule 4428 xfail("_batch_norm_with_update"), 4429 xfail("histogram"), 4430 xfail("scatter_reduce", "sum"), 4431 xfail("scatter_reduce", "mean"), 4432 xfail("scatter_reduce", "amax"), 4433 xfail("scatter_reduce", "amin"), 4434 # `index_put` OpInfo in pytorch/pytorch has 4435 # masked index as input which is not supported 4436 xfail("index_put", ""), 4437 xfail("isin"), 4438 xfail("masked_fill"), 4439 xfail("masked_scatter"), 4440 xfail("masked_select"), 4441 xfail("nanquantile"), 4442 xfail("ormqr"), 4443 xfail("put"), 4444 xfail("quantile"), 4445 xfail("renorm"), 4446 xfail("resize_as_"), 4447 xfail("take"), 4448 xfail("tensor_split"), 4449 xfail("to_sparse"), 4450 # TypeError: expected Tensor as element 0 in argument 0, but got float 4451 xfail("item"), 4452 xfail("tril"), # Exception not raised on error input 4453 xfail("triu"), # Exception not raised on error input 4454 xfail("__getitem__", ""), 4455 xfail("count_nonzero"), 4456 xfail( 4457 "nn.functional.dropout" 4458 ), # works, can't check against for loop because of randomness inconsistency 4459 xfail("nn.functional.scaled_dot_product_attention"), # randomness 4460 xfail("nn.functional.multi_head_attention_forward"), # randomness 4461 xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints 4462 xfail("resize_"), 4463 xfail("view_as_complex"), 4464 xfail("matrix_exp"), 4465 xfail("fft.ihfft2"), 4466 xfail("fft.ihfftn"), 4467 xfail("allclose"), 4468 xfail("argwhere"), 4469 xfail("unique_consecutive"), 4470 xfail("unique"), 4471 xfail("nn.functional.ctc_loss"), 4472 xfail("nn.functional.gaussian_nll_loss"), 4473 xfail("histc"), 4474 xfail("as_strided"), 4475 xfail("as_strided_copy"), 4476 xfail("t_copy"), 4477 xfail("unsqueeze_copy"), 4478 xfail("istft"), 4479 xfail("nonzero"), 4480 xfail("nn.functional.fractional_max_pool2d"), 4481 xfail("stft"), 4482 xfail("isclose"), 4483 xfail("nn.functional.fractional_max_pool3d"), 4484 xfail("nn.functional.bilinear"), 4485 xfail("nn.functional.embedding_bag"), 4486 xfail("linalg.tensorsolve"), 4487 xfail("bernoulli", ""), 4488 xfail("nn.functional.feature_alpha_dropout", "with_train"), 4489 xfail("native_dropout_backward"), 4490 xfail("nn.functional.kl_div", ""), 4491 xfail("multinomial", ""), 4492 xfail("pca_lowrank", ""), 4493 xfail("normal", ""), 4494 xfail("nn.functional.dropout2d", ""), 4495 xfail("normal", "number_mean"), 4496 xfail("svd_lowrank", ""), 4497 xfail("diagflat", ""), 4498 xfail("special.log_ndtr"), 4499 xfail( 4500 "narrow" 4501 ), # Batching rule not implemented for aten::narrow.Tensor 4502 xfail("nn.functional.triplet_margin_loss", ""), 4503 xfail("nn.functional.pdist", ""), 4504 xfail("scatter_reduce", "sum"), 4505 xfail("scatter_reduce", "amax"), 4506 xfail("nn.functional.max_unpool1d", "grad"), 4507 xfail("nn.functional.multi_margin_loss", ""), 4508 xfail("scatter_reduce", "prod"), 4509 xfail("nn.functional.multilabel_margin_loss", ""), 4510 xfail("scatter_reduce", "amin"), 4511 xfail("nn.functional.max_unpool3d", "grad"), 4512 xfail("nn.functional.max_unpool2d", ""), 4513 xfail("nn.functional.max_unpool2d", "grad"), 4514 xfail("nn.functional.margin_ranking_loss", ""), 4515 xfail("nn.functional.max_unpool1d", ""), 4516 xfail("nn.functional.soft_margin_loss", ""), 4517 xfail("scatter_reduce", "mean"), 4518 xfail("nn.functional.max_unpool3d", ""), 4519 xfail("linalg.ldl_solve", "", device_type="cpu"), 4520 xfail("chalf", ""), 4521 xfail("clamp_max", ""), 4522 xfail("jiterator_binary_return_by_ref", device_type="cuda"), 4523 xfail("jiterator_unary", device_type="cuda"), 4524 xfail("jiterator_2inputs_2outputs", device_type="cuda"), 4525 xfail("special.airy_ai"), 4526 xfail("clamp_min", ""), 4527 xfail("sparse.sampled_addmm"), 4528 xfail("sparse.mm", "reduce"), 4529 xfail("special.chebyshev_polynomial_u"), 4530 xfail("_segment_reduce", "offsets"), 4531 xfail("index_reduce", "prod"), 4532 xfail("index_reduce", "mean"), 4533 xfail("index_reduce", "amin"), 4534 xfail("index_reduce", "amax"), 4535 xfail("special.laguerre_polynomial_l"), 4536 xfail("special.hermite_polynomial_h"), 4537 xfail("jiterator_binary", device_type="cuda"), 4538 xfail("jiterator_4inputs_with_extra_args", device_type="cuda"), 4539 xfail("_segment_reduce", "lengths"), 4540 xfail("lu_solve", ""), 4541 xfail("special.hermite_polynomial_he"), 4542 xfail("nn.functional.dropout3d", ""), 4543 xfail("special.chebyshev_polynomial_t"), 4544 xfail("as_strided_scatter", ""), 4545 xfail("equal", ""), 4546 xfail("linalg.lu", ""), 4547 skip("linalg.ldl_solve", ""), 4548 skip("_softmax_backward_data"), 4549 # One or more of the overload doesn't have a Batch rule. 4550 xfail("bincount"), 4551 # RuntimeError: Expected all tensors to be on the same device, 4552 # but found at least two devices, cuda:0 and cpu! 4553 xfail("ge", device_type="cuda"), 4554 xfail( 4555 "searchsorted" 4556 ), # aten::searchsorted.Scalar hit the vmap fallback which is currently disabled 4557 } 4558 ), 4559 ) 4560 def test_op_has_batch_rule(self, device, dtype, op): 4561 # needs to be fixed 4562 inplace_failures = ( 4563 "addbmm", 4564 "addcdiv", 4565 "addcmul", 4566 "addmm", 4567 "addmv", 4568 "addr", 4569 "baddbmm", 4570 "clamp", 4571 "conj_physical", 4572 "cumprod", 4573 "cumsum", 4574 "floor_divide", 4575 "fmod", 4576 "heaviside", 4577 "hypot", 4578 "igamma", 4579 "igammac", 4580 "index_copy", 4581 "ldexp", 4582 "lerp", 4583 "neg", 4584 "nextafter", 4585 "polygamma", 4586 "pow", 4587 "remainder", 4588 "scatter_add", 4589 "scatter", 4590 "square", 4591 "sub", 4592 "trunc", 4593 "xlogy", 4594 ) 4595 self.opinfo_vmap_test( 4596 device, dtype, op, check_has_batch_rule=True, skip_inplace=inplace_failures 4597 ) 4598 4599 def test_linalg_svd(self, device): 4600 # linalg_svd returns a tuple of three tensors, (U, S, Vh). 4601 # Given the same input, it may return different tensors, 4602 # because svd isn't unique. To test that the svd is correct, we multiply 4603 # U @ diag(S) @ Vh and check that the output from vmap matches the 4604 # output from a for-loop. 4605 def compute_A(out): 4606 U, S, Vh = out 4607 m = U.shape[-1] 4608 n = Vh.shape[-2] 4609 diag_S = S.new_zeros(*S.shape[:-1], m, n) 4610 diag_S.diagonal(offset=0, dim1=-2, dim2=-1).copy_(S) 4611 return U @ diag_S @ Vh 4612 4613 opinfos = [op for op in op_db if op.name == "linalg.svd"] 4614 assert len(opinfos) > 0 4615 4616 for op in opinfos: 4617 self.opinfo_vmap_test( 4618 device, 4619 torch.float, 4620 op, 4621 check_has_batch_rule=True, 4622 postprocess_fn=compute_A, 4623 ) 4624 4625 def test_linalg_eigh(self, device): 4626 # linalg_svd returns two tensors, (Q, L). 4627 # Given the same input, it may return different tensors, 4628 # because the eig decomposition isn't unique. 4629 # To test that eigh is correct, we multiply 4630 # Q @ diag(L) @ Qh and check that the output from vmap matches the 4631 # output from a for-loop. 4632 def compute_A(out): 4633 L, Q = out 4634 n = Q.shape[-1] 4635 diag_L = L.new_zeros(*L.shape[:-1], n, n) 4636 diag_L.diagonal(offset=0, dim1=-2, dim2=-1).copy_(L) 4637 Qh = Q.transpose(-2, -1).conj() 4638 return Q @ diag_L @ Qh 4639 4640 opinfos = [op for op in op_db if op.name == "linalg.eigh"] 4641 assert len(opinfos) > 0 4642 4643 for op in opinfos: 4644 self.opinfo_vmap_test( 4645 device, 4646 torch.float, 4647 op, 4648 check_has_batch_rule=True, 4649 postprocess_fn=compute_A, 4650 ) 4651 4652 @skipIfTorchDynamo() 4653 def test_slogdet(self, device): 4654 # There's no OpInfo for this 4655 def test(): 4656 B = 2 4657 x = torch.randn(B, 5, 5, device=device) 4658 self.vmap_outplace_test(torch.slogdet, (x,), {}, (0,)) 4659 4660 check_vmap_fallback(self, test, torch.slogdet) 4661 4662 def test_index_fill(self, device): 4663 # There's no OpInfo for these tests 4664 4665 B = 2 4666 4667 def test1(): 4668 # negative dim 4669 x = torch.randn(B, 5, 5, device=device) 4670 dim = -2 4671 index = torch.tensor([[2, 3], [0, 4]], device=device) 4672 value = 5.0 4673 self.vmap_outplace_test( 4674 torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None) 4675 ) 4676 4677 def test2(): 4678 # self batched, self logical rank 1, index logical rank 1 4679 x = torch.zeros(B, 3, device=device) 4680 dim = 0 4681 index = torch.tensor([[0], [1]], device=device) 4682 for value in (1.0, torch.rand((), device=device)): 4683 self.vmap_outplace_test( 4684 torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None) 4685 ) 4686 4687 def test3(): 4688 # self batched, self logical rank 1, index logical rank 0 4689 x = torch.zeros(B, 3, device=device) 4690 dim = 0 4691 index = torch.tensor([0, 1], device=device) 4692 for value in (1.0, torch.rand((), device=device)): 4693 self.vmap_outplace_test( 4694 torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None) 4695 ) 4696 4697 def test4(): 4698 # self not batched, self logical rank 0, index logical rank 1 4699 x = torch.zeros([], device=device) 4700 dim = 0 4701 index = torch.tensor([[0], [0]], device=device) 4702 for value in (1.0, torch.rand((), device=device)): 4703 self.vmap_outplace_test( 4704 torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None) 4705 ) 4706 4707 def test5(): 4708 # self not batched, self logical rank 0, index logical rank 0 4709 x = torch.zeros([], device=device) 4710 dim = 0 4711 index = torch.tensor([0, 0], device=device) 4712 for value in (1.0, torch.rand((), device=device)): 4713 self.vmap_outplace_test( 4714 torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None) 4715 ) 4716 4717 def test6(): 4718 # self not batched, self logical rank 0, index logical rank 1 4719 x = torch.zeros(3, device=device) 4720 dim = 0 4721 index = torch.tensor([[0], [1]], device=device) 4722 for value in (1.0, torch.rand((), device=device)): 4723 self.vmap_outplace_test( 4724 torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None) 4725 ) 4726 4727 def test7(): 4728 # self not batched, self logical rank 0, index logical rank 0 4729 x = torch.zeros(3, device=device) 4730 dim = 0 4731 index = torch.tensor([0, 1], device=device) 4732 for value in (1.0, torch.rand((), device=device)): 4733 self.vmap_outplace_test( 4734 torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None) 4735 ) 4736 4737 def test8(): 4738 # self batched, self logical rank > 1, index logical rank 0 4739 x = torch.zeros(B, 3, 3, device=device) 4740 dim = 0 4741 index = torch.tensor([0, 1], device=device) 4742 for value in (1.0, torch.rand((), device=device)): 4743 self.vmap_outplace_test( 4744 torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None) 4745 ) 4746 4747 for test in (test1, test2, test3, test4, test5, test6, test7, test8): 4748 check_vmap_fallback(self, test, torch.index_fill) 4749 4750 def test_fill__Tensor(self, device): 4751 # There's no OpInfo for fill_.Tensor, so here's an extra test for it. 4752 def test(): 4753 B = 2 4754 args = (torch.randn(B, 3, device=device), torch.randn(B)) 4755 self.vmap_inplace_test(Tensor.fill_, args, {}, (0, 0)) 4756 4757 args = (torch.randn(3, B, device=device), torch.randn(B)) 4758 self.vmap_inplace_test(Tensor.fill_, args, {}, (-1, 0)) 4759 4760 args = (torch.randn(3, device=device), torch.randn(B)) 4761 self.vmap_inplace_test(Tensor.fill_, args, {}, (None, 0)) 4762 4763 args = (torch.randn(3, B, device=device), torch.randn([])) 4764 self.vmap_inplace_test(Tensor.fill_, args, {}, (1, None)) 4765 4766 check_vmap_fallback(self, test, Tensor.fill_) 4767 4768 def test_conv_double_backward(self, device): 4769 images = torch.randn(2, 1, 5, 5, device=device) 4770 weight = torch.randn(2, 1, 2, 2, device=device) 4771 bias = torch.randn(2, device=device) 4772 ggI = torch.randn_like(images) 4773 ggW = torch.randn_like(weight) 4774 ggb = torch.randn_like(bias) 4775 stride = (1, 1) 4776 padding = (0, 0) 4777 dilation = (1, 1) 4778 transposed = False 4779 output_padding = (0, 0) 4780 groups = 1 4781 output_mask = (True, True, True) 4782 gO = torch.randn_like( 4783 F.conv2d(images, weight, bias, stride, padding, dilation, groups) 4784 ) 4785 4786 args = ( 4787 ggI, 4788 ggW, 4789 ggb, 4790 gO, 4791 weight, 4792 images, 4793 stride, 4794 padding, 4795 dilation, 4796 transposed, 4797 output_padding, 4798 groups, 4799 output_mask, 4800 ) 4801 op = torch.ops.aten._convolution_double_backward 4802 4803 generator = get_fallback_and_vmap_exhaustive(op, args, {}) 4804 is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability( 4805 0 4806 ) == (8, 6) 4807 atol, rtol = (1e-3, 1e-3) if is_cuda_sm86 else (1e-4, 1e-4) 4808 4809 def test(): 4810 for loop_out, batched_out in generator: 4811 self.assertEqual(loop_out, batched_out, atol=atol, rtol=rtol) 4812 4813 check_vmap_fallback(self, test, op) 4814 4815 def test_isnan(self, device): 4816 test = functools.partial(_vmap_test, check_propagates_grad=False) 4817 4818 B, N, C, H, W = 2, 3, 24, 5, 7 4819 op = torch.isnan 4820 4821 x = torch.randn(B, N, C, H, W) 4822 x[x > 0] = float("nan") 4823 test(self, op, (x,), in_dims=(0)) 4824 4825 def test_sum_scalar(self, device): 4826 x = torch.tensor([10.0], device=device) 4827 y = vmap(torch.sum)(x) 4828 self.assertEqual(y, x) 4829 4830 y = vmap(lambda x: x.sum(0))(x) 4831 self.assertEqual(y, x) 4832 4833 y = vmap(lambda x: x.sum(-1))(x) 4834 self.assertEqual(y, x) 4835 4836 def test_isinf(self, device): 4837 test = functools.partial(_vmap_test, check_propagates_grad=False) 4838 4839 B, N, C, H, W = 2, 3, 24, 5, 7 4840 op = torch.isinf 4841 4842 x = torch.randn(B, N, C, H, W) 4843 x[x > 0] = float("inf") 4844 test(self, op, (x,), in_dims=(0)) 4845 4846 def test_foo_like(self, device): 4847 # vfdev-5: Probably, we can remove this line. Flake8 reported as unused 4848 # test = functools.partial(_vmap_test, check_propagates_grad=False) 4849 4850 B, N, C, H, W = 2, 3, 24, 5, 7 4851 for op in [torch.ones_like, torch.zeros_like]: 4852 x = torch.randn(B, N, C, H, W) 4853 # todo(chilli): test these better 4854 # Not testing correctness, just that they run 4855 vmap(op, in_dims=(0,))( 4856 x, 4857 ) 4858 4859 def test_flatten(self, device): 4860 test = functools.partial(_vmap_test, check_propagates_grad=False) 4861 4862 op = torch.flatten 4863 4864 x = torch.randn(2, 3, 4, 5) 4865 test(self, op, (x, 1, 2), in_dims=(0, None, None)) 4866 4867 def test_group_norm(self, device): 4868 test = functools.partial(_vmap_test, check_propagates_grad=False) 4869 4870 B, N, C, H, W = 2, 3, 24, 5, 7 4871 op = F.group_norm 4872 4873 x = torch.randn(B, N, C, H, W) 4874 weight = torch.randn(C) 4875 bias = torch.randn(C) 4876 test(self, op, (x, 3, weight, bias), in_dims=(0, None, None, None)) 4877 4878 x = torch.randn(B, N, C, H, W) 4879 weight = torch.randn(B, C) 4880 bias = torch.randn(B, C) 4881 test(self, op, (x, 4, weight, bias), in_dims=(0, None, 0, 0)) 4882 4883 def test_index_put(self, device): 4884 def test(f, t, idx, values): 4885 base = f(t[0], idx[0], values[0]) 4886 self.assertEqual(vmap(f, in_dims=(0, 0, 0))(t, idx, values)[0], base) 4887 self.assertEqual( 4888 vmap(f, in_dims=(0, None, None))(t, idx[0], values[0])[0], base 4889 ) 4890 self.assertEqual(vmap(f, in_dims=(0, None, 0))(t, idx[0], values)[0], base) 4891 self.assertEqual(vmap(f, in_dims=(0, 0, None))(t, idx, values[0])[0], base) 4892 4893 def f(x, y, z): 4894 x[y] = z 4895 return x 4896 4897 x = torch.randn(3, 4, 5, device=device) 4898 y = torch.zeros((3, 2), device=device).long() 4899 z = torch.randn(3, 2, 5, device=device) 4900 test(f, x, y, z) 4901 4902 # indexing innermost dim 4903 def f(t, idx, values): 4904 t[:, idx] = values 4905 return t 4906 4907 t = torch.zeros((3, 2, 3)) 4908 values = torch.ones((3, 1, 2)) 4909 idx = torch.tensor([[1, 2]]).expand((3, 2)) 4910 test(f, t, idx, values) 4911 4912 # indexing middle dim 4913 def f(t, idx, values): 4914 t[:, idx, :] = values 4915 return t 4916 4917 t = torch.zeros((3, 2, 3, 3)) 4918 values = torch.ones((3, 1, 2, 3)) 4919 idx = torch.tensor([[0, 2]]).expand((3, 2)) 4920 test(f, t, idx, values) 4921 4922 # indexing with slices 4923 def f(t, values): 4924 t[:, :2, :] = values 4925 return t 4926 4927 base = f(t[0], values[0]) 4928 self.assertEqual(vmap(f, in_dims=(0, 0))(t, values)[0], base) 4929 self.assertEqual(vmap(f, in_dims=(0, None))(t, values[0])[0], base) 4930 4931 # index_put_ 4932 tensor = torch.zeros(3, 3, 4) 4933 value = torch.ones(3, 2) 4934 idxs = ( 4935 torch.tensor([[0], [1], [2]]), 4936 torch.tensor([[0]]), 4937 torch.tensor([1, 2]), 4938 ) 4939 expected = torch.index_put_(tensor.clone(), idxs, value) 4940 4941 def f(t, idx, v): 4942 torch.index_put_(t, idx, v) 4943 return t 4944 4945 self.assertEqual( 4946 vmap(f, in_dims=(0, (None, None), 0))(tensor, idxs[1:], value), expected 4947 ) 4948 self.assertEqual( 4949 vmap(f, in_dims=(0, (None, None), None))(tensor, idxs[1:], value[0]), 4950 expected, 4951 ) 4952 4953 # boolean mask 4954 B = 2 4955 x = torch.randn(1, 3, 3) 4956 gy = torch.randn(B, 1, 3, 3) 4957 4958 def f(x, gy): 4959 mask = x < 1e-09 4960 zeros = torch.zeros([]) 4961 index_put = torch.ops.aten.index_put.default(gy, [mask], zeros) 4962 return index_put 4963 4964 self.vmap_outplace_test(f, (x, gy), {}, in_dims=(None, 0)) 4965 4966 @onlyCUDA 4967 @parametrize("inplace", [True, False]) 4968 def test_0d_tensor_index_put(self, device, inplace): 4969 def f(t, idx, v): 4970 fn = torch.index_put_ if inplace else torch.index_put 4971 return fn(t, idx, v) 4972 4973 N = 2 4974 t = torch.zeros((N, 5), device="cuda") 4975 idx = torch.tensor([1, 3]) 4976 v = torch.tensor(1, dtype=t.dtype, device="cpu") 4977 4978 expected = torch.tensor([[0, 1, 0, 1, 0], [0, 1, 0, 1, 0]], dtype=t.dtype) 4979 self.assertEqual(expected, vmap(f, in_dims=(0, None, None))(t, (idx,), v)) 4980 4981 @parametrize("training", [True, False]) 4982 @parametrize("track_running_stats", [True, False]) 4983 @parametrize("affine", [True, False]) 4984 def test_batch_norm(self, device, affine, track_running_stats, training): 4985 if not track_running_stats and not training: 4986 return 4987 4988 test = functools.partial(_vmap_test, check_propagates_grad=False) 4989 BN = torch.nn.BatchNorm2d 4990 ensemble_size = 10 4991 hidden_dim = 3 4992 4993 weights, buffers, _, _, _ = functional_init_with_buffers(BN, [ensemble_size])( 4994 hidden_dim, affine=affine, track_running_stats=track_running_stats 4995 ) 4996 4997 inputs = [torch.randn(ensemble_size, 32, hidden_dim, 16, 16, device=device)] 4998 in_dims = [0] 4999 5000 def append(inp, in_dim): 5001 inputs.append(inp) 5002 in_dims.append(in_dim) 5003 5004 if track_running_stats: 5005 running_mean, running_var, _ = buffers 5006 append(running_mean.to(device), 0) 5007 append(running_var.to(device), 0) 5008 else: 5009 append(None, None) 5010 append(None, None) 5011 5012 if affine: 5013 weight, bias = weights 5014 append(weight.to(device), 0) 5015 append(bias.to(device), 0) 5016 else: 5017 append(None, None) 5018 append(None, None) 5019 5020 append(training, None) 5021 5022 def op(inp, running_mean, running_var, weight, bias, training): 5023 res = F.batch_norm(inp, running_mean, running_var, weight, bias, training) 5024 if track_running_stats: 5025 return res, running_mean, running_var 5026 return res 5027 5028 test(self, op, tuple(inputs), in_dims=tuple(in_dims)) 5029 5030 def test_torch_return_types_returns(self, device): 5031 t = torch.randn(3, 2, 2, device=device) 5032 self.assertTrue( 5033 isinstance(vmap(torch.min, (0, None))(t, 0), torch.return_types.min) 5034 ) 5035 self.assertTrue( 5036 isinstance(vmap(torch.max, (0, None))(t, 0), torch.return_types.max) 5037 ) 5038 self.assertTrue( 5039 isinstance( 5040 vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk 5041 ) 5042 ) 5043 self.assertTrue( 5044 isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig) 5045 ) 5046 5047 def test_namedtuple_returns(self, device): 5048 Point = namedtuple("Point", ["x", "y"]) 5049 5050 def f(x, y): 5051 return Point(x=x, y=y) 5052 5053 x = torch.randn(2, 5, device=device) 5054 y = torch.randn(2, 3, device=device) 5055 self.assertTrue(isinstance(vmap(f)(x, y), Point)) 5056 5057 def test_inplace_on_view(self, device): 5058 def func(leaf): 5059 base = leaf * leaf 5060 view = base.transpose(0, 1) 5061 view[2:4, 2:4] *= 2 5062 view[0:2, 0:2].diagonal().sin_() 5063 view = view[1:3, 1:3] 5064 view.cos_() 5065 return view 5066 5067 def push_vjp(leaf, gout): 5068 _, vjp_fn = vjp(func, leaf) 5069 (result,) = vjp_fn(gout) 5070 return result 5071 5072 leaf = torch.randn(4, 4, device=device) 5073 gout = torch.randn(2, 2, device=device) 5074 args = (leaf, gout) 5075 5076 for ( 5077 batched_args, 5078 in_dims, 5079 _, 5080 ) in generate_vmap_inputs(args, {}): 5081 if in_dims[1] is None: 5082 # triggers some composite compliance problem 5083 continue 5084 self.vmap_outplace_test(push_vjp, batched_args, {}, in_dims) 5085 5086 def test_advanced_indexing(self, device): 5087 def test(f, args): 5088 for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}): 5089 self.assertEqual(loop_out, batched_out) 5090 5091 def f(x, idx): 5092 return x[:, idx] 5093 5094 def f2(x, idx): 5095 return x[idx, :] 5096 5097 def f3(x, idx): 5098 return x[:, :, idx] 5099 5100 inps = ( 5101 torch.randn(5, 5, 5, device=device), 5102 torch.randn(5, 5, 5, 5, device=device), 5103 torch.randn(5, 5, 5, 5, 5, device=device), 5104 ) 5105 idxes = ( 5106 torch.tensor([0, 1, 2], device=device), 5107 torch.tensor([0, 1, 2], device=device).reshape(3, 1), 5108 torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1), 5109 ) 5110 for inp, idx in itertools.product(inps, idxes): 5111 test(f, (inp, idx)) 5112 test(f2, (inp, idx)) 5113 test(f3, (inp, idx)) 5114 5115 def test_nested_advanced_indexing(self, device): 5116 e = torch.rand(7, 4, device=device) 5117 idx = torch.tensor([0, 1], device=device).view(2, 1) 5118 5119 # simple reference implementation for comparison 5120 def _fake_vmap(f, in_dims=0, out_dims=0): 5121 def w(input): 5122 r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))] 5123 return torch.stack(r, out_dims) 5124 5125 return w 5126 5127 def with_vmap(_vmap): 5128 def g(idx_): 5129 def f(e_): 5130 return e_[idx_] 5131 5132 return _vmap(f, in_dims=1)(e) 5133 5134 r = _vmap(g)(idx) 5135 return r 5136 5137 a = with_vmap(vmap) 5138 b = with_vmap(_fake_vmap) 5139 self.assertEqual(a, b) 5140 5141 @ops( 5142 filter(lambda op: "linalg" in op.name, op_db + additional_op_db), 5143 allowed_dtypes=(torch.float,), 5144 ) 5145 @skipOps( 5146 "TestVmapOperatorsOpInfo", 5147 "test_vmap_linalg_failure_1D_input", 5148 { 5149 xfail("linalg.vector_norm"), # can accept vector inputs 5150 xfail("linalg.norm"), # can accept vector inputs 5151 xfail("linalg.norm", "subgradients_at_zero"), # can accept vector inputs 5152 xfail("linalg.vander"), # can accept vector inputs 5153 skip( 5154 "linalg.multi_dot" 5155 ), # accepts list of tensor inputs, has its own special test 5156 xfail("linalg.vecdot"), 5157 # throws in vmap on CUDA 5158 # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2) 5159 # https://github.com/pytorch/pytorch/runs/8110653462?check_suite_focus=true 5160 # but it passes locally 5161 xfail("linalg.diagonal"), 5162 skip("linalg.matrix_norm", ""), 5163 skip("linalg.ldl_solve", ""), 5164 }, 5165 ) 5166 def test_vmap_linalg_failure_1D_input(self, device, dtype, op): 5167 for sample in op.sample_inputs(device, dtype, requires_grad=False): 5168 if sample.input.dim() != 2 or sample.input.shape[0] == 0: 5169 continue 5170 test_input = sample.input[ 5171 0 5172 ] # using the sample input avoids numerical inconsistency issues 5173 with self.assertRaisesRegex(RuntimeError, "dimension"): 5174 op(test_input, *sample.args, **sample.kwargs) 5175 5176 def op_wrapper(inp): 5177 return op(inp, *sample.args, **sample.kwargs) 5178 5179 # square inputs are more likely to pass linalg checks 5180 test_input = test_input.expand(test_input.shape[0], test_input.shape[0]) 5181 with self.assertRaisesRegex(RuntimeError, "dimension"): 5182 return vmap(op_wrapper)(test_input) 5183 5184 def test_vmap_multi_dot_failure_1D_input(self): 5185 # special exception for first and last tensors so making giving 3 items avoids special cases 5186 inputs = (torch.randn(3, 3), torch.randn(3), torch.randn(3, 3)) 5187 with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"): 5188 torch.linalg.multi_dot(inputs) 5189 5190 # square inputs are more likely to pass linalg checks 5191 inputs = tuple(i.expand(i.shape[0], i.shape[0]) for i in inputs) 5192 with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"): 5193 return vmap(torch.linalg.multi_dot)(inputs) 5194 5195 def test_vmap_escaped_error(self): 5196 escaped = None 5197 5198 def f(x): 5199 nonlocal escaped 5200 escaped = x 5201 return x**2 5202 5203 x = torch.randn([3, 3, 3, 3, 3]) 5204 vmap(f)(x) 5205 5206 common_message = ( 5207 r"your tensor may have escaped from inside a function being vmapped.*{0}.*" 5208 ) 5209 5210 # Note: These are not a complete set of tests for all possible functions calling 'vmap_check_escaped' 5211 5212 with self.assertRaisesRegex( 5213 RuntimeError, common_message.format("gen_vmap_plumbing") 5214 ): 5215 escaped.sin() 5216 5217 with self.assertRaisesRegex( 5218 RuntimeError, common_message.format("boxed_tensor_inputs_batch_rule") 5219 ): 5220 escaped.sin_() 5221 5222 with self.assertRaisesRegex( 5223 RuntimeError, common_message.format("gen_vmap_inplace_plumbing") 5224 ): 5225 escaped.mul_(1) 5226 5227 with self.assertRaisesRegex( 5228 RuntimeError, common_message.format("binary_cross_entropy_plumbing") 5229 ): 5230 torch.nn.functional.binary_cross_entropy(escaped, torch.zeros([3, 3, 3, 3])) 5231 5232 with self.assertRaisesRegex( 5233 RuntimeError, common_message.format("boxed_existing_bdim_all_batch_rule") 5234 ): 5235 torch.nn.functional.adaptive_max_pool2d(escaped, output_size=(1, 1)) 5236 5237 with self.assertRaisesRegex( 5238 RuntimeError, common_message.format("boxed_reduction_batch_rule") 5239 ): 5240 escaped.argmin() 5241 5242 a = torch.zeros([4, 4, 4, 4]) 5243 b = torch.zeros([4, 4, 4, 4], dtype=torch.long) 5244 with self.assertRaisesRegex( 5245 RuntimeError, common_message.format("boxed_all_tensors_have_optional_bdim") 5246 ): 5247 torch.ops.aten.adaptive_max_pool2d_backward(escaped, a, b) 5248 5249 vmap(f)(torch.tensor([[0, 0], [0, 0]], dtype=torch.int)) 5250 with self.assertRaisesRegex( 5251 RuntimeError, common_message.format("gen_vmap_plumbing_no_returns") 5252 ): 5253 torch.ops.aten._linalg_check_errors(escaped, "linalg.inv", is_matrix=False) 5254 5255 def test_vmap_with_anomaly_detection(self): 5256 with torch.autograd.set_detect_anomaly(True): 5257 x = torch.zeros(3) - 1 5258 5259 def fn(x): 5260 return x.sum() 5261 5262 per_sample_grad = vmap(grad(fn))(x) 5263 self.assertEqual(per_sample_grad, torch.ones_like(x)) 5264 5265 def bad_fn(x): 5266 return x.sqrt().sum() 5267 5268 err_msg = "Function 'SqrtBackward0' returned nan values in its 0th output." 5269 with self.assertRaisesRegex(RuntimeError, err_msg): 5270 vmap(grad(bad_fn))(x) 5271 5272 def test_searchsorted_bucketize(self, device): 5273 # OpInfo generates test with repeated samples in batch dim. 5274 # Thus we test explicitly with different samples across a batch. 5275 5276 def test(): 5277 boundaries = torch.tensor( 5278 [[1, 4, 5, 7, 9], [1, 2, 6, 8, 10]], device=device 5279 ) 5280 v = torch.tensor(3, device=device) 5281 self.vmap_outplace_test(torch.searchsorted, (boundaries, v), {}, (0, None)) 5282 self.vmap_outplace_test(torch.bucketize, (v, boundaries), {}, (None, 0)) 5283 boundaries = torch.tensor([[1, 4, 5, 7, 9], [1, 2, 4, 8, 9]], device=device) 5284 v = torch.tensor([3, 4], device=device) 5285 self.vmap_outplace_test(torch.searchsorted, (boundaries, v), {}, (0, 0)) 5286 self.vmap_outplace_test(torch.bucketize, (v, boundaries), {}, (0, 0)) 5287 5288 test() 5289 5290 5291@markDynamoStrictTest 5292class TestRandomness(TestCase): 5293 def _reset_random(self, generator, orig_state, use_generator, seed): 5294 return ( 5295 generator.set_state(orig_state) 5296 if use_generator 5297 else torch.manual_seed(seed) 5298 ) 5299 5300 def _get_image(self, batched_input, batch_size, device): 5301 if batched_input == "first": 5302 return torch.ones([batch_size, 3, 3, 14, 14], device=device) 5303 if batched_input == "last": 5304 return torch.ones([3, 3, 14, 14, batch_size], device=device) 5305 assert batched_input == "none" 5306 return torch.ones([3, 3, 14, 14], device=device) 5307 5308 def _assert_all_slices_equal(self, tensor): 5309 expected = tensor[0] 5310 self.assertTrue((tensor == expected).all()) 5311 5312 def _assert_all_slices_unique(self, tensor): 5313 B0 = tensor.shape[0] 5314 slices_equal = vmap(vmap(lambda x, y: (x == y).all(), (0, None)), (None, 0))( 5315 tensor, tensor 5316 ) 5317 assert slices_equal.shape == (B0, B0) 5318 slices_equal.diagonal().zero_() 5319 self.assertEqual(slices_equal, torch.zeros_like(slices_equal)) 5320 5321 def _assert_throws_in_error_mode(self, fn, args, in_dims): 5322 with self.assertRaisesRegex( 5323 RuntimeError, r"called random operation while in randomness error mode" 5324 ): 5325 vmap(fn, in_dims=in_dims, randomness="error")(*args) 5326 5327 def _assert_throws_in_different_mode_inplace(self, fn, args, in_dims): 5328 with self.assertRaisesRegex( 5329 RuntimeError, r"different inplace randomness on an unbatched tensor" 5330 ): 5331 vmap(fn, in_dims=in_dims, randomness="different")(*args) 5332 5333 def _assert_throws_in_same_mode_batched(self, fn, args, in_dims): 5334 with self.assertRaisesRegex( 5335 RuntimeError, 5336 r"Vmap does not currently support same randomness with a batched tensor input", 5337 ): 5338 vmap(fn, in_dims=in_dims, randomness="same")(*args) 5339 5340 def _in_dims(self, *batched_strings): 5341 def get_in_dim(batched_string): 5342 if batched_string == "first": 5343 return 0 5344 if batched_string == "last": 5345 return -1 5346 assert batched_string == "none" 5347 return None 5348 5349 batched_strings = batched_strings + ( 5350 "first", 5351 ) # for the always batched as first dim dummy argument 5352 return tuple(get_in_dim(batched_string) for batched_string in batched_strings) 5353 5354 @parametrize("randomness", ["same", "different", "error"]) 5355 @parametrize("use_generator", [True, False]) 5356 def test_factory_ops(self, device, randomness, use_generator): 5357 generator = torch.Generator(device=device) 5358 orig_state = generator.get_state() 5359 kwargs = ( 5360 {"device": device, "generator": generator} 5361 if use_generator 5362 else {"device": device} 5363 ) 5364 ops = [ 5365 lambda _, shape: torch.randn(shape, **kwargs), 5366 lambda _, shape: torch.rand(shape, **kwargs), 5367 lambda _, shape: torch.randint(100, shape, **kwargs), 5368 lambda _, shape: torch.randint(5, 100, shape, **kwargs), 5369 lambda _, shape: torch.normal(0.0, 1.0, shape, **kwargs), 5370 ] 5371 B0 = 4 5372 shape = (3, 3) 5373 seed = 1234567 5374 5375 for op in ops: 5376 passed = torch.randn(B0, device=device) 5377 if randomness == "error": 5378 self._assert_throws_in_error_mode( 5379 op, (passed, shape), in_dims=(0, None) 5380 ) 5381 return 5382 5383 generator = self._reset_random(generator, orig_state, use_generator, seed) 5384 vmap_result = vmap(op, in_dims=(0, None), randomness=randomness)( 5385 passed, shape 5386 ) 5387 5388 generator = self._reset_random(generator, orig_state, use_generator, seed) 5389 if randomness == "different": 5390 expected = op(passed, [B0, *shape]) 5391 self._assert_all_slices_unique(vmap_result) 5392 self.assertEqual(vmap_result, expected) 5393 else: 5394 expected = op(passed, shape) 5395 self._assert_all_slices_equal(vmap_result) 5396 for i in range(B0): 5397 self.assertEqual(vmap_result[i], expected) 5398 5399 @parametrize("randomness", ["same", "different", "error"]) 5400 @parametrize("use_generator", [True, False]) 5401 def test_randperm(self, device, randomness, use_generator): 5402 # needs a special case because randperm doesn't take a batch size 5403 B0 = 4 5404 seed = 1234567 5405 passed = torch.randn(B0, device=device) 5406 5407 torch.manual_seed(seed) 5408 generator = torch.Generator(device=device) 5409 orig_state = generator.get_state() 5410 5411 kwargs = ( 5412 {"device": device, "generator": generator} 5413 if use_generator 5414 else {"device": device} 5415 ) 5416 5417 if randomness == "error": 5418 with self.assertRaisesRegex( 5419 RuntimeError, r"called random operation while in randomness error mode" 5420 ): 5421 vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)( 5422 passed 5423 ) 5424 return 5425 5426 vmap_result = vmap( 5427 lambda _: torch.randperm(10, **kwargs), randomness=randomness 5428 )(passed) 5429 generator = generator.set_state(orig_state) 5430 torch.manual_seed(seed) 5431 if randomness == "different": 5432 for i in range(B0): 5433 expected = torch.randperm(10, **kwargs) 5434 # RNG differs between eager and via dynamo trace on CUDA 5435 if TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda": 5436 self._assert_all_slices_unique(vmap_result) 5437 else: 5438 self.assertEqual(vmap_result[i], expected) 5439 else: 5440 expected = torch.randperm(10, **kwargs) 5441 # RNG differs between eager and via dynamo trace on CUDA 5442 if TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda": 5443 self._assert_all_slices_equal(vmap_result) 5444 else: 5445 for i in range(B0): 5446 self.assertEqual(vmap_result[i], expected) 5447 5448 @parametrize("randomness", ["error", "same", "different"]) 5449 @parametrize("batched_input", ["first", "last", "none"]) 5450 def test_dropout(self, device, randomness, batched_input): 5451 def op(t, ignored): 5452 return torch.nn.functional.dropout(torch.ones_like(t), training=True) 5453 5454 B0 = 4 5455 always_batched = torch.randn((B0,)) 5456 passed = self._get_image(batched_input, B0, device) 5457 in_dims = self._in_dims(batched_input) 5458 5459 if randomness == "error": 5460 with self.assertRaisesRegex( 5461 RuntimeError, r"called random operation while in randomness error mode" 5462 ): 5463 vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) 5464 return 5465 5466 vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)( 5467 passed, always_batched 5468 ) 5469 5470 # Check that the randomness is within bounds... 5471 # ideally this is close to 0.5 5472 p_estimate = vmap_result.mean() / 2 5473 self.assertTrue(p_estimate < 0.75) 5474 self.assertTrue(p_estimate > 0.25) 5475 5476 if randomness == "different": 5477 self._assert_all_slices_unique(vmap_result) 5478 return 5479 5480 assert randomness == "same" 5481 self._assert_all_slices_equal(vmap_result) 5482 5483 @parametrize("randomness", ["error", "same", "different"]) 5484 @parametrize("batched_input", ["first", "last", "none"]) 5485 def test_alpha_dropout(self, device, randomness, batched_input): 5486 def op(t, ignored): 5487 return torch.nn.functional.alpha_dropout(torch.ones_like(t), training=True) 5488 5489 B0 = 4 5490 always_batched = torch.randn((B0,)) 5491 passed = self._get_image(batched_input, B0, device) 5492 in_dims = self._in_dims(batched_input) 5493 5494 if randomness == "error": 5495 with self.assertRaisesRegex( 5496 RuntimeError, r"called random operation while in randomness error mode" 5497 ): 5498 vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) 5499 return 5500 5501 # I have no clue how to actually test correctness of alpha dropout because the docs 5502 # seem wrong: https://github.com/pytorch/pytorch/issues/74004 5503 vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)( 5504 passed, always_batched 5505 ) 5506 if randomness == "different": 5507 self._assert_all_slices_unique(vmap_result) 5508 return 5509 5510 assert randomness == "same" 5511 self._assert_all_slices_equal(vmap_result) 5512 5513 @parametrize("randomness", ["error", "same", "different"]) 5514 @parametrize("batched_input", ["first", "last", "none"]) 5515 @parametrize("dim", [2, 3]) 5516 def test_feature_dropout(self, device, randomness, batched_input, dim): 5517 def op(t, ignored): 5518 f = ( 5519 torch.nn.functional.dropout2d 5520 if dim == 2 5521 else torch.nn.functional.dropout3d 5522 ) 5523 return f(torch.ones_like(t), training=True) 5524 5525 B0 = 4 5526 always_batched = torch.randn((B0,)) 5527 passed = self._get_image(batched_input, B0, device) 5528 if dim == 3: 5529 unsqueeze_dim = -2 if batched_input == "last" else -1 5530 passed = passed.unsqueeze(unsqueeze_dim) 5531 in_dims = self._in_dims(batched_input) 5532 5533 if randomness == "error": 5534 with self.assertRaisesRegex( 5535 RuntimeError, r"called random operation while in randomness error mode" 5536 ): 5537 vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) 5538 return 5539 5540 vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)( 5541 passed, always_batched 5542 ) 5543 5544 # Check the "feature" pattern 5545 dims = [-1, -2] if dim == 2 else [-1, -2, -3] 5546 planes_numel = ( 5547 2 5548 * vmap_result.numel() 5549 / (vmap_result.shape[0] * vmap_result.shape[1] * vmap_result.shape[2]) 5550 ) 5551 planes = vmap_result.sum(dims) 5552 result = (planes == 0) ^ (planes == planes_numel) 5553 self.assertEqual(result, torch.ones_like(result, dtype=torch.bool)) 5554 5555 if randomness == "different": 5556 self._assert_all_slices_unique(vmap_result) 5557 return 5558 5559 assert randomness == "same" 5560 self._assert_all_slices_equal(vmap_result) 5561 5562 @parametrize("randomness", ["error", "same", "different"]) 5563 @parametrize("batched_input", ["first", "last", "none"]) 5564 def test_feature_alpha_dropout(self, device, randomness, batched_input): 5565 def op(t, ignored): 5566 return torch.nn.functional.feature_alpha_dropout( 5567 torch.ones_like(t), training=True 5568 ) 5569 5570 B0 = 4 5571 always_batched = torch.randn((B0,)) 5572 passed = self._get_image(batched_input, B0, device) 5573 unsqueeze_dim = -2 if batched_input == "last" else -1 5574 passed = passed.unsqueeze(unsqueeze_dim) 5575 in_dims = self._in_dims(batched_input) 5576 5577 if randomness == "error": 5578 with self.assertRaisesRegex( 5579 RuntimeError, r"called random operation while in randomness error mode" 5580 ): 5581 vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) 5582 return 5583 5584 vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)( 5585 passed, always_batched 5586 ) 5587 5588 # I have no clue how to actually test correctness of alpha dropout because the docs 5589 # seem wrong: https://github.com/pytorch/pytorch/issues/74004 5590 5591 # Check the "feature" pattern 5592 dims = [-1, -2, -3] 5593 planes = vmap_result.sum(dims) 5594 max_elt = planes.max() 5595 min_elt = planes.min() 5596 result = (planes == min_elt) ^ (planes == max_elt) 5597 self.assertEqual(result, torch.ones_like(result, dtype=torch.bool)) 5598 5599 if randomness == "different": 5600 self._assert_all_slices_unique(vmap_result) 5601 return 5602 5603 assert randomness == "same" 5604 self._assert_all_slices_equal(vmap_result) 5605 5606 @parametrize("randomness", ["error", "same", "different"]) 5607 @parametrize("batched_input", ["first", "last", "none"]) 5608 def test_like_functions(self, device, randomness, batched_input): 5609 seed = 1234567 5610 supported_ops = [ 5611 lambda t, _: torch.randint_like(t, 20), 5612 lambda t, _: torch.randint_like(t, 0, 20), 5613 lambda t, _: torch.rand_like(t), 5614 lambda t, _: torch.randn_like(t), 5615 ] 5616 B0 = 4 5617 5618 for op in supported_ops: 5619 always_batched = torch.randn(B0) 5620 passed = self._get_image(batched_input, B0, device) 5621 in_dims = self._in_dims(batched_input) 5622 5623 if randomness == "error": 5624 with self.assertRaisesRegex( 5625 RuntimeError, 5626 r"called random operation while in randomness error mode", 5627 ): 5628 vmap(op, in_dims=in_dims, randomness=randomness)( 5629 passed, always_batched 5630 ) 5631 return 5632 5633 torch.manual_seed(seed) 5634 vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)( 5635 passed, always_batched 5636 ) 5637 5638 torch.manual_seed(seed) 5639 5640 if batched_input == "last": 5641 passed = passed.movedim(-1, 0) 5642 if randomness == "different": 5643 if batched_input == "none": 5644 passed = passed.expand(B0, *passed.shape) 5645 expected = op(passed, 0) 5646 5647 self._assert_all_slices_unique(vmap_result) 5648 # RNG differs between eager and via dynamo trace on CUDA 5649 if not (TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda"): 5650 self.assertEqual(expected, vmap_result) 5651 return 5652 5653 assert randomness == "same" 5654 if batched_input != "none": 5655 passed = passed[0] 5656 expected = op(passed, 0) 5657 self._assert_all_slices_equal(vmap_result) 5658 # RNG differs between eager and via dynamo trace on CUDA 5659 if not (TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda"): 5660 for i in range(B0): 5661 self.assertEqual(expected, vmap_result[i]) 5662 5663 @parametrize("use_generator", [True, False]) 5664 @parametrize("randomness", ["error", "same", "different"]) 5665 @parametrize("batched_input", ["first", "last", "none"]) 5666 def test_random_unary_inplace( 5667 self, device, use_generator, randomness, batched_input 5668 ): 5669 generator = torch.Generator(device=device) 5670 orig_state = generator.get_state() 5671 kwargs = {"generator": generator} if use_generator else {} 5672 ops = [ 5673 lambda t, _: t.random_(**kwargs), 5674 lambda t, _: t.random_(100, **kwargs), 5675 lambda t, _: t.random_(-5, 100, **kwargs), 5676 lambda t, _: t.normal_(**kwargs), 5677 lambda t, _: t.bernoulli_(**kwargs), 5678 lambda t, _: t.cauchy_(**kwargs), 5679 lambda t, _: t.exponential_(**kwargs), 5680 lambda t, _: t.geometric_(0.5, **kwargs), 5681 lambda t, _: t.log_normal_(**kwargs), 5682 lambda t, _: t.uniform_(**kwargs), 5683 ] 5684 B0 = 4 5685 seed = 1234567 5686 in_dims = self._in_dims(batched_input) 5687 5688 for op in ops: 5689 # because of in place updates, clone inputs 5690 always_batched = torch.randn(B0, device=device) 5691 passed = self._get_image(batched_input, B0, device) 5692 passed_expected = passed.clone() 5693 5694 if randomness == "error": 5695 self._assert_throws_in_error_mode( 5696 op, (passed, always_batched), in_dims=in_dims 5697 ) 5698 return 5699 if randomness == "different" and batched_input == "none": 5700 self._assert_throws_in_different_mode_inplace( 5701 op, (passed, always_batched), in_dims=in_dims 5702 ) 5703 return 5704 5705 generator = self._reset_random(generator, orig_state, use_generator, seed) 5706 vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)( 5707 passed, always_batched 5708 ) 5709 5710 if batched_input == "last": 5711 passed_expected = passed_expected.movedim(-1, 0) 5712 generator = self._reset_random(generator, orig_state, use_generator, seed) 5713 if randomness == "different": 5714 expected = op(passed_expected, always_batched) 5715 self._assert_all_slices_unique(vmap_result) 5716 self.assertEqual(vmap_result, expected) 5717 else: 5718 if batched_input != "none": 5719 passed_expected = passed_expected[ 5720 0 5721 ].clone() # bug in pytorch, normal_ on views doesn't work 5722 expected = op(passed_expected, always_batched) 5723 self._assert_all_slices_equal(vmap_result) 5724 for i in range(B0): 5725 self.assertEqual(vmap_result[i], expected) 5726 5727 @parametrize("use_generator", [True, False]) 5728 @parametrize("randomness", ["error", "same", "different"]) 5729 @parametrize("batched_input", ["first", "last", "none"]) 5730 @parametrize("batched_probability", ["first", "last", "none"]) 5731 def test_bernoulli_in_place( 5732 self, device, use_generator, randomness, batched_input, batched_probability 5733 ): 5734 B0 = 4 5735 seed = 1234567 5736 generator = torch.Generator(device=device) 5737 orig_state = generator.get_state() 5738 kwargs = {"generator": generator} if use_generator else {} 5739 in_dims = self._in_dims(batched_input, batched_probability) 5740 5741 def op(t, p, ignored): 5742 return t.bernoulli_(p, **kwargs) 5743 5744 # because of in place updates, clone inputs 5745 always_batched = torch.randn(B0, device=device) 5746 input = self._get_image(batched_input, B0, device) 5747 input_expected = input.clone() 5748 probability = self._get_image(batched_probability, B0, device) - 0.5 5749 5750 if randomness == "error": 5751 self._assert_throws_in_error_mode( 5752 op, (input, probability, always_batched), in_dims=in_dims 5753 ) 5754 return 5755 if randomness == "same" and batched_probability != "none": 5756 self._assert_throws_in_same_mode_batched( 5757 op, (input, probability, always_batched), in_dims=in_dims 5758 ) 5759 return 5760 if batched_input == "none" and batched_probability != "none": 5761 regex = r"there exists a Tensor `other` in extra_args that has more elements than `self`" 5762 with self.assertRaisesRegex(RuntimeError, regex): 5763 vmap(op, in_dims=in_dims, randomness=randomness)( 5764 input, probability, always_batched 5765 ) 5766 return 5767 if randomness == "different" and batched_input == "none": 5768 self._assert_throws_in_different_mode_inplace( 5769 op, (input, probability, always_batched), in_dims=in_dims 5770 ) 5771 return 5772 5773 self._reset_random(generator, orig_state, use_generator, seed) 5774 vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)( 5775 input, probability, always_batched 5776 ) 5777 5778 self._reset_random(generator, orig_state, use_generator, seed) 5779 if batched_input == "last": 5780 input_expected = input_expected.movedim(-1, 0) 5781 if batched_probability == "last": 5782 probability = probability.movedim(-1, 0) 5783 if randomness == "different": 5784 expected = op(input_expected, probability, always_batched) 5785 self._assert_all_slices_unique(vmap_result) 5786 self.assertEqual(vmap_result, expected) 5787 else: 5788 if batched_input != "none": 5789 input_expected = input_expected[0] 5790 expected = op(input_expected, probability, always_batched) 5791 self._assert_all_slices_equal(vmap_result) 5792 for i in range(B0): 5793 self.assertEqual(vmap_result[i], expected) 5794 5795 @parametrize("use_generator", [True, False]) 5796 @parametrize("randomness", ["error", "same", "different"]) 5797 @parametrize("batched_input", ["first", "last", "none"]) 5798 @parametrize("batched_other", ["first", "last", "none"]) 5799 def test_random_binary_out_of_place( 5800 self, device, use_generator, randomness, batched_input, batched_other 5801 ): 5802 generator = torch.Generator(device=device) 5803 orig_state = generator.get_state() 5804 kwargs = {"generator": generator} if use_generator else {} 5805 ops = [ 5806 lambda t, o, _: torch.normal(t, o, **kwargs), 5807 lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs), 5808 ] 5809 5810 B0 = 4 5811 seed = 1234567 5812 in_dims = self._in_dims(batched_input, batched_other) 5813 5814 for op in ops: 5815 always_batched = torch.randn(B0, device=device) 5816 input = self._get_image(batched_input, B0, device) 5817 other = self._get_image(batched_other, B0, device) 5818 5819 if randomness == "error": 5820 self._assert_throws_in_error_mode( 5821 op, (input, other, always_batched), in_dims=in_dims 5822 ) 5823 return 5824 if randomness == "same" and ( 5825 batched_input != "none" or batched_other != "none" 5826 ): 5827 self._assert_throws_in_same_mode_batched( 5828 op, (input, other, always_batched), in_dims=in_dims 5829 ) 5830 return 5831 5832 generator = self._reset_random(generator, orig_state, use_generator, seed) 5833 vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)( 5834 input, other, always_batched 5835 ) 5836 5837 if batched_input == "last": 5838 input = input.movedim(-1, 0) 5839 if batched_other == "last": 5840 other = other.movedim(-1, 0) 5841 5842 generator = self._reset_random(generator, orig_state, use_generator, seed) 5843 if randomness == "different": 5844 if batched_input == "none": 5845 input = input.expand(B0, *input.shape) 5846 expected = op(input, other, always_batched) 5847 self._assert_all_slices_unique(vmap_result) 5848 self.assertEqual(vmap_result, expected) 5849 else: 5850 assert batched_input == "none" and batched_other == "none" 5851 expected = op(input, other, always_batched) 5852 self._assert_all_slices_equal(vmap_result) 5853 for i in range(B0): 5854 self.assertEqual(vmap_result[i], expected) 5855 5856 @parametrize("use_generator", [True, False]) 5857 @parametrize("randomness", ["error", "same", "different"]) 5858 @parametrize("batched_input", ["first", "last", "none"]) 5859 def test_random_unary_out_of_place( 5860 self, device, use_generator, randomness, batched_input 5861 ): 5862 generator = torch.Generator(device=device) 5863 orig_state = generator.get_state() 5864 kwargs = {"generator": generator} if use_generator else {} 5865 ops = [ 5866 lambda t, _: torch.normal(0.0, torch.abs(t), **kwargs), 5867 lambda t, _: torch.normal(t, 1.0, **kwargs), 5868 lambda t, _: torch.bernoulli(t - 0.5, **kwargs), 5869 lambda t, _: torch.bernoulli(t, 0.5, **kwargs), 5870 lambda t, _: torch._standard_gamma(t, **kwargs), 5871 lambda t, _: torch._sample_dirichlet(t, **kwargs), 5872 lambda t, _: torch.poisson(t, **kwargs), 5873 ] 5874 5875 B0 = 4 5876 seed = 1234567 5877 in_dims = self._in_dims(batched_input) 5878 5879 for op in ops: 5880 always_batched = torch.randn(B0, device=device) 5881 passed = self._get_image(batched_input, B0, device) 5882 if randomness == "error": 5883 self._assert_throws_in_error_mode( 5884 op, (passed, always_batched), in_dims=in_dims 5885 ) 5886 return 5887 if randomness == "same" and batched_input != "none": 5888 self._assert_throws_in_same_mode_batched( 5889 op, (passed, always_batched), in_dims=in_dims 5890 ) 5891 return 5892 5893 generator = self._reset_random(generator, orig_state, use_generator, seed) 5894 vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)( 5895 passed, always_batched 5896 ) 5897 5898 generator = self._reset_random(generator, orig_state, use_generator, seed) 5899 if randomness == "different": 5900 if batched_input == "none": 5901 passed = passed.expand(B0, *passed.shape) 5902 if batched_input == "last": 5903 passed = passed.movedim(-1, 0) 5904 expected = op(passed, always_batched) 5905 self._assert_all_slices_unique(vmap_result) 5906 self.assertEqual(vmap_result, expected) 5907 else: 5908 expected = op(passed, always_batched) 5909 self._assert_all_slices_equal(vmap_result) 5910 for i in range(B0): 5911 self.assertEqual(vmap_result[i], expected) 5912 5913 @parametrize("use_generator", [True, False]) 5914 @parametrize("randomness", ["error", "same", "different"]) 5915 @parametrize("batched_call", [True, False]) 5916 @parametrize("batched_input", ["first", "last", "none"]) 5917 def test_multinomial( 5918 self, device, use_generator, randomness, batched_call, batched_input 5919 ): 5920 def flatten_input(input, batch_call, batch_location): 5921 if batch_call and batch_location != "none": 5922 final_size = 3 # [B0, B, N] 5923 elif not batch_call and batch_location == "none": 5924 final_size = 1 # [N] 5925 else: 5926 final_size = 2 # [B0, N] or [B, N] 5927 5928 start_idx = final_size - 1 5929 end_idx = -1 5930 if batch_location == "last": 5931 start_idx -= 1 5932 end_idx -= ( 5933 1 # gets to correct final size because using negative indices 5934 ) 5935 5936 ret = input.flatten(start_idx, end_idx) 5937 assert ret.dim() == final_size 5938 return ret 5939 5940 def op(input, _): 5941 return torch.multinomial(input, 10, **kwargs) 5942 5943 generator = torch.Generator(device=device) 5944 orig_state = generator.get_state() 5945 kwargs = {"generator": generator} if use_generator else {} 5946 5947 B0 = 4 5948 seed = 1234567 5949 in_dims = self._in_dims(batched_input) 5950 5951 always_batched = torch.randn(B0, device=device) 5952 passed = self._get_image(batched_input, B0, device) 5953 passed = flatten_input(passed, batched_call, batched_input) 5954 if randomness == "error": 5955 self._assert_throws_in_error_mode( 5956 op, (passed, always_batched), in_dims=in_dims 5957 ) 5958 return 5959 if randomness == "same" and batched_input != "none": 5960 self._assert_throws_in_same_mode_batched( 5961 op, (passed, always_batched), in_dims=in_dims 5962 ) 5963 return 5964 5965 generator = self._reset_random(generator, orig_state, use_generator, seed) 5966 vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)( 5967 passed, always_batched 5968 ) 5969 5970 generator = self._reset_random(generator, orig_state, use_generator, seed) 5971 5972 if randomness == "different": 5973 if batched_input == "none": 5974 passed = passed.expand(B0, *passed.shape) 5975 if batched_input == "last": 5976 passed = passed.movedim(-1, 0) 5977 orig_passed_size = passed.shape[:2] if batched_call else passed.shape[:1] 5978 passed = passed.flatten(0, 1) if batched_call else passed 5979 expected = op(passed, always_batched) 5980 expected = expected.reshape(*orig_passed_size, 10) 5981 self._assert_all_slices_unique(vmap_result) 5982 self.assertEqual(vmap_result, expected) 5983 else: 5984 expected = op(passed, always_batched) 5985 self._assert_all_slices_equal(vmap_result) 5986 for i in range(B0): 5987 self.assertEqual(vmap_result[i], expected) 5988 5989 def test_unsupported_random(self, device): 5990 x = torch.randn(3, device=device) 5991 y = x.abs() 5992 z = x.abs() 5993 with self.assertRaisesRegex(RuntimeError, "calling out variants"): 5994 5995 def f(x): 5996 return torch.randn(3, device=device, out=y) 5997 5998 vmap(f, randomness="same")(x) 5999 with self.assertRaisesRegex(RuntimeError, "calling out variants"): 6000 6001 def f(x0, x1): 6002 return torch.normal(x, y, out=x) 6003 6004 vmap(f, randomness="same")(z, z) 6005 with self.assertRaisesRegex(RuntimeError, "do not yet support"): 6006 6007 def f(z): 6008 return torch.rrelu(x) 6009 6010 vmap(f, randomness="same")(z) 6011 6012 @parametrize("in_dim", [0, 1, 2]) 6013 @parametrize("out_dim", [0, 1, 2]) 6014 def test_chunk_vmap(self, in_dim, out_dim): 6015 randomness = "different" 6016 6017 x = torch.randn(4, 5, 6) 6018 6019 def f(x): 6020 y = x.sin() + torch.rand_like(x) 6021 return y 6022 6023 for chunks in [1, 2, 3, 4, 7, 10, 16]: 6024 output = chunk_vmap( 6025 f, 6026 in_dims=in_dim, 6027 out_dims=out_dim, 6028 randomness=randomness, 6029 chunks=chunks, 6030 )(x) 6031 self._assert_all_slices_unique(output) 6032 6033 @parametrize("in_dim", [0, 1, 2]) 6034 @parametrize("out_dim", [0, 1, 2]) 6035 def test_vmap_chunksize(self, in_dim, out_dim): 6036 randomness = "different" 6037 6038 x = torch.randn(4, 5, 6) 6039 6040 def f(x): 6041 y = x.sin() + torch.rand_like(x) 6042 return y 6043 6044 for chunk_size in [1, 2, 3, 4, 7, 10, 16, 100]: 6045 output = vmap( 6046 f, 6047 in_dims=in_dim, 6048 out_dims=out_dim, 6049 randomness=randomness, 6050 chunk_size=chunk_size, 6051 )(x) 6052 self._assert_all_slices_unique(output) 6053 6054 def test_jacfwd_with_random(self): 6055 # checks on behavior are above, this just checks that jacfwd respects 6056 # the randomness param 6057 6058 x = torch.rand(3, 4) 6059 with self.assertRaisesRegex( 6060 RuntimeError, r"called random operation while in randomness error mode" 6061 ): 6062 jacfwd(torch.bernoulli)(x) 6063 6064 # x isn't batched so use bernoulli since it doesn't do inplace randomness 6065 jacfwd(torch.bernoulli, randomness="same")(x) 6066 jacfwd(torch.bernoulli, randomness="different")(x) 6067 6068 @parametrize("randomness", ["error", "same", "different"]) 6069 def test_dropout_unbatched(self, device, randomness): 6070 x = torch.randn(3, device=device) 6071 y = torch.randn(1, 3, device=device) 6072 6073 def fn(x, y): 6074 # output from dropout should be a Tensor[B, 1, 3] (B=3) 6075 return x + torch.nn.functional.dropout(y, p=0.5).mean(1) 6076 6077 # We just verify that this doesn't raise an error for 6078 # `same` and `different` randomness. 6079 # Ref: https://github.com/pytorch/pytorch/issues/92283 6080 context = ( 6081 self.assertRaises(RuntimeError) 6082 if randomness == "error" 6083 else contextlib.nullcontext() 6084 ) 6085 with context: 6086 vmap(fn, in_dims=(0, None), randomness=randomness)(x, y) 6087 6088 6089@markDynamoStrictTest 6090class TestTransformFailure(TestCase): 6091 @skipIfTorchDynamo() 6092 @parametrize( 6093 "transform", 6094 ["vmap", "grad", "grad_and_value", "vjp", "jvp", "jacrev", "jacfwd"], 6095 ) 6096 def test_fails_with_autograd_function(self, device, transform): 6097 failed_build_envs = ("linux-focal-py3.8-clang10", "linux-focal-py3.11-clang10") 6098 if ( 6099 device == "cpu" 6100 and transform in ["grad", "vmap"] 6101 and TEST_WITH_TORCHDYNAMO 6102 and os.getenv("BUILD_ENVIRONMENT", "") in failed_build_envs 6103 ): 6104 raise unittest.SkipTest( 6105 "Unexpected successes on focal with dynamo," 6106 + " see https://github.com/pytorch/pytorch/issues/107173" 6107 ) 6108 6109 class Test(torch.autograd.Function): 6110 @staticmethod 6111 def forward(_, input): 6112 return input 6113 6114 @staticmethod 6115 def backward(_, grad_input): 6116 return grad_input 6117 6118 transform = getattr(functorch, transform) 6119 6120 def f(x): 6121 return Test.apply(x) 6122 6123 if transform in (grad, grad_and_value): 6124 input = torch.tensor(4.0) 6125 else: 6126 input = torch.randn(5) 6127 6128 if transform == vjp: 6129 transform = functools.partial(transform, f) 6130 elif transform == jvp: 6131 input = (input,) 6132 transform = functools.partial(transform, f, input) 6133 else: 6134 transform = transform(f) 6135 6136 with self.assertRaisesRegex(RuntimeError, "autograd.Function"): 6137 transform(input) 6138 6139 6140@markDynamoStrictTest 6141class TestVmapDeviceType(Namespace.TestVmapBase): 6142 def _vmap_test(self, *args, **kwargs): 6143 return _vmap_test(self, *args, **kwargs) 6144 6145 def test__is_all_true(self, device): 6146 def test(): 6147 def f(x, *, expected_result): 6148 result = torch.ops.aten._is_all_true(x) 6149 self.assertFalse(torch._C._functorch.is_batchedtensor(result)) 6150 self.assertEqual(result.shape, torch.Size([])) 6151 self.assertEqual(result.item(), expected_result) 6152 return result 6153 6154 x = torch.rand(10, device=device) 6155 vmap(f)(x >= 0, expected_result=True) 6156 vmap(f)(x < 0, expected_result=False) 6157 6158 x[random.choice(range(10))] *= -1 6159 vmap(f)(x >= 0, expected_result=False) 6160 vmap(f)(x < 0, expected_result=False) 6161 6162 x = -torch.rand(10, device=device) 6163 vmap(f)(x > 0, expected_result=False) 6164 vmap(f)(x <= 0, expected_result=True) 6165 6166 check_vmap_fallback(self, test, torch._is_all_true) 6167 6168 def test__is_any_true(self, device): 6169 def test(): 6170 def f(x, *, expected_result): 6171 result = torch.ops.aten._is_any_true(x) 6172 self.assertFalse(torch._C._functorch.is_batchedtensor(result)) 6173 self.assertEqual(result.shape, torch.Size([])) 6174 self.assertEqual(result.item(), expected_result) 6175 return result 6176 6177 x = torch.zeros(10, device=device, dtype=torch.bool) 6178 vmap(f)(x > 0, expected_result=False) 6179 6180 x[5] = True 6181 vmap(f)(x > 0, expected_result=True) 6182 vmap(f)(x[1::2], expected_result=True) 6183 vmap(f)(x[0::2], expected_result=False) 6184 6185 check_vmap_fallback(self, test, torch._is_any_true) 6186 6187 def test_check_tensor(self, device): 6188 def test(): 6189 test_sizes = [ 6190 (1,), 6191 (10,), 6192 (1, 1), 6193 (1, 10), 6194 (10, 1), 6195 (10, 10), 6196 (1, 1, 1), 6197 (10, 1, 1), 6198 (1, 10, 1), 6199 (10, 10, 10), 6200 ] 6201 6202 def check_gte_0(t): 6203 return torch._test_check_tensor(t >= 0) 6204 6205 error_message = "Test message for TORCH_CHECK_TENSOR_ALL" 6206 6207 for size in test_sizes: 6208 t_all_gte_0 = torch.rand(size, device=device) 6209 t_all_lt_0 = t_all_gte_0 - 1 6210 6211 vmap(check_gte_0)(t_all_gte_0) 6212 6213 if len(size) >= 2: 6214 vmap(vmap(check_gte_0))(t_all_gte_0) 6215 6216 with self.assertRaisesRegex(RuntimeError, error_message): 6217 vmap(check_gte_0)(t_all_lt_0) 6218 6219 if len(size) >= 2: 6220 with self.assertRaisesRegex(RuntimeError, error_message): 6221 vmap(vmap(check_gte_0))(t_all_lt_0) 6222 6223 if t_all_gte_0.numel() > 1: 6224 t_all_gte_0_but_one = t_all_gte_0.clone() 6225 idx = (random.choice(range(dim_size)) for dim_size in size) 6226 t_all_gte_0_but_one[(..., *idx)] = -1 6227 6228 with self.assertRaisesRegex(RuntimeError, error_message): 6229 vmap(check_gte_0)(t_all_gte_0_but_one) 6230 6231 if len(size) >= 2: 6232 with self.assertRaisesRegex(RuntimeError, error_message): 6233 vmap(vmap(check_gte_0))(t_all_gte_0_but_one) 6234 6235 check_vmap_fallback(self, test, torch._test_check_tensor) 6236 6237 6238@markDynamoStrictTest 6239class TestVmapNestedTensor(Namespace.TestVmapBase): 6240 def _vmap_test(self, *args, **kwargs): 6241 return _vmap_test(self, *args, **kwargs) 6242 6243 # dims should be something like [5, None, 10], with None indicating that a 6244 # random ragged structure should be used 6245 def _create_nt(self, dims, device): 6246 sizes = [ 6247 [ 6248 d if d is not None else torch.randint(2, 10, size=(1,)).item() 6249 for d in dims[1:] 6250 ] 6251 for d in range(dims[0]) 6252 ] 6253 return torch.nested.nested_tensor( 6254 [torch.randn(*size) for size in sizes], device=device 6255 ) 6256 6257 # Creates an NT matching another NT's number of components and 6258 # shape / ragged structure for all dims specified to be -1. 6259 def _nt_from_similar(self, other, dims): 6260 assert len(dims) == other.dim() 6261 assert dims[0] == -1 or dims[0] == other.size(0) 6262 6263 ret_sizes = [] 6264 for t in other.unbind(): 6265 other_size = t.shape 6266 ret_size = [] 6267 for i, d in enumerate(dims[1:]): 6268 if d == -1: 6269 ret_size.append(other_size[i]) 6270 else: 6271 ret_size.append(d) 6272 ret_sizes.append(ret_size) 6273 6274 return torch.nested.nested_tensor( 6275 [torch.randn(*size) for size in ret_sizes], device=other.device 6276 ) 6277 6278 @allowVmapFallbackUsage 6279 def test_fallback_unary(self, device): 6280 def f(x): 6281 return x.sin() * 5.0 + 4.0 6282 6283 nt = self._create_nt([4, None, 3], device=device) 6284 self._vmap_test(f, (nt,)) 6285 6286 @allowVmapFallbackUsage 6287 def test_fallback_binary(self, device): 6288 def f(x, y): 6289 return x @ y 6290 6291 x = self._create_nt([5, None, 3], device=device) 6292 y = self._create_nt([5, 3, None], device=device) 6293 self._vmap_test(f, (x, y)) 6294 6295 @allowVmapFallbackUsage 6296 def test_fallback_binary_nt_and_unbatched_dense(self, device): 6297 def f(x, y): 6298 return x @ y 6299 6300 x = self._create_nt([5, None, 3], device=device) 6301 y = torch.randn(3, 4, device=device) 6302 self._vmap_test(f, (x, y), in_dims=(0, None)) 6303 6304 @allowVmapFallbackUsage 6305 def test_fallback_binary_nt_and_batched_dense(self, device): 6306 def f(x, y): 6307 return x @ y 6308 6309 x = self._create_nt([5, None, 3], device=device) 6310 y = torch.randn(5, 3, 4, device=device) 6311 self._vmap_test(f, (x, y)) 6312 6313 def test_nt_acts_as_dense_in_vmap(self, device): 6314 def f(x): 6315 assert not x.is_nested 6316 return x 6317 6318 x = self._create_nt([5, None, 3], device=device) 6319 self._vmap_test(f, (x,)) 6320 6321 def test_cat_batching_rule(self, device): 6322 def f(x, y, dim): 6323 return torch.cat([x, y], dim=dim) 6324 6325 # Different nested structure, same other dims 6326 x = self._create_nt([3, None, 2], device=device) 6327 y = self._create_nt([3, None, 2], device=device) 6328 self._vmap_test(functools.partial(f, dim=0), (x, y)) 6329 6330 x = self._create_nt([3, 2, None], device=device) 6331 y = self._create_nt([3, 2, None], device=device) 6332 self._vmap_test(functools.partial(f, dim=1), (x, y)) 6333 6334 # Same nested structure, different other dims 6335 x = self._create_nt([3, 2, None], device=device) 6336 y = self._nt_from_similar(x, [-1, 4, -1]) 6337 self._vmap_test(functools.partial(f, dim=0), (x, y)) 6338 6339 x = self._create_nt([3, None, 2], device=device) 6340 y = self._nt_from_similar(x, [-1, -1, 4]) 6341 self._vmap_test(functools.partial(f, dim=1), (x, y)) 6342 6343 # .shape calls don't work on NTs 6344 # TODO: Fix this somehow? 6345 @unittest.expectedFailure 6346 def test_shape_call(self, device): 6347 def f(x): 6348 x.shape[0] 6349 return x 6350 6351 x = self._create_nt([3, None, 2]) 6352 self._vmap_test(f, (x,)) 6353 6354 def test_nt_with_nonzero_in_dim_raises(self, device): 6355 def f(x): 6356 return x 6357 6358 x = self._create_nt([3, None, 2], device=device) 6359 with self.assertRaisesRegex( 6360 RuntimeError, "Nested tensors can only be vmapped over dim=0" 6361 ): 6362 vmap(f, in_dims=2)(x) 6363 6364 def test_nt_with_nonzero_out_dim_raises(self, device): 6365 def f(x): 6366 return x 6367 6368 x = self._create_nt([3, None, 2], device=device) 6369 with self.assertRaisesRegex( 6370 RuntimeError, "Nested tensors can only be vmapped over dim=0" 6371 ): 6372 vmap(f, out_dims=2)(x) 6373 6374 def test_fallback_with_nt_and_batched_dense_with_nonzero_bdim_raises(self, device): 6375 def f(x, y): 6376 return x @ y 6377 6378 x = self._create_nt([5, None, 3], device=device) 6379 y = torch.randn(3, 5, 4, device=device) 6380 6381 with self.assertRaisesRegex( 6382 RuntimeError, 6383 "Fallback not supported for mixed nested / non-nested arguments without bdim=0", 6384 ): 6385 vmap(f, in_dims=(0, 1))(x, y) 6386 6387 def test_multilevel_vmap_raises(self, device): 6388 def f(x): 6389 return x.sin() * 4.0 + 3.0 6390 6391 x = self._create_nt([2, 2, 2, None], device=device) 6392 6393 with self.assertRaisesRegex( 6394 RuntimeError, "Only one level of vmap is supported" 6395 ): 6396 vmap(vmap(f))(x) 6397 6398 with self.assertRaisesRegex( 6399 RuntimeError, "Only one level of vmap is supported" 6400 ): 6401 vmap(vmap(vmap(f)))(x) 6402 6403 6404only_for = ("cpu", "cuda") 6405instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for) 6406 6407instantiate_device_type_tests( 6408 TestVmapBatchedGradient, 6409 globals(), 6410 only_for=only_for, 6411) 6412instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for) 6413instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for) 6414instantiate_device_type_tests(TestVmapDeviceType, globals(), only_for=only_for) 6415instantiate_device_type_tests(TestVmapNestedTensor, globals(), only_for=only_for) 6416 6417if __name__ == "__main__": 6418 run_tests() 6419