1# Owner(s): ["module: vmap"] 2 3import functools 4import itertools 5import types 6import warnings 7 8import torch 9import torch.nn.functional as F 10from torch import Tensor 11from torch._vmap_internals import vmap 12from torch.testing._internal.common_device_type import instantiate_device_type_tests 13from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 14 15 16FALLBACK_REGEX = r"There is a performance drop" 17 18 19class EnableVmapFallbackWarnings: 20 def __enter__(self): 21 self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled() 22 torch._C._debug_only_display_vmap_fallback_warnings(True) 23 24 def __exit__(self, *ignored): 25 torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state) 26 27 28class TestVmapAPILegacy(TestCase): 29 def test_non_tensor_output_raises(self): 30 with self.assertRaisesRegex( 31 ValueError, "got type <class 'float'> as the return" 32 ): 33 output = vmap(lambda x: 3.14)(torch.ones(3)) 34 35 def multiple_outputs(x): 36 return x, 3 37 38 with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"): 39 vmap(multiple_outputs)(torch.ones(3)) 40 41 def test_different_map_dim_size_raises(self): 42 x = torch.randn(2) 43 y = torch.randn(3) 44 expected_msg = ( 45 "Expected all tensors to have the same size in the mapped dimension" 46 ) 47 with self.assertRaisesRegex(ValueError, expected_msg): 48 vmap(torch.mul)(x, y) 49 with self.assertRaisesRegex(ValueError, expected_msg): 50 vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) 51 with self.assertRaisesRegex(ValueError, expected_msg): 52 vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))( 53 {"x": x, "y": y} 54 ) 55 56 def test_func_with_no_inputs(self): 57 expected_msg = "got no inputs" 58 59 def foo(): 60 return torch.randn(3) 61 62 def bar(x): 63 return torch.randn(3) 64 65 with self.assertRaisesRegex(ValueError, expected_msg): 66 vmap(foo)() 67 68 with self.assertRaisesRegex(ValueError, expected_msg): 69 vmap(bar)() 70 71 def test_constant_function(self): 72 output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) 73 self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) 74 75 def test_single_input(self): 76 x = torch.randn(2, 3) 77 78 def square(x): 79 return x * x 80 81 output = vmap(square)(x) 82 self.assertEqual(output, x * x) 83 84 def test_multiple_inputs(self): 85 x = torch.randn(2, 3) 86 y = torch.randn(2, 3) 87 output = vmap(torch.mul)(x, y) 88 self.assertEqual(output, x * y) 89 90 def test_multiple_outputs(self): 91 def foo(x): 92 return x * x, x * x * x 93 94 x = torch.randn(3) 95 outputs = vmap(foo)(x) 96 self.assertEqual(outputs[0], x * x) 97 self.assertEqual(outputs[1], x * x * x) 98 99 def test_multiple_outputs_error_cases(self): 100 # This is the same thing as 101 # def returns_tuple_of_tensors(x): 102 # return x, x 103 def returns_tuple_of_tensors(x): 104 return (x, x) 105 106 def returns_list_of_two_tensors(x): 107 return [x, x] 108 109 def returns_list_of_one_tensor(x): 110 return [x] 111 112 x = torch.randn(3) 113 114 # should not throw 115 vmap(returns_tuple_of_tensors)(x) 116 117 # jax supports these, but we don't yet 118 msg = "must only return Tensors, got type <class 'list'>" 119 with self.assertRaisesRegex(ValueError, msg): 120 vmap(returns_list_of_two_tensors)(x) 121 with self.assertRaisesRegex(ValueError, msg): 122 vmap(returns_list_of_one_tensor)(x) 123 124 def test_nested_with_same_map_dim(self): 125 x = torch.randn(2, 3, 5) 126 y = torch.randn(2, 3, 5) 127 output = vmap(vmap(torch.mul))(x, y) 128 self.assertEqual(output, x * y) 129 130 output = vmap(vmap(vmap(torch.mul)))(x, y) 131 self.assertEqual(output, x * y) 132 133 def test_nested_with_different_map_dim(self): 134 x = torch.randn(2, 3) 135 y = torch.randn(5, 3) 136 output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) 137 self.assertEqual(output.shape, (2, 5, 3)) 138 self.assertEqual(output, x.view(2, 1, 3) * y) 139 140 z = torch.randn(7, 3) 141 output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) 142 self.assertEqual(output.shape, (2, 5, 7, 3)) 143 self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) 144 145 def test_noop_in_inner_vmap(self): 146 x = torch.randn(3) 147 y = torch.randn(5) 148 output = vmap(lambda x: vmap(lambda y: x)(y))(x) 149 self.assertEqual(output, x.view(3, 1).expand(3, 5)) 150 151 def test_unsupported_op_err_msg(self): 152 # Unsupported view op 153 tensor = torch.randn(2, 3) 154 msg = ( 155 r"Batching rule not implemented for aten::.+; the " 156 r"fallback path doesn't work on out= or view ops" 157 ) 158 with self.assertRaisesRegex(RuntimeError, msg): 159 vmap(torch.ravel)(tensor) 160 161 def out_op(x, y): 162 return torch.abs(x, out=y) 163 164 with self.assertRaisesRegex(RuntimeError, msg): 165 vmap(out_op)(tensor, tensor) 166 167 tensor = torch.randn(2) 168 # The fallback doesn't support TensorList 169 with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"): 170 vmap(lambda t: torch.atleast_1d([t]))(tensor) 171 172 # Don't support non-tensor returns. This is a limitation of vmap; 173 # functions that don't return tensors must be special cased 174 with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"): 175 vmap(torch.Tensor.item)(tensor) 176 177 def test_nonzero_out_dims(self): 178 # Basic test 179 tensor = torch.randn(2, 3) 180 result = vmap(lambda x: x, out_dims=1)(tensor) 181 self.assertEqual(result, tensor.permute(1, 0)) 182 self.assertEqual(result.data_ptr(), tensor.data_ptr()) 183 184 # Test that the batch dimension gets permuted to dim 2 185 tensor = torch.randn(2, 3, 5, 7) 186 result = vmap(lambda x: x, out_dims=2)(tensor) 187 self.assertEqual(result, tensor.permute(1, 2, 0, 3)) 188 self.assertEqual(result.data_ptr(), tensor.data_ptr()) 189 190 # negative out_dim 191 tensor = torch.randn(2, 3, 5, 7) 192 result = vmap(lambda x: x, out_dims=-1)(tensor) 193 self.assertEqual(result, tensor.permute(1, 2, 3, 0)) 194 self.assertEqual(result.data_ptr(), tensor.data_ptr()) 195 196 # check that out_dims works on ALL outputs 197 tensor = torch.randn(2, 3, 5, 7) 198 other = torch.randn(2, 3, 5, 7) 199 result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other) 200 self.assertEqual( 201 result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)) 202 ) 203 204 # use out_dims with the maximum vmap-able tensor dims (64 dims) 205 ndims = 64 206 shape = [2] + [1] * (ndims - 1) 207 expected_shape = [1, 1, 2] + [1] * (ndims - 3) 208 tensor = torch.randn(shape) 209 result = vmap(lambda x: x, out_dims=2)(tensor) 210 self.assertEqual(result.shape, expected_shape) 211 212 # test something that is not the identity function 213 def foo(x, y): 214 return x, x * y, x * y * y 215 216 x = torch.randn(2, 3, 5) 217 y = torch.randn(2, 3, 5) 218 result = vmap(foo, out_dims=1)(x, y) 219 self.assertEqual( 220 result, 221 ( 222 x.permute(1, 0, 2), 223 (x * y).permute(1, 0, 2), 224 (x * y * y).permute(1, 0, 2), 225 ), 226 ) 227 228 def test_multiple_out_dims(self): 229 def foo(x): 230 return x, x 231 232 def bar(x, y): 233 return x, x, x, x * y 234 235 x = torch.randn(2, 3, 5) 236 y = torch.randn(2, 3, 5) 237 result = vmap(foo, out_dims=(0, 1))(x) 238 self.assertEqual(result, (x, x.permute(1, 0, 2))) 239 240 result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y) 241 expected = ( 242 x.permute(1, 2, 0), 243 x, 244 x.permute(1, 0, 2), 245 (x * y).permute(1, 2, 0), 246 ) 247 self.assertEqual(result, expected) 248 249 def test_nested_out_dims(self): 250 y = torch.randn(2, 3, 5, 7) 251 252 # Inner vmap has non-zero out_dim 253 result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) 254 self.assertEqual(result.shape, (2, 5, 3, 7)) 255 self.assertEqual(result, y.permute(0, 2, 1, 3)) 256 257 # all vmaps have non-zero out_dim 258 result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) 259 self.assertEqual(result.shape, (5, 2, 3, 7)) 260 self.assertEqual(result, y.permute(2, 0, 1, 3)) 261 262 # throwing in some negative out_dims 263 result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) 264 self.assertEqual(result.shape, (5, 7, 3, 2)) 265 self.assertEqual(result, y.permute(2, 3, 1, 0)) 266 267 # testing fn that isn't the identity 268 x = torch.randn(2, 3) 269 y = torch.randn(5, 3) 270 result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) 271 self.assertEqual(result.shape, (3, 2, 5)) 272 self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0)) 273 274 def test_out_dims_edge_case(self): 275 def foo(x): 276 return x 277 278 # Test that we accept out_dims=(1,) for a function with one output. 279 tensor = torch.randn(2, 3) 280 expected = vmap(foo, out_dims=1)(tensor) 281 result = vmap(foo, out_dims=(1,))(tensor) 282 self.assertEqual(result, expected) 283 284 def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self): 285 msg = "`out_dims` must be an int or a tuple of int" 286 tensor = torch.randn(2, 3) 287 with self.assertRaisesRegex(ValueError, msg): 288 vmap(lambda x: x, out_dims="lol")(tensor) 289 with self.assertRaisesRegex(ValueError, msg): 290 vmap(lambda x: x, out_dims=("lol",))(tensor) 291 with self.assertRaisesRegex(ValueError, msg): 292 vmap(lambda x: x, out_dims=None)(tensor) 293 with self.assertRaisesRegex(ValueError, msg): 294 vmap(lambda x: x, out_dims=(None,))(tensor) 295 296 def test_out_dims_and_num_outputs_mismatch_err_msg(self): 297 msg = "`out_dims` must have one dim per output" 298 x = torch.randn(2, 3, 5) 299 300 # Too many out_dims 301 with self.assertRaisesRegex(ValueError, msg): 302 vmap(lambda x: x, out_dims=(0, 0))(x) 303 with self.assertRaisesRegex(ValueError, msg): 304 vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x) 305 306 # Too few out_dims 307 with self.assertRaisesRegex(ValueError, msg): 308 vmap(lambda x: (x, x), out_dims=(0,))(x) 309 with self.assertRaisesRegex(ValueError, msg): 310 vmap(lambda x: (x, x, x), out_dims=(0, 0))(x) 311 312 def test_out_dim_out_of_bounds_err_msg(self): 313 # TODO(rzou): This error message isn't that great. It comes straight 314 # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to 315 # the error message in the future in C++ 316 msg = "Dimension out of range" 317 x = torch.randn(2, 3, 5) 318 with self.assertRaisesRegex(IndexError, msg): 319 vmap(lambda x: x, out_dims=3)(x) 320 with self.assertRaisesRegex(IndexError, msg): 321 vmap(lambda x: x, out_dims=-4)(x) 322 323 def test_non_zero_in_dims(self): 324 tensor = torch.randn(2, 3, 5) 325 326 # Implicit out_dims = 0; vmap will move the batch dim to the front. 327 output = vmap(lambda x: x, (1,))(tensor) 328 self.assertEqual(output, tensor.permute(1, 0, 2)) 329 self.assertEqual(output.data_ptr(), tensor.data_ptr()) 330 331 x = torch.randn(2, 3) 332 y = torch.randn(3, 2) 333 output = vmap(torch.mul, (0, 1))(x, y) 334 self.assertEqual(output, x * y.t()) 335 output = vmap(torch.mul, (1, 0))(x, y) 336 self.assertEqual(output, x.t() * y) 337 338 def test_none_in_dims(self): 339 x = torch.randn(2, 3) 340 y = torch.randn(2, 3) 341 342 # None in_dim for a Tensor means we don't map over it 343 output = vmap(torch.mul, (0, None))(x, y) 344 self.assertEqual(output.shape, (2, 2, 3)) 345 self.assertEqual(output, x.view(2, 1, 3) * y) 346 347 # None in_dim for non-tensor arguments 348 output = vmap(torch.mul, (0, None))(x, 2) 349 self.assertEqual(output, x * 2) 350 351 def test_nested_non_default_in_dims(self): 352 x = torch.rand(5, 2, 3) 353 y = torch.rand(3, 5, 2) 354 result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y) 355 self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1)) 356 357 def test_non_default_in_dims_out_dims(self): 358 x = torch.randn(2, 3, 5) 359 360 # Same in_dim as out_dim, vmap over identity 361 result = vmap(lambda x: x, in_dims=1, out_dims=1)(x) 362 self.assertEqual(result, x) 363 self.assertEqual(result.data_ptr(), x.data_ptr()) 364 365 # Different in_dim from out_dim, vmap over identity 366 result = vmap(lambda x: x, in_dims=2, out_dims=1)(x) 367 self.assertEqual(result.shape, (2, 5, 3)) 368 self.assertEqual(result, x.transpose(1, 2)) 369 self.assertEqual(result.data_ptr(), x.data_ptr()) 370 371 def foo(x): 372 return x * 2 373 374 # Same in_dim as out_dim, vmap over operation 375 result = vmap(foo, in_dims=1, out_dims=1)(x) 376 self.assertEqual(result, x * 2) 377 378 # Different in_dim as out_dim, vmap over operation 379 result = vmap(foo, in_dims=2, out_dims=1)(x) 380 self.assertEqual(result.shape, (2, 5, 3)) 381 self.assertEqual(result, (x * 2).transpose(1, 2)) 382 383 # Basic nested test. 384 result = vmap(vmap(foo, 1, 1), 1, 1)(x) 385 self.assertEqual(result, x * 2) 386 387 def test_accepts_nested_inputs(self): 388 B0 = 2 389 x = torch.randn(2, 3) 390 y = torch.randn(2, 3) 391 392 # Single layer of nesting 393 out = vmap(lambda z: z[0] + z[1])((x, y)) 394 self.assertEqual(out, x + y) 395 out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y)) 396 self.assertEqual(out, x + y) 397 out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) 398 self.assertEqual(out, x + y) 399 400 out = vmap(lambda z: z[0] + z[1])([x, y]) 401 self.assertEqual(out, x + y) 402 out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y]) 403 self.assertEqual(out, x + y) 404 out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y]) 405 self.assertEqual(out, x + y) 406 407 out = vmap(lambda z: z["x"] + z["y"])({"x": x, "y": y}) 408 self.assertEqual(out, x + y) 409 out = vmap(lambda z: z["x"] + z["y"], in_dims=(0,))({"x": x, "y": y}) 410 self.assertEqual(out, x + y) 411 out = vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))( 412 {"x": x, "y": y} 413 ) 414 self.assertEqual(out, x + y) 415 416 # Multiple layers of nesting 417 out_fn = vmap(lambda z: z["x"][0] + z["x"][1][0] + z["y"][0] + z["y"][1]) 418 out = out_fn({"x": [x, (x,)], "y": [y, y]}) 419 self.assertEqual(out, x + x + y + y) 420 421 def test_in_dims_wrong_type_err_msg(self): 422 x = torch.randn(3) 423 y = torch.randn(3) 424 msg = r"expected `in_dims` to be int or a \(potentially nested\) tuple" 425 with self.assertRaisesRegex(ValueError, msg): 426 vmap(torch.mul, [0, 0])(x, y) 427 with self.assertRaisesRegex(ValueError, msg): 428 vmap(torch.mul, set({0}))(x, y) 429 with self.assertRaisesRegex(ValueError, msg): 430 vmap(torch.mul, "lol")(x, y) 431 with self.assertRaisesRegex(ValueError, msg): 432 vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y]) 433 # The following should not throw 434 vmap(torch.mul, (0, 0))(x, y) 435 436 def test_not_enough_in_dims_err_msg(self): 437 x = torch.randn(3) 438 y = torch.randn(3) 439 msg = r"in_dims is not compatible with the structure of `inputs`" 440 441 with self.assertRaisesRegex(ValueError, msg): 442 vmap(torch.mul, (0,))(x, y) 443 with self.assertRaisesRegex(ValueError, msg): 444 vmap(torch.mul, (0, 0, 0))(x, y) 445 with self.assertRaisesRegex(ValueError, msg): 446 vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y]) 447 with self.assertRaisesRegex(ValueError, msg): 448 vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y]) 449 # The following should not throw 450 vmap(torch.mul, (0, 0))(x, y) 451 452 def test_integer_in_dim_but_not_tensor_input_err_msg(self): 453 def foo(xy): 454 return xy[0] * xy[1] 455 456 def bar(x, yz): 457 return x * yz[0] * yz[1] 458 459 x = torch.randn(2, 3) 460 y = torch.randn(2, 3) 461 462 # the following are errors in jax (and will always be errors) 463 msg = "Got in_dim=0 for an input but the input is of type" 464 with self.assertRaisesRegex(ValueError, msg): 465 vmap(torch.sum)(x, 0) 466 with self.assertRaisesRegex(ValueError, msg): 467 vmap(torch.sum, (0, 0))(x, 0) 468 with self.assertRaisesRegex(ValueError, msg): 469 vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1]) 470 # The following should not throw 471 vmap(torch.sum, (0, None))(x, 0) 472 473 def test_in_dim_not_in_tensor_err_msg(self): 474 def foo(x): 475 return x * x 476 477 x = torch.randn(2, 3) 478 y = torch.randn(2, 3) 479 480 msg = r"Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w" 481 with self.assertRaisesRegex(ValueError, msg): 482 vmap(foo)(torch.randn([])) 483 with self.assertRaisesRegex(ValueError, msg): 484 vmap(foo, in_dims=(0,))(torch.randn([])) 485 with self.assertRaisesRegex(ValueError, msg): 486 vmap(foo, in_dims=(-1,))(x) 487 with self.assertRaisesRegex(ValueError, msg): 488 vmap(foo, in_dims=(2,))(y) 489 with self.assertRaisesRegex(ValueError, msg): 490 vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y]) 491 # the following should not throw 492 vmap(foo, in_dims=(0,))(torch.randn(2, 3)) 493 vmap(foo, in_dims=(1,))(torch.randn(2, 3)) 494 495 def test_fallback_does_not_warn_by_default(self): 496 # NB: One day we will implement a batching rule for torch.atan2. 497 # If/when we do, this test should be replaced to test the fallback 498 # path on another operator to avoid bitrot. 499 op = torch.atan2 500 x = torch.randn(11) 501 y = torch.randn(11) 502 with warnings.catch_warnings(record=True) as wa: 503 result = vmap(op)(x, y) 504 # The single warning here is the "vmap is experimental" 505 # warning, not a warning from the vmap fallback path. 506 self.assertEqual(len(wa), 1) 507 508 def test_fallback_warns_when_warnings_are_enabled(self): 509 # NB: One day we will implement a batching rule for torch.atan2. 510 # If/when we do, this test should be replaced to test the fallback 511 # path on another operator to avoid bitrot. 512 op = torch.atan2 513 x = torch.randn(11) 514 y = torch.randn(11) 515 with warnings.catch_warnings(record=True) as wa: 516 with EnableVmapFallbackWarnings(): 517 result = vmap(op)(x, y) 518 self.assertEqual(len(wa), 2) 519 self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) 520 521 def _assert_uses_vmap_fallback(self, vmap_args, inputs): 522 with warnings.catch_warnings(record=True) as wa: 523 with EnableVmapFallbackWarnings(): 524 result = vmap(*vmap_args)(*inputs) 525 self.assertEqual(len(wa), 2) 526 self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) 527 528 def test_fallback_zero_dim(self): 529 # NB: One day we will implement a batching rule for torch.atan2. 530 # If/when we do, this test should be replaced to test the fallback 531 # path on another operator to avoid bitrot. 532 op = torch.atan2 533 x = torch.randn(11) 534 y = torch.randn(11) 535 self._assert_uses_vmap_fallback((op,), (x, y)) 536 537 B0, B1 = 0, 3 538 x = torch.randn(B0, 11) 539 y = torch.randn(11) 540 541 msg = "The fallback path does not support vmap over dims of size 0" 542 543 with self.assertRaisesRegex(RuntimeError, msg): 544 vmap(op, (0, None))(x, y) 545 with self.assertRaisesRegex(RuntimeError, msg): 546 vmap(op, (None, 0))(y, x) 547 with self.assertRaisesRegex(RuntimeError, msg): 548 vmap(op)(x, x) 549 550 x = torch.randn(B0, B1, 11) 551 y = torch.randn(B1, 11) 552 with self.assertRaisesRegex(RuntimeError, msg): 553 vmap(op, (0, None))(x, y) 554 with self.assertRaisesRegex(RuntimeError, msg): 555 vmap(op, (None, 0))(y, x) 556 with self.assertRaisesRegex(RuntimeError, msg): 557 vmap(op)(x, x) 558 559 def test_fallback_atan2(self): 560 # NB: One day we will implement a batching rule for torch.atan2. 561 # If/when we do, this test should be replaced to test the fallback 562 # path on another operator to avoid bitrot. 563 op = torch.atan2 564 565 x = torch.randn(5, 7, 11) 566 y = torch.randn(5, 7, 11) 567 568 self._assert_uses_vmap_fallback((op,), (x, y)) 569 570 # fallback on torch.atan2 571 x = torch.randn(7, 11, 5) 572 y = torch.randn(5, 7, 11) 573 result = vmap(op, (2, 0))(x, y) 574 self.assertEqual(result, op(x.permute(2, 0, 1), y)) 575 576 # fallback on torch.atan2, nested vmap 577 x = torch.randn(7, 11, 5) 578 y = torch.randn(5, 7, 11) 579 result = vmap(vmap(op), (2, 0))(x, y) 580 self.assertEqual(result, op(x.permute(2, 0, 1), y)) 581 582 # big batch size (total 10000) 583 x = torch.randn(100, 10, 10, 5) 584 y = torch.randn(100, 10, 10) 585 result = vmap(vmap(vmap(op)))(x, y) 586 self.assertEqual(result, op(x, y.view(100, 10, 10, 1))) 587 588 def test_fallback_masked_fill(self): 589 # NB: One day we will implement a batching rule for masked_fill 590 # If/when we do, this test should be replaced to test the fallback 591 # path on another operator to avoid bitrot. 592 def run_test(batch_size): 593 B0 = batch_size 594 x = torch.randn(B0, 7, 11, 13) 595 dim = 0 596 index = torch.tensor([0, 4, 2]) 597 values = torch.randn(B0, 3, 11, 13) 598 599 self._assert_uses_vmap_fallback( 600 (torch.index_add, (0, None, None, 0)), (x, dim, index, values) 601 ) 602 603 result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values) 604 expected = torch.index_add(x, dim + 1, index, values.view(B0, 3, 11, 13)) 605 self.assertEqual(result, expected) 606 607 run_test(batch_size=5) 608 run_test(batch_size=1237) 609 610 def test_fallback_multiple_returns(self): 611 # NB: One day we will implement a batching rule for torch.var_mean 612 # If/when we do, this test should be replaced to test the fallback 613 # path on another operator to avoid bitrot. 614 B0, B1, B2 = 2, 3, 1237 615 tensor = torch.randn(B0, 10) 616 617 self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,)) 618 619 # fallback correctness on torch.var_mean 620 result = vmap(torch.var_mean)(tensor) 621 expected = torch.var_mean(tensor, dim=1) 622 self.assertEqual(result, expected) 623 624 # nested vmap 625 tensor = torch.randn(B0, B1, 10) 626 result = vmap(vmap(torch.var_mean))(tensor) 627 expected = torch.var_mean(tensor, dim=2) 628 self.assertEqual(result, expected) 629 630 # big batch size, nested vmap 631 tensor = torch.randn(B0, B1, B2, 10) 632 result = vmap(vmap(vmap(torch.var_mean)))(tensor) 633 expected = torch.var_mean(tensor, dim=3) 634 self.assertEqual(result, expected) 635 636 def test_inplace_fallback_unary(self): 637 # Test the in-place fallback on an in-place method that takes no 638 # additional Tensor arguments. This is the simplest case of the fallback. 639 # NB: One day we will implement a batching rule for acos_. 640 # If/when we do, this test should be replaced to test the fallback 641 # path on another operator to avoid bitrot. 642 op = Tensor.acos_ 643 B0, B1, B2 = 2, 3, 10000 644 645 x = torch.randn(B0, 5) 646 self._assert_uses_vmap_fallback((op,), (x,)) 647 648 # Single vmap 649 x_orig = torch.rand(B0, 5) 650 x = x_orig.clone() 651 result = vmap(op)(x) 652 self.assertTrue(result is x) 653 self.assertEqual(result, x_orig.acos()) 654 655 # Single vmap + different out_dim produces a view(!) 656 x_orig = torch.rand(B0, 5) 657 x = x_orig.clone() 658 result = vmap(op, out_dims=(1,))(x) 659 self.assertTrue(result._base is x) 660 self.assertEqual(result, x_orig.t().acos()) 661 662 # Nested vmap 663 x_orig = torch.randn(B0, B1, 5) 664 x = x_orig.clone() 665 result = vmap(vmap(op))(x) 666 self.assertTrue(result is x) 667 self.assertEqual(result, x_orig.acos()) 668 669 # Nested vmap, large batch size 670 x_orig = torch.randn(B0, B1, B2, 5) 671 x = x_orig.clone() 672 result = vmap(vmap(vmap(op)))(x) 673 self.assertTrue(result is x) 674 self.assertEqual(result, x_orig.acos()) 675 676 def test_inplace_fallback_nary_same_levels(self): 677 # NB: One day we will implement a batching rule for atan2_ 678 # If/when we do, this test should be replaced to test the fallback 679 # path on another operator to avoid bitrot. 680 op = Tensor.atan2_ 681 outplace_op = torch.atan2 682 683 x = torch.randn(5, 7, 11) 684 y = torch.randn(5, 7, 11) 685 self._assert_uses_vmap_fallback((op,), (x, y)) 686 687 # Single vmap 688 B0 = 5 689 x_orig = torch.randn(7, 11, B0) 690 x = x_orig.clone() 691 y = torch.randn(B0, 7, 11) 692 vmap(op, (2, 0))(x, y) 693 self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2))) 694 695 # Nested vmap 696 B0, B1 = 5, 7 697 x_orig = torch.randn(B1, 11, B0) 698 x = x_orig.clone() 699 y = torch.randn(B0, B1, 11) 700 vmap(vmap(op), (2, 0))(x, y) 701 self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0]))) 702 703 # big batch size (total 10000) 704 B0, B1, B2 = 100, 10, 10 705 x_orig = torch.randn(B0, B1, B2, 5) 706 x = x_orig.clone() 707 y = torch.randn(B0, B1, B2) 708 result = vmap(vmap(vmap(op)))(x, y) 709 self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1))) 710 711 def test_inplace_fallback_nary_different_levels(self): 712 # NB: One day we will implement a batching rule for atan2_ 713 # If/when we do, this test should be replaced to test the fallback 714 # path on another operator to avoid bitrot. 715 op = Tensor.atan2_ 716 outplace_op = torch.atan2 717 B0, B1, B2 = 2, 3, 5 718 719 x = torch.rand(B0, 7) 720 y = torch.rand(7) 721 self._assert_uses_vmap_fallback((op, (0, None)), (x, y)) 722 723 # op(left, right): All of the levels in right are found in left 724 x_orig = torch.rand(B0, 7) 725 x = x_orig.clone() 726 y = torch.rand(7) 727 vmap(op, in_dims=(0, None))(x, y) 728 self.assertEqual(x, outplace_op(x_orig, y)) 729 730 x_orig = torch.rand(B0, B1, 7) 731 x = x_orig.clone() 732 y = torch.rand(B0, 7) 733 vmap(vmap(op, in_dims=(0, None)))(x, y) 734 self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7))) 735 736 # op(left, right): Some of the levels in right are not found in left 737 msg = r"vmap: aten::atan2_\(self, \*extra_args\) is not possible" 738 x = torch.rand(7) 739 y = torch.rand(B0, 7) 740 with self.assertRaisesRegex(RuntimeError, msg): 741 vmap(op, in_dims=(None, 0))(x, y) 742 743 x = torch.rand(B1, 7) 744 y = torch.rand(B0, 7) 745 with self.assertRaisesRegex(RuntimeError, msg): 746 vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y) 747 748 x = torch.rand(B1, 7) 749 y = torch.rand(7, B0) 750 with self.assertRaisesRegex(RuntimeError, msg): 751 vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y) 752 753 x = torch.rand(B0, 7) 754 y = torch.rand(B0, B1, 7) 755 with self.assertRaisesRegex(RuntimeError, msg): 756 vmap(vmap(op, in_dims=(None, 0)))(x, y) 757 758 def test_backward_unsupported_interaction(self): 759 x = torch.randn(3, requires_grad=True) 760 y = torch.randn(5) 761 grad = torch.randn_like(x) 762 err_msg = r"backward\(\) called inside torch.vmap" 763 764 def backward_on_vmapped_tensor(x): 765 x.sum().backward() 766 767 with self.assertRaisesRegex(RuntimeError, err_msg): 768 vmap(backward_on_vmapped_tensor)(x) 769 770 def backward_with_vmapped_grad(x, grad): 771 x.backward(grad) 772 773 with self.assertRaisesRegex(RuntimeError, err_msg): 774 vmap(backward_with_vmapped_grad)(x, grad) 775 776 def completely_unrelated_backward(y): 777 x.sum().backward() 778 779 with self.assertRaisesRegex(RuntimeError, err_msg): 780 vmap(completely_unrelated_backward)(y) 781 782 def test_grad_unsupported_interaction(self): 783 input_tensor = torch.randn(3, requires_grad=True) 784 err_msg = "autograd.grad.* called inside torch.vmap" 785 786 captured = torch.randn(3, requires_grad=True) 787 788 def output_to_grad_is_vmapped(input_tensor): 789 output = (captured * input_tensor).sum() 790 return torch.autograd.grad([output], [captured])[0] 791 792 with self.assertRaisesRegex(RuntimeError, err_msg): 793 vmap(output_to_grad_is_vmapped)(input_tensor) 794 795 output = (input_tensor**2).sum() 796 797 def input_to_grad_is_vmapped(input_tensor): 798 return torch.autograd.grad([output], [input_tensor])[0] 799 800 with self.assertRaisesRegex(RuntimeError, err_msg): 801 vmap(input_to_grad_is_vmapped)(input_tensor) 802 803 def test_batched_gradient_basic(self): 804 N = 3 805 x = torch.randn(N, requires_grad=True) 806 y = torch.randn(N) 807 808 def vjp_mul(v): 809 return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0] 810 811 batched_v = torch.eye(N) 812 jacobian = vmap(vjp_mul)(batched_v) 813 self.assertEqual(jacobian, torch.diagflat(y)) 814 815 def test_functools_partial(self): 816 x = torch.randn(3) 817 y = torch.randn(2, 3) 818 result = vmap(functools.partial(torch.mul, x))(y) 819 self.assertEqual(result, x * y) 820 821 def test_nn_module(self): 822 tensor = torch.randn(2, 3) 823 model = torch.nn.Linear(3, 3, bias=False) 824 result = vmap(model)(tensor) 825 self.assertEqual(result, model(tensor)) 826 827 def test_fallback_with_undefined_grad(self): 828 B0 = 7 829 x = torch.randn(2, 3, 4, 5, requires_grad=True) 830 weight = torch.randn(3, 3, 1, 1) 831 v = torch.randn(B0, 2, 3, 4, 5) 832 833 def get_vjp(v): 834 result = torch.nn.functional.conv2d(x, weight) 835 (grad_x,) = torch.autograd.grad(result, x, v) 836 return grad_x 837 838 # Runs vmap(get_vjp)(v), which should not error out. 839 # The backward formula for convolution returns an undefined 840 # Tensor for grad_bias because the original bias does not exist. 841 # 842 # In the future we'll probably add a batching rule for convolution 843 # backward. When this happens, we should modify this test to use a 844 # different op (and/or create and use a dummy operator) to avoid bitrot. 845 self._assert_uses_vmap_fallback([get_vjp], [v]) 846 847 848def slice_inputs(inputs, bdims, i): 849 result = [] 850 for inp, bdim in zip(inputs, bdims): 851 if bdim is None: 852 result.append(inp) 853 else: 854 result.append(inp.select(bdim, i)) 855 return tuple(result) 856 857 858def reference_vmap(op, inputs, in_dims=0, out_dims=0): 859 if isinstance(in_dims, int): 860 in_dims = (in_dims,) * len(inputs) 861 bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None] 862 assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes) 863 bdim_size = bdim_sizes[0] 864 results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size)) 865 866 assert len(results) > 0 867 op_has_single_return = not isinstance(results[0], tuple) 868 if op_has_single_return: 869 assert all(isinstance(result, torch.Tensor) for result in results) 870 if isinstance(out_dims, int): 871 out_dims = (out_dims,) * 1 872 return torch.stack(results, dim=out_dims[0]) 873 874 assert all(isinstance(result, tuple) for result in results) 875 num_returns = len(results[0]) 876 assert all(len(result) == num_returns for result in results) 877 if isinstance(out_dims, int): 878 out_dims = (out_dims,) * num_returns 879 return tuple( 880 torch.stack(result_shards, out_dim) 881 for result_shards, out_dim in zip(zip(*results), out_dims) 882 ) 883 884 885class TensorFactory: 886 @staticmethod 887 def rand(size, device="cpu", dtype=torch.float): 888 return torch.rand(size, device=device, dtype=dtype) 889 890 @staticmethod 891 def randn(size, device="cpu", dtype=torch.float): 892 return torch.randn(size, device=device, dtype=dtype) 893 894 @staticmethod 895 def randp1(size, device="cpu", dtype=torch.float): 896 return torch.rand(size, device=device, dtype=dtype) + 1 897 898 899# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a 900# (slow) sequential map+stack fallback. 901# 902# check_view: Test if the first returned output is a view of the first input 903# check_propagates_grad: Test if the operation propagates gradients. 904def _vmap_test( 905 self, 906 op, 907 inputs, 908 in_dims=0, 909 out_dims=0, 910 check_view=False, 911 check_propagates_grad=True, 912): 913 result = vmap(op, in_dims, out_dims)(*inputs) 914 reference_result = reference_vmap(op, inputs, in_dims, out_dims) 915 self.assertEqual(result, reference_result) 916 op_has_single_return = not isinstance(result, tuple) 917 918 if check_view: 919 result_as_tuple = (result,) if op_has_single_return else result 920 for output in result_as_tuple: 921 input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base 922 self.assertTrue( 923 output._base is input0_base, 924 msg="result was not a view of the first input!", 925 ) 926 927 if not check_propagates_grad: 928 return 929 # Assuming input[0] is a floating-point tensor. Check if the vmap 930 # operation propagates the requires_grad flag to the zeroth output. 931 # Some vmap operators are implemented in a way that assumes that 932 # they are composite with respect to autograd. If the operator ever is 933 # changed to not be composite with respect to autograd, then the 934 # following check should fail. 935 inputs_clone = list(inputs) 936 inputs_clone[0] = inputs[0].clone().requires_grad_() 937 result = vmap(op, in_dims, out_dims)(*inputs_clone) 938 result_as_tuple = (result,) if op_has_single_return else result 939 self.assertTrue(result[0].requires_grad) 940 941 942def should_allow_vmap_fallback_usage(fn): 943 return getattr(fn, "_allow_vmap_fallback_usage", False) 944 945 946def allowVmapFallbackUsage(fn): 947 fn._allow_vmap_fallback_usage = True 948 return fn 949 950 951# All tests of TestVmapBaseLegacy check that the slow vmap fallback is never invoked. 952# This is so that we can incrementally add batching rules for operators to 953# replace the slow vmap fallback path for said operators. To skip this check, 954# please use the allowVmapFallbackUsage decorator. 955# 956# NB: Don't add tests to TestVmapBaseLegacy directly, unless you want them to run 957# on every subclass of TestVmapBaseLegacy. Add them to e.g. TestVmapOperators. 958# 959# NB: TestVmapBaseLegacy is a nested class. This prevents test runners from picking 960# it up and running it. 961class Namespace: 962 class TestVmapBaseLegacy(TestCase): 963 def __init__(self, method_name="runTest"): 964 super().__init__(method_name) 965 966 test_method = getattr(self, method_name, None) 967 if test_method is None: 968 return 969 970 if not should_allow_vmap_fallback_usage(test_method): 971 setattr( 972 self, 973 method_name, 974 self._wrap_method_with_vmap_fallback_check(test_method), 975 ) 976 977 def _wrap_method_with_vmap_fallback_check(self, method): 978 msg = ( 979 "Expected the test to not invoke the vmap fallback path, i.e., " 980 "all of the operators being tested in this test should have batching " 981 "rules implemented. If you are intentionally testing something to " 982 "do with the fallback path, use allowVmapFallbackUsage. Otherwise, " 983 "please make sure that batching rules are implemented for the " 984 "operator(s) being tested." 985 ) 986 987 @functools.wraps(method) 988 def wrapper(self, *args, **kwargs): 989 with warnings.catch_warnings(record=True) as wa: 990 warnings.simplefilter("always") 991 with EnableVmapFallbackWarnings(): 992 method(*args, **kwargs) 993 for captured_warning in wa: 994 self.assertNotRegex( 995 str(captured_warning.message), FALLBACK_REGEX, msg 996 ) 997 998 return types.MethodType(wrapper, self) 999 1000 @allowVmapFallbackUsage 1001 def test_vmap_fallback_check_ok(self): 1002 # One day we'll implement a batching rule for torch.var_mean. 1003 # When that happens, please change the example to use an 1004 # operator that doesn't have a batching rule implemented. 1005 op_using_fallback = torch.var_mean 1006 vmap(op_using_fallback)(torch.rand(3)) 1007 1008 def test_vmap_fallback_check(self): 1009 @self._wrap_method_with_vmap_fallback_check 1010 def no_fallback(self): 1011 pass 1012 1013 # One day we'll implement a batching rule for torch.var_mean. 1014 # When that happens, please change the example to use an 1015 # operator that doesn't have a batching rule implemented. 1016 op_using_fallback = torch.var_mean 1017 1018 @self._wrap_method_with_vmap_fallback_check 1019 def uses_fallback(self): 1020 vmap(op_using_fallback)(torch.rand(3)) 1021 1022 no_fallback(self) 1023 1024 with self.assertRaises(AssertionError): 1025 uses_fallback(self) 1026 1027 1028class TestVmapOperatorsLegacy(Namespace.TestVmapBaseLegacy): 1029 def _vmap_test(self, *args, **kwargs): 1030 return _vmap_test(self, *args, **kwargs) 1031 1032 def _vmap_view_test(self, *args, **kwargs): 1033 self._vmap_test(*args, **kwargs, check_view=True) 1034 1035 def _test_unary(self, op, getter, device, *args, **kwargs): 1036 test = functools.partial(self._vmap_test, *args, **kwargs) 1037 B0, B1 = 7, 11 1038 1039 # Single vmap, various in_dims / out_dims 1040 test(op, [getter([B0, 3], device)]) 1041 test(op, [getter([2, 5, B0, 3], device)], in_dims=2) 1042 test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2) 1043 1044 # Doubly nested vmap 1045 test(vmap(op), [getter([B0, B1], device)]) 1046 test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2) 1047 test( 1048 vmap(op, in_dims=2), 1049 [getter([2, 5, B0, B1, 3], device)], 1050 in_dims=2, 1051 out_dims=2, 1052 ) 1053 1054 def test_unary_pointwise_ops(self): 1055 cases = [ 1056 (torch.abs, TensorFactory.randn), 1057 (torch.acos, TensorFactory.rand), 1058 (torch.asin, TensorFactory.rand), 1059 (torch.atan, TensorFactory.rand), 1060 (torch.ceil, TensorFactory.randn), 1061 (torch.cos, TensorFactory.rand), 1062 (torch.cosh, TensorFactory.rand), 1063 (torch.digamma, TensorFactory.rand), 1064 (torch.exp, TensorFactory.randn), 1065 (torch.expm1, TensorFactory.randn), 1066 (torch.floor, TensorFactory.randn), 1067 (torch.frac, TensorFactory.randn), 1068 (torch.lgamma, TensorFactory.rand), 1069 (torch.log, TensorFactory.randp1), 1070 (torch.log10, TensorFactory.randp1), 1071 (torch.log1p, TensorFactory.randp1), 1072 (torch.log2, TensorFactory.randp1), 1073 (torch.neg, TensorFactory.randn), 1074 (torch.reciprocal, TensorFactory.randp1), 1075 (torch.relu, TensorFactory.randn), 1076 (torch.round, TensorFactory.randn), 1077 (torch.rsqrt, TensorFactory.randp1), 1078 (torch.sigmoid, TensorFactory.randn), 1079 (torch.sign, TensorFactory.randn), 1080 (torch.sin, TensorFactory.rand), 1081 (torch.sinh, TensorFactory.rand), 1082 (torch.sqrt, TensorFactory.rand), 1083 (torch.tan, TensorFactory.rand), 1084 (torch.tanh, TensorFactory.rand), 1085 (torch.trunc, TensorFactory.randn), 1086 ] 1087 for op, getter in cases: 1088 self._test_unary(op, getter, "cpu") 1089 1090 def test_clone(self): 1091 # Some basic tests 1092 self._test_unary(lambda x: x.clone(), TensorFactory.randn, "cpu") 1093 self._test_unary( 1094 lambda x: x.clone(memory_format=torch.preserve_format), 1095 TensorFactory.randn, 1096 "cpu", 1097 ) 1098 self._test_unary( 1099 lambda x: x.clone(memory_format=torch.contiguous_format), 1100 TensorFactory.randn, 1101 "cpu", 1102 ) 1103 1104 # Test that the per-examples are contiguous when using torch.contiguous_format 1105 def clone_contiguous(x): 1106 return x.clone(memory_format=torch.contiguous_format) 1107 1108 B0, B1 = 3, 5 1109 x = torch.randn(2, B0, 7) 1110 y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x) 1111 self.assertTrue(y.movedim(1, 0).is_contiguous()) 1112 self.assertTrue(y[:, 0, :].is_contiguous()) 1113 1114 x = torch.randn(2, B0, 7, B1) 1115 y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x) 1116 self.assertTrue(y.is_contiguous()) 1117 self.assertTrue(y[0][0].is_contiguous()) 1118 1119 msg = r"only supported with memory_format torch.preserve_format or torch.contiguous_format" 1120 with self.assertRaisesRegex(RuntimeError, msg): 1121 vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0)) 1122 with self.assertRaisesRegex(RuntimeError, msg): 1123 vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))( 1124 torch.randn(B0) 1125 ) 1126 1127 def test_binary_pointwise_ops(self): 1128 def get_number(getter): 1129 return getter([]).item() 1130 1131 def make_case(op, input_getter=TensorFactory.randn): 1132 return (op, input_getter) 1133 1134 cases = [ 1135 # Basic arithmetic 1136 make_case(torch.add), 1137 make_case(lambda x, y: x + y), 1138 make_case(torch.sub), 1139 make_case(lambda x, y: x - y), 1140 make_case(torch.mul), 1141 make_case(lambda x, y: x * y), 1142 make_case(torch.div, input_getter=TensorFactory.randp1), 1143 make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), 1144 make_case(torch.pow, input_getter=TensorFactory.randp1), 1145 make_case(lambda x, y: x**y, input_getter=TensorFactory.randp1), 1146 ] 1147 test = self._vmap_test 1148 1149 for op, getter in cases: 1150 device = "cpu" 1151 B0, B1 = 7, 11 1152 1153 # Single vmap: op(Tensor, Tensor) 1154 test(op, (getter([B0, 3], device), getter([B0, 3], device))) 1155 test(op, (getter([B0], device), getter([B0, 2, 3], device))) 1156 test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1)) 1157 test( 1158 op, 1159 (getter([B0], device), getter([2, B0, 3], device)), 1160 in_dims=(0, 1), 1161 out_dims=1, 1162 ) 1163 test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None)) 1164 test( 1165 op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None) 1166 ) 1167 1168 # Nested vmap: op(Tensor, Tensor) 1169 test( 1170 vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)) 1171 ) 1172 test( 1173 vmap(op, in_dims=(None, 0)), 1174 (getter([B0, 2, 3], device), getter([B1, 3], device)), 1175 in_dims=(0, None), 1176 ) 1177 1178 # Python number overload: op(Tensor, Number) (and vice-versa) 1179 number = get_number(getter) 1180 self._test_unary(lambda t: op(t, number), getter, device) 1181 number = get_number(getter) 1182 self._test_unary(lambda t: op(number, t), getter, device) 1183 1184 # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor) 1185 test(op, (getter([B0], device), getter([B0], device, dtype=torch.double))) 1186 test(op, (getter([B0], device, dtype=torch.double), getter([B0], device))) 1187 test(op, (getter([B0], device), getter([B0], device))) 1188 1189 # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa) 1190 test(op, (getter([B0, 2], device), getter([B0], device, torch.double))) 1191 test(op, (getter([B0], device, torch.double), getter([B0, 2], device))) 1192 1193 if not torch.cuda.is_available(): 1194 continue 1195 1196 # TODO(rzou): fix the following 1197 # # Test cross-device scalars 1198 # number = get_number(getter) 1199 # self._test_unary(lambda t: op(t, number), getter, device='cuda') 1200 # self._test_unary(lambda t: op(number, t), getter, device='cuda') 1201 # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda') 1202 1203 def test_as_strided(self): 1204 def _test(sizes, strides, offset, tensor, lambd): 1205 result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor) 1206 expected = vmap(lambd)(tensor) 1207 self.assertTrue(result._base is expected._base) 1208 self.assertEqual(result, expected) 1209 1210 # single vmap test 1211 B0 = 5 1212 tensors = [ 1213 # contiguous 1214 torch.randn(B0, 2, 3), 1215 # non-contiguous 1216 torch.randn(B0, 3, 2).transpose(1, 2), 1217 # non-zero storage offset 1218 torch.randn(2, B0, 2, 3)[1], 1219 # non-contiguous strides, zero storage offset 1220 torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0], 1221 # non-contiguous strides, non-zero storage offset 1222 torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1], 1223 ] 1224 1225 for x in tensors: 1226 S0, S1 = x.stride()[1:] 1227 offset = x.storage_offset() 1228 1229 # Broadcast 1230 _test( 1231 [5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3) 1232 ) 1233 # transpose 1234 _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1)) 1235 # select 1236 _test([2], [S0], offset + S1, x, lambda x: x[:, 1]) 1237 1238 # Nested vmap test 1239 B1 = 7 1240 x = torch.randn(B1, B0, 2, 3) 1241 S0, S1 = x.stride()[2:] 1242 result = vmap( 1243 vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1 1244 )(x) 1245 expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x) 1246 self.assertTrue(result._base is expected._base) 1247 self.assertEqual(result, expected) 1248 1249 # Check that mal-formatted size/strides doesn't crash 1250 with self.assertRaisesRegex( 1251 RuntimeError, "size and stride must have the same length" 1252 ): 1253 x = torch.randn(B0, 2, 3).transpose(0, 1) 1254 vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x) 1255 1256 # Sanity check #1: we require the batch dims to be at the front of the 1257 # tensor (in memory layout). 1258 msg = "batch dims being vmapped over are at the front of the tensor" 1259 with self.assertRaisesRegex(RuntimeError, msg): 1260 x = torch.randn(2, B0, 3).transpose(0, 1) 1261 vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x) 1262 with self.assertRaisesRegex(RuntimeError, msg): 1263 x = torch.randn(B0, 2, 3, B1).movedim(3, 1) 1264 vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x) 1265 1266 # All the Sanity check #2{a,b,c} cases check that 1267 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1268 # doesn't index memory that is out of bounds of xs[i]. This condition 1269 # is important to the correctness of the as_strided batching rule 1270 # (see NOTE: [When will the as_strided_batching_rule fail?]) 1271 1272 # Sanity check #2a: The maximum indexable location of 1273 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1274 # is less than or equal to the maximum indexable location of xs[i]. 1275 msg = "This is not supported inside of vmap" 1276 with self.assertRaisesRegex(RuntimeError, msg): 1277 x = torch.randn(B0, 3) 1278 vmap(lambda x: x.as_strided([3], [1], 1))(x) 1279 with self.assertRaisesRegex(RuntimeError, msg): 1280 x = torch.randn(B0, 3, 5) 1281 vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x) 1282 with self.assertRaisesRegex(RuntimeError, msg): 1283 x = torch.randn(B0, B1, 3, 5) 1284 vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x) 1285 1286 # Sanity check #2b: The min indexable location of 1287 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1288 # is greater than or equal to the min indexable location of xs[i]. 1289 with self.assertRaisesRegex(RuntimeError, msg): 1290 x = torch.randn(2, B0, 3)[1] 1291 vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x) 1292 1293 # Sanity check #2c: 1294 # xs[i] is a zero-dim tensor, but 1295 # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1296 # is not 1297 with self.assertRaisesRegex(RuntimeError, msg): 1298 x = torch.randn(B0, 0, 3) 1299 vmap(lambda x: x.as_strided([3], [1]))(x) 1300 1301 def test_bmm(self): 1302 op = torch.bmm 1303 test = self._vmap_test 1304 B0, B1 = 7, 11 1305 1306 # shape mismatch 1307 msg = "Shape mismatch" 1308 with self.assertRaisesRegex(RuntimeError, msg): 1309 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1310 with self.assertRaisesRegex(RuntimeError, msg): 1311 vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2)) 1312 with self.assertRaisesRegex(RuntimeError, msg): 1313 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) 1314 1315 # left arg is vmapped 1316 test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None)) 1317 test( 1318 vmap(op, in_dims=(0, None)), 1319 (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)), 1320 in_dims=(1, None), 1321 ) 1322 1323 # right arg is vmapped 1324 test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) 1325 test( 1326 vmap(op, in_dims=(None, 0)), 1327 (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)), 1328 in_dims=(None, 1), 1329 ) 1330 1331 # both args are vmapped 1332 test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3))) 1333 test( 1334 vmap(op), 1335 (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), 1336 in_dims=(1, 0), 1337 ) 1338 test( 1339 vmap(op, in_dims=(0, None)), 1340 (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), 1341 in_dims=(None, 0), 1342 ) 1343 1344 def test_cat(self): 1345 test = self._vmap_test 1346 B0, B1 = 5, 7 1347 1348 # Quick hack b/c vmap can't accept a list of tensors as an argument 1349 def get_op(dim): 1350 def op(*tensors): 1351 return torch.cat(tensors, dim=dim) 1352 1353 return op 1354 1355 test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3))) 1356 test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0)) 1357 test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2)) 1358 test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2)) 1359 test( 1360 vmap(get_op(0), in_dims=(0, None)), 1361 (torch.rand(B1, 2), torch.rand(B0, 3)), 1362 in_dims=(None, 0), 1363 ) 1364 test( 1365 vmap(get_op(0), in_dims=(0, 0)), 1366 (torch.rand(B1, 2), torch.rand(B0, B1, 3)), 1367 in_dims=(None, 0), 1368 ) 1369 1370 def test_conj(self): 1371 op = torch.conj 1372 1373 def run_test(dtype): 1374 def get(shape): 1375 return torch.randn(shape, dtype=dtype) 1376 1377 B0, B1 = 7, 11 1378 test = self._vmap_test 1379 1380 # Single vmap, various in_dims / out_dims 1381 test(op, [get([B0, 3])]) 1382 test(op, [get([2, 5, B0, 3])], in_dims=2) 1383 test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) 1384 1385 # Doubly nested vmap 1386 test(vmap(op), [get([B0, B1])]) 1387 test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2) 1388 test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2) 1389 1390 # correctness tests 1391 run_test(torch.float) 1392 run_test(torch.cfloat) 1393 1394 # check that torch.conj on a non-complex tensor returns the same tensor 1395 real_tensor = torch.randn(3) 1396 result = vmap(op)(real_tensor) 1397 self.assertEqual(result.data_ptr(), real_tensor.data_ptr()) 1398 1399 def test_contiguous(self): 1400 op = Tensor.contiguous 1401 1402 self._test_unary(op, TensorFactory.randn, "cpu") 1403 1404 # check that contiguous returns the original tensor if the per-examples 1405 # are already contiguous 1406 B0 = 3 1407 x = torch.randn(B0, 2, 5, 7) 1408 x = x.movedim(0, 2) 1409 result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x) 1410 self.assertTrue(result is x) 1411 1412 msg = "NYI: querying is_contiguous inside of vmap for memory_format" 1413 tensor = torch.randn(B0, 3) 1414 with self.assertRaisesRegex(RuntimeError, msg): 1415 vmap(functools.partial(op, memory_format=torch.channels_last))(tensor) 1416 with self.assertRaisesRegex(RuntimeError, msg): 1417 vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor) 1418 1419 def test_stride(self): 1420 B0 = 3 1421 1422 x = torch.randn(B0, 2, 5, 7) 1423 1424 def foo(x): 1425 assert x.stride() == (7 * 5, 7, 1) 1426 return x 1427 1428 vmap(foo)(x) 1429 1430 x = torch.randn(2, B0, 5, 7).movedim(1, 0) 1431 1432 def bar(x): 1433 assert x.stride() == (7 * 5 * B0, 7, 1) 1434 return x 1435 1436 vmap(bar)(x) 1437 1438 def test_chunk(self): 1439 test = self._vmap_view_test 1440 op = torch.chunk 1441 B0, B1, B2 = 7, 11, 13 1442 1443 # tests for torch.split(self, split_size: int, dim) 1444 test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None)) 1445 test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None)) 1446 test( 1447 vmap(op, in_dims=(0, None, None)), 1448 (torch.rand(B1, 1023, B0, 5), 4, 0), 1449 in_dims=(2, None, None), 1450 ) 1451 test( 1452 vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 1453 (torch.rand(B1, 2, B0, 64, B2),), 1454 in_dims=2, 1455 ) 1456 1457 def test_clamp(self): 1458 clamp_cases = ( 1459 (lambda t: t.clamp(min=-0.5), TensorFactory.randn), 1460 (lambda t: t.clamp(max=0.5), TensorFactory.randn), 1461 (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn), 1462 (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn), 1463 (lambda t: t.clamp_max(max=0.5), TensorFactory.randn), 1464 ) 1465 for op, getter in clamp_cases: 1466 self._test_unary(op, getter, "cpu") 1467 1468 def test_comparison_ops(self): 1469 test = functools.partial(self._vmap_test, check_propagates_grad=False) 1470 1471 getter = TensorFactory.randn 1472 B0, B1 = 7, 11 1473 1474 ops = ( 1475 torch.eq, 1476 lambda x, y: x == y, 1477 torch.gt, 1478 lambda x, y: x > y, 1479 torch.ge, 1480 lambda x, y: x >= y, 1481 torch.le, 1482 lambda x, y: x <= y, 1483 torch.lt, 1484 lambda x, y: x < y, 1485 torch.ne, 1486 lambda x, y: x != y, 1487 ) 1488 1489 for op in ops: 1490 # Single vmap: op(Tensor, Tensor) 1491 test(op, (getter([B0, 3]), getter([B0, 3]))) 1492 test(op, (getter([B0]), getter([B0, 2, 3]))) 1493 test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1)) 1494 test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1) 1495 test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None)) 1496 test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None)) 1497 1498 # Nested vmap: op(Tensor, Tensor) 1499 test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3]))) 1500 test( 1501 vmap(op, in_dims=(None, 0)), 1502 (getter([B0, 2, 3]), getter([B1, 3])), 1503 in_dims=(0, None), 1504 ) 1505 1506 # test number as inputs 1507 number = getter([]).item() 1508 self._test_unary( 1509 lambda t: op(t, number), getter, "cpu", check_propagates_grad=False 1510 ) 1511 1512 def test_diagonal(self): 1513 tensor = torch.randn(3, 5, 7, 11, 13) 1514 test = self._vmap_view_test 1515 op = torch.diagonal 1516 test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None)) 1517 test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None)) 1518 test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None)) 1519 test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1) 1520 test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1) 1521 test( 1522 vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3), 1523 (tensor,), 1524 in_dims=1, 1525 out_dims=1, 1526 ) 1527 1528 def test_dot(self): 1529 op = torch.dot 1530 test = self._vmap_test 1531 B0, B1 = 7, 11 1532 1533 # shape mismatch 1534 msg = "Shape mismatch" 1535 with self.assertRaisesRegex(RuntimeError, msg): 1536 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1537 with self.assertRaisesRegex(RuntimeError, msg): 1538 vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) 1539 with self.assertRaisesRegex(RuntimeError, msg): 1540 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2)) 1541 1542 # left arg is vmapped 1543 test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None)) 1544 test( 1545 vmap(op, in_dims=(0, None)), 1546 (torch.rand(B1, B0, 5), torch.rand(5)), 1547 in_dims=(1, None), 1548 ) 1549 1550 # right arg is vmapped 1551 test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0)) 1552 test( 1553 vmap(op, in_dims=(None, 0)), 1554 (torch.rand(5), torch.rand(B1, B0, 5)), 1555 in_dims=(None, 1), 1556 ) 1557 1558 # both args are vmapped 1559 test(op, (torch.rand(B0, 5), torch.rand(B0, 5))) 1560 test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)) 1561 test( 1562 vmap(op, in_dims=(0, None)), 1563 (torch.rand(B1, 5), torch.rand(B0, 5)), 1564 in_dims=(None, 0), 1565 ) 1566 1567 def test_expand_as(self): 1568 op = torch.Tensor.expand_as 1569 test = self._vmap_view_test 1570 B0, B1, B2 = 7, 11, 13 1571 test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5))) 1572 test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None)) 1573 test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) 1574 test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5))) 1575 test( 1576 vmap(op), 1577 (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), 1578 in_dims=(0, 1), 1579 ) 1580 test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) 1581 test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5))) 1582 1583 def test_fill_and_zero_inplace(self): 1584 test = functools.partial(self._vmap_test, check_propagates_grad=False) 1585 B0, B1 = 7, 11 1586 ops = ( 1587 lambda t: t.fill_(0.1), 1588 lambda t: t.fill_(torch.tensor(0.2)), 1589 lambda t: t.zero_(), 1590 ) 1591 1592 for op in ops: 1593 # Single vmap, various in_dims / out_dims 1594 test(op, [TensorFactory.randn([B0, 3])]) 1595 test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2) 1596 test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) 1597 1598 # Doubly nested vmap 1599 test(vmap(op), [TensorFactory.randn([B0, B1])]) 1600 test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2) 1601 test( 1602 vmap(op, in_dims=2), 1603 [TensorFactory.randn([2, 5, B0, B1, 3])], 1604 in_dims=2, 1605 out_dims=2, 1606 ) 1607 1608 # test when value is a batched tensor for fill_ operator 1609 B0, B1 = 3, 5 1610 test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)]) 1611 1612 with self.assertRaisesRegex( 1613 RuntimeError, r"output with shape .+ doesn't match the broadcast shape" 1614 ): 1615 # Runtime Error is thrown when the tensor being written to isn't being vmapped over 1616 vmap(Tensor.fill_, (None, 0))( 1617 TensorFactory.randn([B0, B1]), TensorFactory.randn([B0]) 1618 ) 1619 1620 def _test_complex_views(self, op, dtypes): 1621 test = self._vmap_view_test 1622 1623 def run_test(op, dtype): 1624 def get(shape): 1625 return torch.randn(shape, dtype=dtype) 1626 1627 B0, B1 = 7, 11 1628 1629 # Single vmap, various in_dims / out_dims 1630 test(op, [get([B0, 3])]) 1631 test(op, [get([3, B0])], in_dims=1) 1632 test(op, [get([2, 5, B0, 3])], in_dims=2) 1633 test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) 1634 1635 # Doubly nested vmap 1636 test(vmap(op), [get([B0, B1])]) 1637 test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4) 1638 test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2) 1639 1640 for dtype in dtypes: 1641 run_test(op, dtype) 1642 1643 def test_real(self): 1644 self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble]) 1645 1646 def test_imag(self): 1647 self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble]) 1648 1649 def test_view_as_real(self): 1650 self._test_complex_views( 1651 torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble] 1652 ) 1653 1654 def test_view_as_complex(self): 1655 def run_test(dtype): 1656 def get(shape): 1657 return torch.randn(shape, dtype=dtype) 1658 1659 op = torch.view_as_complex 1660 test = self._vmap_view_test 1661 B0, B1 = 7, 11 1662 1663 # Single vmap, various in_dims / out_dims 1664 test(op, [get([B0, 3, 2])]) 1665 test(op, [get([2, 5, B0, 3, 2])], in_dims=2) 1666 test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2) 1667 1668 # Doubly nested vmap 1669 test(vmap(op), [get([B0, B1, 2])]) 1670 test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2) 1671 test( 1672 vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], in_dims=2, out_dims=2 1673 ) 1674 1675 # Interesting case #1: Batch dim directly before dim of size 2 1676 test(op, [get([3, B0, 2])], in_dims=1) 1677 test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2) 1678 1679 # Interesting case #2: Batch dim at end of tensor, success cases 1680 # view_as_complex requires that the dim with size 2 have stride 1 1681 # in order for the view to function propertly 1682 test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1) 1683 test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)]) 1684 test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)]) 1685 1686 # Interesting case #3: Batch dim at end of tensor, failure cases 1687 msg = "Tensor must have a last dimension with stride 1" 1688 with self.assertRaisesRegex(RuntimeError, msg): 1689 vmap(op, in_dims=1)(get([2, B0])) 1690 with self.assertRaisesRegex(RuntimeError, msg): 1691 vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1])) 1692 1693 # Invalid input: no dimension of size 2 1694 msg = "Input tensor must have one or more dimensions" 1695 with self.assertRaisesRegex(RuntimeError, msg): 1696 vmap(op)(get([B0])) 1697 with self.assertRaisesRegex(RuntimeError, msg): 1698 vmap(vmap(op))(get([B0, B1])) 1699 1700 # Invalid input: Batch dim has size 2, but the logical last dim does 1701 # not have size 2 1702 msg = "Tensor must have a last dimension of size 2" 1703 with self.assertRaisesRegex(RuntimeError, msg): 1704 vmap(op, in_dims=1)(get([3, 2])) 1705 1706 for dtype in [torch.float, torch.double]: 1707 run_test(dtype) 1708 1709 def test_is_complex(self): 1710 ctensor = torch.randn(3, dtype=torch.cfloat) 1711 tensor = torch.randn(3) 1712 1713 def foo(x): 1714 if x.is_complex(): 1715 return torch.tensor(1) 1716 else: 1717 return torch.tensor(0) 1718 1719 self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1])) 1720 self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0])) 1721 1722 def test_is_floating_point(self): 1723 float_tensor = torch.tensor([1.0, 2.0, 3.0]) 1724 long_tensor = torch.tensor([1, 2, 3]) 1725 1726 def foo(x): 1727 if x.is_floating_point(): 1728 return torch.tensor(1) 1729 else: 1730 return torch.tensor(0) 1731 1732 self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1])) 1733 self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0])) 1734 1735 def test_is_contiguous(self): 1736 def foo(x): 1737 if x.is_contiguous(): 1738 return torch.tensor(1.0) 1739 else: 1740 return torch.tensor(0.0) 1741 1742 B0, B1 = 3, 5 1743 1744 # Single batch dim 1745 contig = torch.randn(B0, 2, 7) 1746 self.assertEqual(vmap(foo)(contig), torch.ones(B0)) 1747 1748 noncontig = torch.randn(2, B0, 7) 1749 self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0)) 1750 1751 noncontig = torch.randn(2, B0, 7).movedim(1, 0) 1752 self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0)) 1753 1754 noncontig = torch.randn(2, 7, B0) 1755 self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0)) 1756 1757 # Multiple batch dims 1758 contig = torch.randn(B0, B1, 3) 1759 self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) 1760 1761 contig = torch.randn(B1, B0, 3) 1762 self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1)) 1763 1764 contig = torch.randn(B1, B0, 3).movedim(0, 1) 1765 self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) 1766 1767 noncontig = torch.randn(B0, 3, B1) 1768 self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1)) 1769 1770 # is_contiguous on empty tensor is True 1771 def bar(x): 1772 assert x.is_contiguous() 1773 return x 1774 1775 vmap(bar)(torch.randn(B0, 0, 3)) 1776 vmap(bar, in_dims=1)(torch.randn(0, B0, 3)) 1777 vmap(bar)(torch.randn(B0, 0, 3).mT) 1778 1779 # is_contiguous with other memory formats 1780 def baz(x, memory_format): 1781 x.is_contiguous(memory_format=memory_format) 1782 return x 1783 1784 msg = "NYI: querying is_contiguous inside of vmap for memory_format" 1785 tensor = torch.randn(B0, 2, 7, 3) 1786 with self.assertRaisesRegex(RuntimeError, msg): 1787 vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor) 1788 with self.assertRaisesRegex(RuntimeError, msg): 1789 vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) 1790 1791 def test_movedim(self): 1792 op = torch.movedim 1793 test = self._vmap_view_test 1794 B0, B1, B2 = 7, 11, 13 1795 1796 # movedim(tensor, int, int) variant 1797 test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None)) 1798 test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None)) 1799 test( 1800 vmap(op, in_dims=(0, None, None)), 1801 (torch.rand(B1, 2, B0, 5), 0, 1), 1802 in_dims=(2, None, None), 1803 ) 1804 test( 1805 vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), 1806 (torch.rand(B1, 2, B0, 5, B2), 0, 1), 1807 in_dims=(2, None, None), 1808 ) 1809 1810 # movedim(tensor, intlist, intlist) variant 1811 test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None)) 1812 test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None)) 1813 test( 1814 vmap(op, in_dims=(0, None, None)), 1815 (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), 1816 in_dims=(2, None, None), 1817 ) 1818 test( 1819 vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), 1820 (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), 1821 in_dims=(2, None, None), 1822 ) 1823 1824 def test_mm(self): 1825 op = torch.mm 1826 test = self._vmap_test 1827 B0, B1 = 7, 11 1828 1829 # shape mismatch 1830 msg = "Shape mismatch" 1831 with self.assertRaisesRegex(RuntimeError, msg): 1832 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1833 with self.assertRaisesRegex(RuntimeError, msg): 1834 vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) 1835 with self.assertRaisesRegex(RuntimeError, msg): 1836 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) 1837 1838 # left arg is vmapped 1839 test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None)) 1840 test( 1841 vmap(op, in_dims=(0, None)), 1842 (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)), 1843 in_dims=(1, None), 1844 ) 1845 1846 # right arg is vmapped 1847 test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0)) 1848 test( 1849 vmap(op, in_dims=(None, 0)), 1850 (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)), 1851 in_dims=(None, 1), 1852 ) 1853 1854 # both args are vmapped 1855 test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2))) 1856 test( 1857 vmap(op), 1858 (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), 1859 in_dims=(1, 0), 1860 ) 1861 test( 1862 vmap(op, in_dims=(0, None)), 1863 (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), 1864 in_dims=(None, 0), 1865 ) 1866 1867 def test_mv(self): 1868 op = torch.mv 1869 test = self._vmap_test 1870 B0, B1 = 7, 11 1871 1872 # shape mismatch 1873 msg = "Shape mismatch" 1874 with self.assertRaisesRegex(RuntimeError, msg): 1875 vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1876 with self.assertRaisesRegex(RuntimeError, msg): 1877 vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2)) 1878 with self.assertRaisesRegex(RuntimeError, msg): 1879 vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2)) 1880 1881 # left arg is vmapped 1882 test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None)) 1883 test( 1884 vmap(op, in_dims=(0, None)), 1885 (torch.rand(B1, B0, 2, 5), torch.rand(5)), 1886 in_dims=(1, None), 1887 ) 1888 1889 # right arg is vmapped 1890 test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0)) 1891 test( 1892 vmap(op, in_dims=(None, 0)), 1893 (torch.rand(2, 5), torch.rand(B1, B0, 5)), 1894 in_dims=(None, 1), 1895 ) 1896 1897 # both args are vmapped 1898 test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5))) 1899 test( 1900 vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0) 1901 ) 1902 test( 1903 vmap(op, in_dims=(0, None)), 1904 (torch.rand(B1, 2, 5), torch.rand(B0, 5)), 1905 in_dims=(None, 0), 1906 ) 1907 1908 def test_narrow(self): 1909 op = torch.narrow 1910 test = self._vmap_view_test 1911 B0, B1, B2 = 7, 11, 13 1912 1913 test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None)) 1914 test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None)) 1915 test( 1916 vmap(op, in_dims=(0, None, None, None)), 1917 (torch.rand(B1, 2, B0, 5), 1, 0, 0), 1918 in_dims=(2, None, None, None), 1919 ) 1920 test( 1921 vmap( 1922 vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None) 1923 ), 1924 (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), 1925 in_dims=(2, None, None, None), 1926 ) 1927 1928 def test_new_empty(self): 1929 # Empty is non-deterministic so we just check that the shape of the 1930 # output tensor is what we expect and that the vmap fallback isn't used. 1931 op = Tensor.new_empty 1932 1933 B0, B1 = 7, 11 1934 1935 result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0)) 1936 self.assertEqual(result.shape, [B0, 2, 3]) 1937 1938 result = vmap(lambda x: op(x, []))(torch.randn(B0)) 1939 self.assertEqual(result.shape, [B0]) 1940 1941 result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1)) 1942 self.assertEqual(result.shape, [B0, B1, 2, 3]) 1943 1944 def test_new_empty_strided(self): 1945 # Empty is non-deterministic so we just check that the size and shape 1946 # of the output are what we expect and that the vmap fallback isn't used 1947 B0, B1 = 7, 11 1948 1949 def _test_single_vmap(size, stride, B0): 1950 x = torch.randn(B0) 1951 result = vmap(lambda x: x.new_empty_strided(size, stride))(x) 1952 S = torch.empty_strided(size, stride).storage().size() 1953 self.assertEqual(result.shape, [B0] + size) 1954 self.assertEqual(result.stride(), [S] + stride) 1955 1956 def _test_double_vmap(size, stride, B0, B1): 1957 x = torch.randn(B0, B1) 1958 result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x) 1959 S = torch.empty_strided(size, stride).storage().size() 1960 self.assertEqual(result.shape, [B0, B1] + size) 1961 self.assertEqual(result.stride(), [B1 * S, S] + stride) 1962 1963 x = torch.randn(B1, B0) 1964 result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)( 1965 x 1966 ) 1967 S = x.new_empty_strided(size, stride).storage().size() 1968 self.assertEqual(result.shape, [B0, B1] + size) 1969 self.assertEqual(result.stride(), [B1 * S, S] + stride) 1970 1971 # contiguous case 1972 _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0) 1973 _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1) 1974 1975 # expanded 1976 _test_single_vmap([2, 3, 5], [0, 5, 1], B0) 1977 _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1) 1978 1979 # some of these cases are pretty strange, just verifying that if 1980 # empty_strided allows them then BatchedTensor.new_empty_strided 1981 # can as well 1982 for shape in [[2, 3, 4], [0, 2, 0]]: 1983 for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]: 1984 _test_single_vmap(shape, strides, B0) 1985 _test_double_vmap(shape, strides, B0, B1) 1986 1987 def test_new_zeros(self): 1988 op = Tensor.new_zeros 1989 test = functools.partial(self._vmap_test, check_propagates_grad=False) 1990 B0, B1 = 7, 11 1991 1992 test(lambda x: op(x, 2, 3), (torch.rand(B0),)) 1993 test(lambda x: op(x, []), (torch.rand(B0),)) 1994 test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),)) 1995 1996 def test_select(self): 1997 op = torch.select 1998 test = self._vmap_view_test 1999 B0, B1, B2 = 7, 11, 13 2000 test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None)) 2001 test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None)) 2002 test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2003 test( 2004 vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), 2005 (torch.rand(B1, 2, B0, B2, 5),), 2006 in_dims=2, 2007 ) 2008 2009 def test_stack(self): 2010 test = self._vmap_test 2011 B0, B1 = 5, 7 2012 2013 # Quick hack b/c vmap can't accept a list of tensors as an argument 2014 def get_op(dim): 2015 def op(*tensors): 2016 return torch.stack(tensors, dim=dim) 2017 2018 return op 2019 2020 test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3))) 2021 test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0)) 2022 test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) 2023 test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) 2024 test( 2025 vmap(get_op(0), in_dims=(0, None)), 2026 (torch.rand(B1, 2), torch.rand(B0, 2)), 2027 in_dims=(None, 0), 2028 ) 2029 test( 2030 vmap(get_op(0), in_dims=(0, 0)), 2031 (torch.rand(B1, 2), torch.rand(B0, B1, 2)), 2032 in_dims=(None, 0), 2033 ) 2034 2035 def test_slice(self): 2036 test = self._vmap_view_test 2037 B0, B1, B2 = 7, 11, 13 2038 test(lambda t: t[0:1], (torch.rand(B0, 3, 5),)) 2039 test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2) 2040 test( 2041 vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2 2042 ) 2043 test( 2044 vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), 2045 (torch.rand(3, 5, B0, B1, B2),), 2046 in_dims=2, 2047 ) 2048 2049 def test_squeeze(self): 2050 test = self._vmap_view_test 2051 op = torch.squeeze 2052 B0, B1 = 1, 11 2053 test(op, (torch.rand(B0),)) 2054 test(op, (torch.rand(B0, 3, 5),)) 2055 test(op, (torch.rand(1, B0, 5),), in_dims=1) 2056 test(op, (torch.rand(B0, 0, 1, 5, 1),)) 2057 test(op, (torch.rand(B0, 1, 1, 1, 1),)) 2058 test(vmap(op), (torch.rand(B0, B1, 1),)) 2059 test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2) 2060 2061 def test_sum_dim(self): 2062 test = self._vmap_test 2063 B0, B1 = 5, 7 2064 2065 # Single vmap, various in_dims / out_dims 2066 test(lambda x: x.sum(()), [torch.randn([B0])]) 2067 test(lambda x: x.sum(()), [torch.randn([B0, 2])]) 2068 test(lambda x: x.sum(0), [torch.randn([B0])]) 2069 test(lambda x: x.sum(-1), [torch.randn([B0])]) 2070 test(lambda x: x.sum(0), [torch.randn([B0, 3])]) 2071 test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2) 2072 test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) 2073 2074 # Doubly nested vmap 2075 test(vmap(lambda x: x.sum(())), [torch.randn([B0, B1])]) 2076 test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])]) 2077 test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])]) 2078 test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2) 2079 test( 2080 vmap(lambda x: x.sum(2), in_dims=2), 2081 [torch.randn([2, 5, B0, B1, 3])], 2082 in_dims=2, 2083 out_dims=2, 2084 ) 2085 2086 def test_reshape(self): 2087 test = self._vmap_test 2088 B0, B1, B2 = 7, 11, 13 2089 op = torch.reshape 2090 test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True) 2091 test( 2092 op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False 2093 ) 2094 test( 2095 vmap(lambda t: t.reshape([-1])), 2096 (torch.rand(B0, B1, 2, 5),), 2097 check_view=True, 2098 ) 2099 test( 2100 vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1), 2101 (torch.rand(3, B1, 2, B2, 5, B0),), 2102 in_dims=5, 2103 check_view=False, 2104 ) 2105 2106 def test_reshape_as(self): 2107 test = self._vmap_test 2108 B0, B1, B2 = 7, 11, 13 2109 op = torch.Tensor.reshape_as 2110 test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True) 2111 test( 2112 op, 2113 (torch.rand(2 * 5), torch.rand(B0, 2, 5)), 2114 in_dims=(None, 0), 2115 check_view=True, 2116 ) 2117 test( 2118 op, 2119 (torch.rand(B0, 2 * 5), torch.rand(2, 5)), 2120 in_dims=(0, None), 2121 check_view=True, 2122 ) 2123 2124 test( 2125 op, 2126 (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), 2127 in_dims=(1, None), 2128 check_view=False, 2129 ) 2130 2131 test( 2132 vmap(op), 2133 (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), 2134 check_view=True, 2135 ) 2136 test( 2137 vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)), 2138 (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)), 2139 in_dims=(5, 0), 2140 check_view=False, 2141 ) 2142 2143 def test_result_type(self): 2144 def scalar_tensor_with_dtype(op): 2145 def wrapped(*args, **kwargs): 2146 dtype = op(*args, **kwargs) 2147 return torch.ones([], dtype=dtype) 2148 2149 return wrapped 2150 2151 test = self._vmap_test 2152 op = scalar_tensor_with_dtype(torch.result_type) 2153 2154 B0 = 2 2155 2156 test( 2157 op, 2158 (torch.randn(B0), torch.randn(B0, dtype=torch.float64)), 2159 check_propagates_grad=False, 2160 ) 2161 test( 2162 op, 2163 (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)), 2164 check_propagates_grad=False, 2165 ) 2166 2167 test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False) 2168 test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False) 2169 2170 test( 2171 lambda x: op(x, torch.tensor(1)), 2172 (torch.randn(B0),), 2173 check_propagates_grad=False, 2174 ) 2175 test( 2176 lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), 2177 (torch.randn(B0),), 2178 check_propagates_grad=False, 2179 ) 2180 2181 test( 2182 op, 2183 (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)), 2184 check_propagates_grad=False, 2185 ) 2186 test( 2187 op, 2188 (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)), 2189 check_propagates_grad=False, 2190 ) 2191 2192 test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False) 2193 test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False) 2194 2195 test( 2196 lambda x: op(x, torch.tensor(1)), 2197 (torch.randn(B0, 2),), 2198 check_propagates_grad=False, 2199 ) 2200 test( 2201 lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), 2202 (torch.randn(B0, 2),), 2203 check_propagates_grad=False, 2204 ) 2205 2206 test( 2207 op, 2208 (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)), 2209 check_propagates_grad=False, 2210 ) 2211 test( 2212 op, 2213 (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)), 2214 check_propagates_grad=False, 2215 ) 2216 2217 @skipIfTorchDynamo("too slow") 2218 def test_tensor_split(self): 2219 test = self._vmap_view_test 2220 op = torch.tensor_split 2221 B0, B1, B2 = 7, 11, 13 2222 2223 # tests for torch.tensor_split(self, indices_or_sections: int, dim) 2224 test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None)) 2225 test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None)) 2226 test( 2227 vmap(op, in_dims=(0, None, None)), 2228 (torch.rand(B1, 1023, B0, 5), 256, 0), 2229 in_dims=(2, None, None), 2230 ) 2231 test( 2232 vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 2233 (torch.rand(B1, 2, B0, 64, B2),), 2234 in_dims=2, 2235 ) 2236 2237 # tests for torch.tensor_split(self, indices_or_sections: List[int], dim) 2238 test( 2239 op, 2240 (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), 2241 in_dims=(0, None, None), 2242 ) 2243 test( 2244 op, 2245 (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), 2246 in_dims=(1, None, None), 2247 ) 2248 test( 2249 vmap(op, in_dims=(0, None, None)), 2250 (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0), 2251 in_dims=(2, None, None), 2252 ) 2253 test( 2254 vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)), 2255 (torch.rand(B1, 2, B0, 64, B2),), 2256 in_dims=2, 2257 ) 2258 2259 def test_split(self): 2260 test = self._vmap_view_test 2261 op = torch.split 2262 B0, B1, B2 = 7, 11, 13 2263 2264 # tests for torch.split(self, split_size: int, dim) 2265 test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None)) 2266 test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None)) 2267 test( 2268 vmap(op, in_dims=(0, None, None)), 2269 (torch.rand(B1, 1023, B0, 5), 256, 0), 2270 in_dims=(2, None, None), 2271 ) 2272 test( 2273 vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 2274 (torch.rand(B1, 2, B0, 64, B2),), 2275 in_dims=2, 2276 ) 2277 2278 # tests for torch.split(self, split_size: List[int], dim) 2279 test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None)) 2280 test( 2281 op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None) 2282 ) 2283 test( 2284 vmap(op, in_dims=(0, None, None)), 2285 (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0), 2286 in_dims=(2, None, None), 2287 ) 2288 test( 2289 vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), 2290 (torch.rand(B1, 2, B0, 64, B2),), 2291 in_dims=2, 2292 ) 2293 2294 def test_trace(self): 2295 op = torch.trace 2296 test = self._vmap_test 2297 B0, B1, B2 = 7, 11, 13 2298 2299 test(op, (torch.rand(B0, 2, 5),)) 2300 test(op, (torch.rand(2, B0, 5),), in_dims=1) 2301 test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2302 test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) 2303 2304 def test_transpose(self): 2305 op = torch.transpose 2306 test = self._vmap_view_test 2307 2308 B0, B1, B2 = 7, 11, 13 2309 test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),)) 2310 test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),)) 2311 test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),)) 2312 test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1) 2313 test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2314 test( 2315 vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)), 2316 (torch.rand(B1, 2, B0, 5, B2),), 2317 in_dims=2, 2318 ) 2319 2320 # Special case: scalar tensor 2321 for dim1, dim2 in itertools.product([0, -1], [0, -1]): 2322 x = torch.rand(B0) 2323 result = vmap(lambda x: op(x, dim1, dim2))(x) 2324 self.assertTrue(result is x) 2325 2326 def test_t(self): 2327 op = torch.t 2328 test = self._vmap_view_test 2329 B0, B1, B2 = 7, 11, 13 2330 test(op, (torch.rand(B0, 2, 5),)) 2331 test(op, (torch.rand(2, B0, 5),), in_dims=1) 2332 test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2333 test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) 2334 2335 def test_T_numpy(self): 2336 def op(t): 2337 return t.T 2338 2339 test = self._vmap_view_test 2340 B0, B1, B2 = 7, 11, 13 2341 test(op, (torch.rand(B0, 2, 3, 5),)) 2342 test(op, (torch.rand(2, B0, 3, 5),), in_dims=1) 2343 test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2344 test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2) 2345 test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2) 2346 2347 def test_to(self): 2348 test = self._vmap_test 2349 B0, B1 = 7, 11 2350 2351 test(lambda t: t.to("cpu"), (torch.rand(B0),)) 2352 test(lambda t: t.to(torch.double), (torch.rand(B0),)) 2353 test( 2354 lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)) 2355 ) 2356 test( 2357 lambda t, o: t.to(o), 2358 (torch.rand(B0), torch.randn(B0, dtype=torch.float64)), 2359 in_dims=(0, None), 2360 ) 2361 test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),)) 2362 2363 # also test some casting methods 2364 test(lambda t: t.double(), (torch.rand(B0),)) 2365 test(lambda t: t.float(), (torch.rand(B0),)) 2366 test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False) 2367 test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False) 2368 2369 def test_unfold(self): 2370 op = torch.Tensor.unfold 2371 test = self._vmap_view_test 2372 B0, B1, B2 = 3, 2, 5 2373 2374 test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None)) 2375 test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None)) 2376 test( 2377 vmap(op, in_dims=(0, None, None, None)), 2378 (torch.rand(B1, 7, B0, 11), 1, 5, 1), 2379 in_dims=(2, None, None, None), 2380 ) 2381 test( 2382 vmap( 2383 vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None) 2384 ), 2385 (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), 2386 in_dims=(2, None, None, None), 2387 ) 2388 2389 def test_unbind(self): 2390 test = self._vmap_view_test 2391 op = torch.unbind 2392 B0, B1, B2 = 7, 11, 13 2393 2394 test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None)) 2395 test(op, (torch.rand(B0, 2, 0),)) 2396 test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None)) 2397 test( 2398 vmap(op, in_dims=(0, None)), 2399 (torch.rand(B1, 1023, B0, 5), 1), 2400 in_dims=(2, None), 2401 ) 2402 test( 2403 vmap(vmap(lambda t: op(t, dim=1), in_dims=2)), 2404 (torch.rand(B1, 2, B0, 32, B2),), 2405 in_dims=2, 2406 ) 2407 2408 def test_view(self): 2409 test = self._vmap_view_test 2410 B0, B1, B2 = 7, 11, 13 2411 op = torch.Tensor.view 2412 2413 # We should error out if the view would produce an incorrect result 2414 with self.assertRaises(RuntimeError): 2415 vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10]) 2416 2417 test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None)) 2418 test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None)) 2419 test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),)) 2420 test( 2421 vmap(vmap(lambda t: t.reshape([-1])), in_dims=1), 2422 (torch.rand(B2, B0, B1, 3, 2, 5),), 2423 in_dims=1, 2424 ) 2425 2426 def test_view_as(self): 2427 test = self._vmap_view_test 2428 B0, B1, B2 = 7, 11, 13 2429 op = torch.Tensor.view_as 2430 2431 # We should error out if the view would produce an incorrect result 2432 with self.assertRaises(RuntimeError): 2433 vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10)) 2434 2435 test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5))) 2436 test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0)) 2437 test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None)) 2438 2439 test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None)) 2440 2441 test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10))) 2442 test( 2443 vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)), 2444 (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)), 2445 in_dims=(2, 0), 2446 ) 2447 2448 def test_no_random_op_support(self): 2449 B0 = 2 2450 2451 captured = torch.rand(3) 2452 2453 random_ops = [ 2454 # out-of-place on BatchedTensor 2455 (torch.bernoulli, (torch.rand(B0, 1),)), 2456 (lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)), 2457 (lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)), 2458 (torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))), 2459 (lambda t: torch.normal(t, 1.0), (torch.randn(B0, 1),)), 2460 (lambda t: torch.normal(0.0, t), (torch.randn(B0, 1),)), 2461 (torch.poisson, (torch.rand(B0, 1),)), 2462 (torch.rand_like, (torch.rand(B0, 1),)), 2463 (torch.randn_like, (torch.rand(B0, 1),)), 2464 (lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)), 2465 (lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)), 2466 # out-of-place on captured tensor 2467 (lambda t: torch.bernoulli(captured), (torch.rand(B0),)), 2468 (lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)), 2469 (lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)), 2470 (lambda t: torch.normal(captured, captured), (torch.randn(B0),)), 2471 (lambda t: torch.normal(captured, 1.0), (torch.randn(B0),)), 2472 (lambda t: torch.normal(0.0, captured), (torch.randn(B0),)), 2473 (lambda t: torch.poisson(captured), (torch.rand(B0),)), 2474 (lambda t: torch.rand_like(captured), (torch.rand(B0),)), 2475 (lambda t: torch.randn_like(captured), (torch.rand(B0),)), 2476 (lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)), 2477 (lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)), 2478 # in-place on BatchedTensor 2479 (lambda t: t.bernoulli_(), (torch.randn(B0, 1),)), 2480 (lambda t: t.cauchy_(), (torch.randn(B0, 1),)), 2481 (lambda t: t.exponential_(), (torch.randn(B0, 1),)), 2482 (lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)), 2483 (lambda t: t.log_normal_(), (torch.randn(B0, 1),)), 2484 (lambda t: t.normal_(), (torch.randn(B0, 1),)), 2485 (lambda t: t.random_(), (torch.randn(B0, 1),)), 2486 (lambda t: t.random_(0, 2), (torch.randn(B0, 1),)), 2487 (lambda t: t.random_(2), (torch.randn(B0, 1),)), 2488 (lambda t: t.uniform_(), (torch.randn(B0, 1),)), 2489 # in-place on captured tensor 2490 (lambda t: captured.bernoulli_(), (torch.randn(B0),)), 2491 (lambda t: captured.cauchy_(), (torch.randn(B0),)), 2492 (lambda t: captured.exponential_(), (torch.randn(B0),)), 2493 (lambda t: captured.geometric_(0.5), (torch.randn(B0),)), 2494 (lambda t: captured.log_normal_(), (torch.randn(B0),)), 2495 (lambda t: captured.normal_(), (torch.randn(B0),)), 2496 (lambda t: captured.random_(), (torch.randn(B0),)), 2497 (lambda t: captured.random_(0, 2), (torch.randn(B0),)), 2498 (lambda t: captured.random_(2), (torch.randn(B0),)), 2499 (lambda t: captured.uniform_(), (torch.randn(B0),)), 2500 # factory functions 2501 (lambda t: torch.rand(1), (torch.randn(B0),)), 2502 (lambda t: torch.randn(1), (torch.randn(B0),)), 2503 (lambda t: torch.randint(5, [1]), (torch.randn(B0),)), 2504 (lambda t: torch.randperm(5), (torch.randn(B0),)), 2505 ] 2506 for op, args in random_ops: 2507 with self.assertRaisesRegex( 2508 RuntimeError, "vmap: We do not yet support calling random operations" 2509 ): 2510 vmap(op)(*args) 2511 2512 2513def construct_v(output, batch_size): 2514 return torch.randn( 2515 batch_size, *output.shape, dtype=output.dtype, device=output.device 2516 ) 2517 2518 2519def as_tuple(x): 2520 if isinstance(x, tuple): 2521 return x 2522 elif isinstance(x, list): 2523 return tuple(x) 2524 else: 2525 return (x,) 2526 2527 2528def differentiable(args): 2529 return tuple( 2530 arg 2531 for arg in as_tuple(args) 2532 if isinstance(arg, torch.Tensor) and arg.requires_grad 2533 ) 2534 2535 2536def _get_rand_no_zeros(*args, **kwargs): 2537 requires_grad = kwargs.get("requires_grad", False) 2538 kwargs_without_requires_grad = kwargs.copy() 2539 kwargs_without_requires_grad["requires_grad"] = False 2540 result = torch.rand(*args, **kwargs_without_requires_grad) 2541 return result.clamp_min_(0.1).requires_grad_(requires_grad) 2542 2543 2544class TestVmapBatchedGradientLegacy(Namespace.TestVmapBaseLegacy): 2545 def _vmap_test(self, *args, **kwargs): 2546 return _vmap_test(self, *args, **kwargs) 2547 2548 # Tests batched gradient computation of outputs = op(*args, **kwargs) 2549 # by comparing it to a sequential map+stack fallback. 2550 # 2551 # output_process_fn: a function that maps the outputs to the part 2552 # that should be differentiated. 2553 # batch_size: the batch dim size for the batched grad 2554 def _batched_grad_test( 2555 self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3 2556 ): 2557 if kwargs is None: 2558 kwargs = {} 2559 outputs = op(*args, **kwargs) 2560 outputs = differentiable(output_process_fn(outputs)) 2561 batched_vectors = tuple(construct_v(out, batch_size) for out in outputs) 2562 2563 def vector_jacobian_product(*vectors): 2564 return torch.autograd.grad( 2565 outputs, differentiable(args), vectors, retain_graph=True 2566 ) 2567 2568 self._vmap_test( 2569 vector_jacobian_product, batched_vectors, check_propagates_grad=False 2570 ) 2571 2572 # Tests batched second grad computation of outputs = op(*args, **kwargs). 2573 # by comparing it to a sequential map+stack fallback. 2574 # 2575 # output_process_fn: a function that maps the outputs to the part 2576 # that should be differentiated. 2577 # batch_size: the batch dim size for the batched grad 2578 # 2579 # NB: we only test computing batched gradients in the second gradient 2580 # computation. One specific use case that does this is computing the hessian 2581 # matrix of a scalar-valued function; this is useful in Bayesian Logistic 2582 # Regression. 2583 # It might be useful to have a test that computes batched first gradients and 2584 # then uses those to compute batched second gradients in the future. 2585 def _batched_grad_grad_test( 2586 self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3 2587 ): 2588 if kwargs is None: 2589 kwargs = {} 2590 outputs = op(*args, **kwargs) 2591 outputs = differentiable(output_process_fn(outputs)) 2592 ones = tuple(torch.ones_like(out) for out in outputs) 2593 # Same thing as summing together all of the outputs and calling .backward() 2594 first_grads = torch.autograd.grad( 2595 outputs, differentiable(args), ones, create_graph=True 2596 ) 2597 first_grads = differentiable(first_grads) 2598 self.assertNotEqual( 2599 len(first_grads), 0, "None of the first grads depend on the input!" 2600 ) 2601 2602 batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads) 2603 2604 def vector_hessian_product(*vectors): 2605 outputs = torch.autograd.grad( 2606 first_grads, 2607 differentiable(args), 2608 vectors, 2609 retain_graph=True, 2610 allow_unused=True, 2611 ) 2612 outputs = tuple(out for out in outputs if out is not None) 2613 assert len(outputs) > 0 2614 return outputs 2615 2616 self._vmap_test( 2617 vector_hessian_product, batched_vectors, check_propagates_grad=False 2618 ) 2619 2620 def _test_arithmetic(self, op, device, test_grad_grad=True): 2621 x = torch.randn(2, 3, requires_grad=True, device=device) 2622 y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2623 scalar = 3.14 2624 self._batched_grad_test(op, (x, y)) 2625 self._batched_grad_test(op, (scalar, y)) 2626 self._batched_grad_test(op, (x, scalar)) 2627 2628 if test_grad_grad: 2629 self._batched_grad_grad_test(op, (x, y)) 2630 2631 def test_add(self, device): 2632 self._test_arithmetic(torch.add, device, test_grad_grad=False) 2633 self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False) 2634 2635 def test_sub(self, device): 2636 self._test_arithmetic(torch.sub, device, test_grad_grad=False) 2637 self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False) 2638 2639 def test_mul(self, device): 2640 self._test_arithmetic(torch.mul, device) 2641 self._test_arithmetic(lambda x, y: x * y, device) 2642 2643 def test_div(self, device): 2644 self._test_arithmetic(torch.div, device) 2645 self._test_arithmetic(lambda x, y: x / y, device) 2646 2647 @allowVmapFallbackUsage 2648 def test_binary_cross_entropy(self, device): 2649 x = torch.sigmoid(torch.randn(3, 2, device=device, requires_grad=True)) 2650 target = torch.rand(3, 2, device=device) 2651 2652 op = functools.partial(F.binary_cross_entropy, target=target) 2653 2654 self._batched_grad_test(op, (x,), {}) 2655 self._batched_grad_grad_test(op, (x,), {}) 2656 2657 def test_expand(self, device): 2658 x = torch.randn(2, 3, device=device, requires_grad=True) 2659 2660 def op(x): 2661 return x.expand(5, 5, 2, 3) 2662 2663 self._batched_grad_test(op, (x,)) 2664 2665 @allowVmapFallbackUsage 2666 def test_index(self, device): 2667 x = torch.randn(2, 3, requires_grad=True, device=device) 2668 index = torch.tensor([[0, 0], [1, 1]], device=device) 2669 2670 def op(x): 2671 y = x * x 2672 return y[index] 2673 2674 self._batched_grad_test(op, (x,)) 2675 self._batched_grad_grad_test(op, (x,)) 2676 2677 def test_lgamma(self, device): 2678 x = torch.randn(2, 3, requires_grad=True, device=device) 2679 self._batched_grad_test(Tensor.lgamma, (x,)) 2680 self._batched_grad_grad_test(Tensor.lgamma, (x,)) 2681 2682 def test_log(self, device): 2683 x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2684 self._batched_grad_test(torch.log, (x,)) 2685 self._batched_grad_grad_test(torch.log, (x,)) 2686 2687 def test_logsumexp(self, device): 2688 x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2689 2690 def op(x): 2691 return torch.logsumexp(x, -1) 2692 2693 self._batched_grad_test(op, (x,)) 2694 self._batched_grad_grad_test(op, (x,)) 2695 2696 def test_log1p(self, device): 2697 x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2698 self._batched_grad_test(torch.log1p, (x,)) 2699 self._batched_grad_grad_test(torch.log1p, (x,)) 2700 2701 @allowVmapFallbackUsage 2702 def test_max(self, device): 2703 x = torch.randn(2, 3, requires_grad=True, device=device) 2704 self._batched_grad_test(torch.max, (x,)) 2705 2706 @allowVmapFallbackUsage 2707 def test_median(self, device): 2708 x = torch.randn(2, 3, requires_grad=True, device=device) 2709 self._batched_grad_test(torch.median, (x,)) 2710 2711 @allowVmapFallbackUsage 2712 def test_min(self, device): 2713 x = torch.randn(2, 3, requires_grad=True, device=device) 2714 self._batched_grad_test(torch.min, (x,)) 2715 2716 def test_permute(self, device): 2717 x = torch.randn(2, 3, 5, requires_grad=True, device=device) 2718 2719 def op(x): 2720 return x.permute(2, 0, 1) 2721 2722 self._batched_grad_test(op, (x,)) 2723 2724 def test_reshape(self, device): 2725 x = torch.randn(2, 3, 5, requires_grad=True, device=device) 2726 2727 def op(x): 2728 return x.reshape([2 * 3, 5]) 2729 2730 self._batched_grad_test(op, (x,)) 2731 2732 def test_sigmoid(self, device): 2733 x = torch.randn(2, 3, requires_grad=True, device=device) 2734 self._batched_grad_test(Tensor.sigmoid, (x,)) 2735 self._batched_grad_grad_test(Tensor.sigmoid, (x,)) 2736 2737 def test_stack(self, device): 2738 x = torch.randn(2, 3, device=device, requires_grad=True) 2739 y = torch.randn(2, 3, device=device, requires_grad=True) 2740 2741 def op(x, y): 2742 return torch.stack([x, y]) 2743 2744 self._batched_grad_test(op, (x, y)) 2745 2746 def test_select(self, device): 2747 x = torch.randn(2, 3, device=device, requires_grad=True) 2748 self._batched_grad_test(lambda x: x[1], (x,)) 2749 self._batched_grad_test(lambda x: x.select(1, 2), (x,)) 2750 self._batched_grad_test(lambda x: x.select(-1, 0), (x,)) 2751 2752 def test_slice(self, device): 2753 x = torch.randn(2, 3, 5, device=device, requires_grad=True) 2754 self._batched_grad_test(lambda x: x[0:1], (x,)) 2755 self._batched_grad_test(lambda x: x[:, 1:3], (x,)) 2756 self._batched_grad_test(lambda x: x[..., 1:3], (x,)) 2757 2758 def test_trace(self, device): 2759 x = torch.randn(2, 3, device=device, requires_grad=True) 2760 self._batched_grad_test(Tensor.trace, (x,)) 2761 2762 def test_threshold(self, device): 2763 x = torch.randn(2, 3, device=device, requires_grad=True) 2764 self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,)) 2765 2766 @allowVmapFallbackUsage 2767 def test_inplace_on_view(self, device): 2768 leaf = torch.randn(4, 5, requires_grad=True) 2769 2770 def func(leaf): 2771 # Make sure the function is non-trivially twice differentiable 2772 base = leaf * leaf 2773 view = base[0] 2774 view.cos_() 2775 return view 2776 2777 self._batched_grad_test(func, (leaf,), {}) 2778 self._batched_grad_grad_test(func, (leaf,), {}) 2779 2780 @allowVmapFallbackUsage 2781 def test_inplace_manyview(self, device): 2782 leaf = torch.randn(4, 4, 5, requires_grad=True) 2783 2784 def func(leaf): 2785 # Make sure the function is non-trivially twice differentiable 2786 base = leaf * leaf 2787 view = base.transpose(0, 2) 2788 view = view[1] 2789 view = view.diagonal() 2790 view = view[::2] 2791 view.cos_() 2792 return view 2793 2794 self._batched_grad_test(func, (leaf,), {}) 2795 self._batched_grad_grad_test(func, (leaf,), {}) 2796 2797 def test_diagonal(self, device): 2798 x = torch.randn(4, 5, device=device, requires_grad=True) 2799 self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,)) 2800 2801 x = torch.randn(3, 4, 5, device=device, requires_grad=True) 2802 self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,)) 2803 2804 @allowVmapFallbackUsage 2805 def test_unrelated_output(self, device): 2806 B0 = 3 2807 x = torch.randn([], requires_grad=True) 2808 y = torch.randn([], requires_grad=True) 2809 gy = torch.randn(B0, requires_grad=True) 2810 2811 def vjp(v): 2812 (res,) = torch.autograd.grad(y, x, v, allow_unused=True) 2813 return torch.zeros_like(x) if res is None else res 2814 2815 result = vmap(vjp)(gy) 2816 self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) 2817 2818 @allowVmapFallbackUsage 2819 def test_unrelated_output_multiple_grad(self, device): 2820 B0 = 3 2821 x = torch.randn([], requires_grad=True) 2822 y = torch.randn([], requires_grad=True) 2823 gy = torch.randn(B0, requires_grad=True) 2824 2825 def vjp(v): 2826 (res,) = torch.autograd.grad(y, x, v, allow_unused=True) 2827 return torch.zeros_like(x) if res is None else res 2828 2829 _ = vjp(gy[0]) 2830 result = vmap(vjp)(gy) 2831 self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) 2832 2833 2834instantiate_device_type_tests(TestVmapBatchedGradientLegacy, globals(), None) 2835 2836if __name__ == "__main__": 2837 run_tests() 2838