xref: /aosp_15_r20/external/pytorch/test/functorch/test_vmap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import contextlib
10import functools
11import itertools
12import os
13import random
14import types
15import unittest
16import warnings
17from collections import namedtuple
18from typing import OrderedDict
19from unittest.case import skipIf
20
21from common_utils import (
22    check_vmap_fallback,
23    compute_quantities_for_vmap_test,
24    decorate,
25    DisableVmapFallback,
26    generate_vmap_inputs,
27    get_fallback_and_vmap_exhaustive,
28    is_batch_norm_training,
29    is_valid_inplace_sample_input,
30    opsToleranceOverride,
31    skip,
32    skipOps,
33    tol1,
34    xfail,
35)
36from functorch_additional_op_db import additional_op_db
37
38import functorch
39import torch
40import torch.nn.functional as F
41from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap
42from functorch.experimental import chunk_vmap
43from torch import Tensor
44from torch._C._functorch import reshape_dim_into, reshape_dim_outof
45from torch._functorch.make_functional import functional_init_with_buffers
46from torch._functorch.vmap import restore_vmap
47from torch.nn.attention import sdpa_kernel, SDPBackend
48from torch.testing._internal.autograd_function_db import autograd_function_db
49from torch.testing._internal.common_cuda import (
50    PLATFORM_SUPPORTS_CUDNN_ATTENTION,
51    PLATFORM_SUPPORTS_FLASH_ATTENTION,
52    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
53    with_tf32_off,
54)
55from torch.testing._internal.common_device_type import (
56    instantiate_device_type_tests,
57    onlyCUDA,
58    OpDTypes,
59    ops,
60    tol,
61    toleranceOverride,
62)
63from torch.testing._internal.common_methods_invocations import op_db
64from torch.testing._internal.common_utils import (
65    instantiate_parametrized_tests,
66    IS_WINDOWS,
67    markDynamoStrictTest,
68    parametrize,
69    run_tests,
70    skipIfTorchDynamo,
71    subtest,
72    TEST_WITH_TORCHDYNAMO,
73    TestCase,
74    unMarkDynamoStrictTest,
75    xfailIfTorchDynamo,
76)
77from torch.testing._internal.custom_op_db import custom_op_db
78from torch.utils import _pytree as pytree
79
80
81def get_platform_specific_sdpa():
82    ret = [SDPBackend.MATH]
83    if PLATFORM_SUPPORTS_FLASH_ATTENTION:
84        ret.append(SDPBackend.FLASH_ATTENTION)
85    if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
86        ret.append(SDPBackend.EFFICIENT_ATTENTION)
87    if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
88        ret.append(SDPBackend.CUDNN_ATTENTION)
89    return ret
90
91
92PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
93
94FALLBACK_REGEX = "There is a performance drop"
95
96
97class EnableVmapFallbackWarnings:
98    def __enter__(self):
99        self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled()
100        torch._C._debug_only_display_vmap_fallback_warnings(True)
101
102    def __exit__(self, *ignored):
103        torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state)
104
105
106@markDynamoStrictTest
107class TestVmapAPI(TestCase):
108    def test_non_tensor_output_raises(self):
109        with self.assertRaisesRegex(ValueError, "got type <class 'float'>"):
110            vmap(lambda x: 3.14)(torch.ones(3))
111
112        def multiple_outputs(x):
113            return x, 3
114
115        with self.assertRaisesRegex(ValueError, "got type <class 'int'>"):
116            vmap(multiple_outputs)(torch.ones(3))
117
118    def test_different_map_dim_size_raises(self):
119        x = torch.randn(2)
120        y = torch.randn(3)
121        expected_msg = (
122            "Expected all tensors to have the same size in the mapped dimension"
123        )
124        with self.assertRaisesRegex(ValueError, expected_msg):
125            vmap(torch.mul)(x, y)
126        with self.assertRaisesRegex(ValueError, expected_msg):
127            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
128        with self.assertRaisesRegex(ValueError, expected_msg):
129            vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
130                {"x": x, "y": y}
131            )
132
133    def test_func_with_no_inputs(self):
134        expected_msg = "got no inputs"
135
136        def foo():
137            return torch.randn(3)
138
139        def bar(x):
140            return torch.randn(3)
141
142        with self.assertRaisesRegex(ValueError, expected_msg):
143            vmap(foo)()
144
145        with self.assertRaisesRegex(ValueError, expected_msg):
146            vmap(bar)()
147
148    def test_func_with_no_tensors(self):
149        def foo(x):
150            return torch.randn(3)
151
152        with self.assertRaisesRegex(ValueError, "at least one Tensor"):
153            vmap(foo, (None,))(1)
154
155    def test_constant_function(self):
156        output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
157        self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
158
159    def test_single_input(self):
160        x = torch.randn(2, 3)
161
162        def square(x):
163            return x * x
164
165        output = vmap(square)(x)
166        self.assertEqual(output, x * x)
167
168    def test_multiple_inputs(self):
169        x = torch.randn(2, 3)
170        y = torch.randn(2, 3)
171        output = vmap(torch.mul)(x, y)
172        self.assertEqual(output, x * y)
173
174    def test_multiple_outputs(self):
175        def foo(x):
176            return x * x, x * x * x
177
178        x = torch.randn(3)
179        outputs = vmap(foo)(x)
180        self.assertEqual(outputs[0], x * x)
181        self.assertEqual(outputs[1], x * x * x)
182
183    def test_multiple_outputs2(self):
184        # This is the same thing as
185        # def returns_tuple_of_tensors(x):
186        #     return x, x
187        def returns_tuple_of_tensors(x):
188            return (x, x)
189
190        def returns_list_of_two_tensors(x):
191            return [x, x]
192
193        def returns_list_of_one_tensor(x):
194            return [x]
195
196        x = torch.randn(3)
197
198        # should not throw
199        vmap(returns_tuple_of_tensors)(x)
200        vmap(returns_list_of_two_tensors)(x)
201        vmap(returns_list_of_one_tensor)(x)
202
203    def test_nested_with_same_map_dim(self):
204        x = torch.randn(2, 3, 5)
205        y = torch.randn(2, 3, 5)
206        output = vmap(vmap(torch.mul))(x, y)
207        self.assertEqual(output, x * y)
208
209        output = vmap(vmap(vmap(torch.mul)))(x, y)
210        self.assertEqual(output, x * y)
211
212    def test_nested_with_diag_embed(self):
213        # diag_embed requires special testing because it is registered with conditional functionalization.
214        x = torch.randn(3, 3, 5)
215        output = vmap(vmap(torch.diag_embed))(x)
216        self.assertEqual(output, torch.diag_embed(x))
217
218    def test_nested_with_different_map_dim(self):
219        x = torch.randn(2, 3)
220        y = torch.randn(5, 3)
221        output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
222        self.assertEqual(output.shape, (2, 5, 3))
223        self.assertEqual(output, x.view(2, 1, 3) * y)
224
225        z = torch.randn(7, 3)
226        output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
227        self.assertEqual(output.shape, (2, 5, 7, 3))
228        self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
229
230    def test_noop_in_inner_vmap(self):
231        x = torch.randn(3)
232        y = torch.randn(5)
233        output = vmap(lambda x: vmap(lambda y: x)(y))(x)
234        self.assertEqual(output, x.view(3, 1).expand(3, 5))
235
236    def test_checkpoint(self):
237        A = torch.randn((3, 8, 8), dtype=torch.float64, requires_grad=True)
238
239        def get_grad(checkpoint):
240            A.grad = None
241
242            def get_loss(A):
243                ortho_A, _ = torch.func.vmap(torch.linalg.qr)(A)
244                return torch.sum(ortho_A)
245
246            if checkpoint:
247                loss = torch.utils.checkpoint.checkpoint(
248                    get_loss, A, use_reentrant=False
249                )
250            else:
251                loss = get_loss(A)
252            loss.backward()
253            return A.grad
254
255        expected = get_grad(checkpoint=False)
256        result = get_grad(checkpoint=True)
257        self.assertEqual(result, expected)
258
259    def test_unsupported_op_err_msg(self):
260        # Unsupported view op
261        tensor = torch.randn(2, 3)
262        msg = (
263            r"Batching rule not implemented for aten::.+; the "
264            r"fallback path doesn't work on out= or view ops"
265        )
266        # TODO: find a view op
267        # with self.assertRaisesRegex(RuntimeError, msg):
268        #     vmap(torch.ravel)(tensor)
269
270        def out_op(x, y):
271            return torch.abs(x, out=y)
272
273        with self.assertRaisesRegex(RuntimeError, msg):
274            vmap(out_op)(tensor, tensor)
275
276        # Don't support non-tensor returns. This is a limitation of vmap;
277        # functions that don't return tensors must be special cased
278        with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"):
279            vmap(torch.equal)(tensor, tensor)
280
281    def test_nonzero_out_dims(self):
282        # Basic test
283        tensor = torch.randn(2, 3)
284        result = vmap(lambda x: x, out_dims=1)(tensor)
285        self.assertEqual(result, tensor.permute(1, 0))
286        self.assertEqual(result.data_ptr(), tensor.data_ptr())
287
288        # Test that the batch dimension gets permuted to dim 2
289        tensor = torch.randn(2, 3, 5, 7)
290        result = vmap(lambda x: x, out_dims=2)(tensor)
291        self.assertEqual(result, tensor.permute(1, 2, 0, 3))
292        self.assertEqual(result.data_ptr(), tensor.data_ptr())
293
294        # negative out_dim
295        tensor = torch.randn(2, 3, 5, 7)
296        result = vmap(lambda x: x, out_dims=-1)(tensor)
297        self.assertEqual(result, tensor.permute(1, 2, 3, 0))
298        self.assertEqual(result.data_ptr(), tensor.data_ptr())
299
300        # check that out_dims works on ALL outputs
301        tensor = torch.randn(2, 3, 5, 7)
302        other = torch.randn(2, 3, 5, 7)
303        result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
304        self.assertEqual(
305            result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))
306        )
307
308        # use out_dims with the maximum vmap-able tensor dims (64 dims)
309        ndims = 64
310        shape = [2] + [1] * (ndims - 1)
311        expected_shape = [1, 1, 2] + [1] * (ndims - 3)
312        tensor = torch.randn(shape)
313        result = vmap(lambda x: x, out_dims=2)(tensor)
314        self.assertEqual(result.shape, expected_shape)
315
316        # test something that is not the identity function
317        def foo(x, y):
318            return x, x * y, x * y * y
319
320        x = torch.randn(2, 3, 5)
321        y = torch.randn(2, 3, 5)
322        result = vmap(foo, out_dims=1)(x, y)
323        self.assertEqual(
324            result,
325            (
326                x.permute(1, 0, 2),
327                (x * y).permute(1, 0, 2),
328                (x * y * y).permute(1, 0, 2),
329            ),
330        )
331
332    def test_multiple_out_dims(self):
333        def foo(x):
334            return x, x
335
336        def bar(x, y):
337            return x, x, x, x * y
338
339        x = torch.randn(2, 3, 5)
340        y = torch.randn(2, 3, 5)
341        result = vmap(foo, out_dims=(0, 1))(x)
342        self.assertEqual(result, (x, x.permute(1, 0, 2)))
343
344        result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
345        expected = (
346            x.permute(1, 2, 0),
347            x,
348            x.permute(1, 0, 2),
349            (x * y).permute(1, 2, 0),
350        )
351        self.assertEqual(result, expected)
352
353    def test_nested_out_dims(self):
354        y = torch.randn(2, 3, 5, 7)
355
356        # Inner vmap has non-zero out_dim
357        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
358        self.assertEqual(result.shape, (2, 5, 3, 7))
359        self.assertEqual(result, y.permute(0, 2, 1, 3))
360
361        # all vmaps have non-zero out_dim
362        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
363        self.assertEqual(result.shape, (5, 2, 3, 7))
364        self.assertEqual(result, y.permute(2, 0, 1, 3))
365
366        # throwing in some negative out_dims
367        result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
368        self.assertEqual(result.shape, (5, 7, 3, 2))
369        self.assertEqual(result, y.permute(2, 3, 1, 0))
370
371        # testing fn that isn't the identity
372        x = torch.randn(2, 3)
373        y = torch.randn(5, 3)
374        result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
375        self.assertEqual(result.shape, (3, 2, 5))
376        self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
377
378    def test_out_dims_edge_case(self):
379        def foo(x):
380            return x
381
382        # Test that we accept out_dims=(1,) for a function with one output.
383        tensor = torch.randn(2, 3)
384        expected = vmap(foo, out_dims=1)(tensor)
385        result = vmap(foo, out_dims=(1,))(tensor)
386        self.assertEqual(result, expected)
387
388    def test_out_dims_none_tuple(self):
389        def foo(x):
390            return x, "hello world"
391
392        tensor = torch.randn(2, 3)
393        result = vmap(foo, out_dims=(0, None))(tensor)
394        self.assertEqual(result[1], "hello world")
395        self.assertEqual(result[0], tensor)
396
397        def foo(x):
398            x.add_(1)
399            return None, "hello world"
400
401        result = vmap(foo, out_dims=(None, None))(tensor)
402        self.assertEqual(result, (None, "hello world"))
403
404    def test_out_dims_none(self):
405        def foo(x):
406            return x
407
408        tensor = torch.randn(2, 3)
409        with self.assertRaisesRegex(
410            ValueError, "can not return a BatchedTensor when out_dim is None"
411        ):
412            vmap(foo, out_dims=None)(tensor)
413
414        def foo(x):
415            x.add_(1)
416            return "hello world"
417
418        result = vmap(foo, out_dims=None)(tensor)
419        self.assertEqual(result, "hello world")
420
421    def test_out_dims_normal_tensor(self):
422        def foo(x):
423            return torch.arange(3)
424
425        tensor = torch.randn(2, 3)
426        result = vmap(foo)(tensor)
427        self.assertEqual(result.shape, [2, 3])
428
429        result = vmap(foo, out_dims=None)(tensor)
430        self.assertEqual(result, torch.arange(3))
431
432    def test_pytree_returns(self):
433        x = torch.randn(2, 3)
434
435        def f(x):
436            y = x.sin()
437            return y, (y, y), [y, (y, y)]
438
439        y0, (y1, y2), (y3, (y4, y5)) = vmap(f)(x)
440        self.assertEqual(y0, x.sin())
441        self.assertEqual(y0, y1)
442        self.assertEqual(y2, y1)
443        self.assertEqual(y2, y3)
444        self.assertEqual(y4, y3)
445        self.assertEqual(y5, y4)
446
447    def test_pytree_odict_returns(self):
448        x = torch.randn(2, 3)
449
450        def f(t):
451            y = t.sin()
452            return OrderedDict([("sin", y), ("cos", t.cos())])
453
454        out = vmap(f)(x)
455        assert isinstance(out, OrderedDict)
456        expected = f(x)
457        self.assertEqual(out["sin"], expected["sin"])
458        self.assertEqual(out["cos"], expected["cos"])
459
460    def test_pytree_returns_outdims(self):
461        x = torch.randn(2, 3)
462
463        def f(x):
464            y = x.sin()
465            return y, (y, y)
466
467        y0, (y1, y2) = vmap(f, out_dims=(0, (0, 1)))(x)
468        self.assertEqual(y0, x.sin())
469        self.assertEqual(y1, x.sin())
470        self.assertEqual(y2, x.sin().t())
471
472    def test_pytree_returns_broadcast_simple(self):
473        x = torch.randn(2, 3)
474
475        def f(x):
476            y = x.sin()
477            return y, (y, y)
478
479        y0, (y1, y2) = vmap(f, out_dims=1)(x)
480        self.assertEqual(y0, x.sin().t())
481        self.assertEqual(y1, y0)
482        self.assertEqual(y2, y0)
483
484    def test_pytree_returns_broadcast_nested(self):
485        x = torch.randn(2, 3)
486
487        def f(x):
488            y = x.sin()
489            return y, (y, y)
490
491        y0, (y1, y2) = vmap(f, out_dims=(0, 1))(x)
492        self.assertEqual(y0, x.sin())
493        self.assertEqual(y1, y0.t())
494        self.assertEqual(y2, y0.t())
495
496    def test_out_dims_must_be_int_or_collection_of_int_err_msg(self):
497        msg = "must be an int, None or a python collection of ints"
498        tensor = torch.randn(2, 3)
499        with self.assertRaisesRegex(ValueError, msg):
500            vmap(lambda x: x, out_dims="lol")(tensor)
501        with self.assertRaisesRegex(ValueError, msg):
502            vmap(lambda x: x, out_dims=("lol",))(tensor)
503
504    def test_out_dims_and_num_outputs_mismatch_err_msg(self):
505        msg = "not compatible"
506        x = torch.randn(2, 3, 5)
507
508        # Too many out_dims
509        with self.assertRaisesRegex(ValueError, msg):
510            vmap(lambda x: x, out_dims=(0, 0))(x)
511        with self.assertRaisesRegex(ValueError, msg):
512            vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
513
514        # Too few out_dims
515        with self.assertRaisesRegex(ValueError, msg):
516            vmap(lambda x: (x, x), out_dims=(0,))(x)
517        with self.assertRaisesRegex(ValueError, msg):
518            vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
519
520    def test_out_dim_out_of_bounds_err_msg(self):
521        # TODO(rzou): This error message isn't that great. It comes straight
522        # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
523        # the error message in the future in C++
524        msg = "Dimension out of range"
525        x = torch.randn(2, 3, 5)
526        with self.assertRaisesRegex(IndexError, msg):
527            vmap(lambda x: x, out_dims=3)(x)
528        with self.assertRaisesRegex(IndexError, msg):
529            vmap(lambda x: x, out_dims=-4)(x)
530
531    def test_non_zero_in_dims(self):
532        tensor = torch.randn(2, 3, 5)
533
534        # Implicit out_dims = 0; vmap will move the batch dim to the front.
535        output = vmap(lambda x: x, (1,))(tensor)
536        self.assertEqual(output, tensor.permute(1, 0, 2))
537        self.assertEqual(output.data_ptr(), tensor.data_ptr())
538
539        x = torch.randn(2, 3)
540        y = torch.randn(3, 2)
541        output = vmap(torch.mul, (0, 1))(x, y)
542        self.assertEqual(output, x * y.t())
543        output = vmap(torch.mul, (1, 0))(x, y)
544        self.assertEqual(output, x.t() * y)
545
546    def test_none_in_dims(self):
547        x = torch.randn(2, 3)
548        y = torch.randn(2, 3)
549
550        # None in_dim for a Tensor means we don't map over it
551        output = vmap(torch.mul, (0, None))(x, y)
552        self.assertEqual(output.shape, (2, 2, 3))
553        self.assertEqual(output, x.view(2, 1, 3) * y)
554
555        # None in_dim for non-tensor arguments
556        output = vmap(torch.mul, (0, None))(x, 2)
557        self.assertEqual(output, x * 2)
558
559    def test_nested_non_default_in_dims(self):
560        x = torch.rand(5, 2, 3)
561        y = torch.rand(3, 5, 2)
562        result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
563        self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
564
565    def test_nested_negative_in_dims(self):
566        x = torch.randn(2, 3)
567        y = torch.randn(2, 3)
568        output = vmap(torch.mul, (-1, -1))(x, y)
569        self.assertEqual(output.shape, (3, 2))
570        self.assertEqual(output, (x * y).permute(1, 0))
571
572    def test_non_default_in_dims_out_dims(self):
573        x = torch.randn(2, 3, 5)
574
575        # Same in_dim as out_dim, vmap over identity
576        result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
577        self.assertEqual(result, x)
578        self.assertEqual(result.data_ptr(), x.data_ptr())
579
580        # Different in_dim from out_dim, vmap over identity
581        result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
582        self.assertEqual(result.shape, (2, 5, 3))
583        self.assertEqual(result, x.transpose(1, 2))
584        self.assertEqual(result.data_ptr(), x.data_ptr())
585
586        def foo(x):
587            return x * 2
588
589        # Same in_dim as out_dim, vmap over operation
590        result = vmap(foo, in_dims=1, out_dims=1)(x)
591        self.assertEqual(result, x * 2)
592
593        # Different in_dim as out_dim, vmap over operation
594        result = vmap(foo, in_dims=2, out_dims=1)(x)
595        self.assertEqual(result.shape, (2, 5, 3))
596        self.assertEqual(result, (x * 2).transpose(1, 2))
597
598        # Basic nested test.
599        result = vmap(vmap(foo, 1, 1), 1, 1)(x)
600        self.assertEqual(result, x * 2)
601
602    def test_item_throws(self):
603        def f(x):
604            return x.item()
605
606        with self.assertRaisesRegex(RuntimeError, r"item\(\) on a Tensor"):
607            vmap(f)(torch.randn(3))
608
609    def test_data_dependent_control_flow_throws(self):
610        def f(x):
611            if x:
612                return x
613            return 0
614
615        with self.assertRaisesRegex(RuntimeError, r"data-dependent control flow"):
616            vmap(f)(torch.randn(3))
617
618    def test_accepts_nested_inputs(self):
619        x = torch.randn(2, 3)
620        y = torch.randn(2, 3)
621
622        # Single layer of nesting
623        out = vmap(lambda z: z[0] + z[1])((x, y))
624        self.assertEqual(out, x + y)
625        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
626        self.assertEqual(out, x + y)
627        out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
628        self.assertEqual(out, x + y)
629
630        out = vmap(lambda z: z[0] + z[1])([x, y])
631        self.assertEqual(out, x + y)
632        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
633        self.assertEqual(out, x + y)
634        out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
635        self.assertEqual(out, x + y)
636
637        out = vmap(lambda z: z["x"] + z["y"])({"x": x, "y": y})
638        self.assertEqual(out, x + y)
639        out = vmap(lambda z: z["x"] + z["y"], in_dims=(0,))({"x": x, "y": y})
640        self.assertEqual(out, x + y)
641        out = vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
642            {"x": x, "y": y}
643        )
644        self.assertEqual(out, x + y)
645
646        # Multiple layers of nesting
647        out_fn = vmap(lambda z: z["x"][0] + z["x"][1][0] + z["y"][0] + z["y"][1])
648        out = out_fn({"x": [x, (x,)], "y": [y, y]})
649        self.assertEqual(out, x + x + y + y)
650
651    def test_in_dims_wrong_type_err_msg(self):
652        x = torch.randn(3)
653        y = torch.randn(3)
654        msg = r"expected `in_dims` to be int or a \(potentially nested\) tuple"
655        with self.assertRaisesRegex(ValueError, msg):
656            vmap(torch.mul, [0, 0])(x, y)
657        with self.assertRaisesRegex(ValueError, msg):
658            vmap(torch.mul, set({0}))(x, y)
659        with self.assertRaisesRegex(ValueError, msg):
660            vmap(torch.mul, "lol")(x, y)
661        with self.assertRaisesRegex(ValueError, msg):
662            vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
663        # The following should not throw
664        vmap(torch.mul, (0, 0))(x, y)
665
666    def test_not_enough_in_dims_err_msg(self):
667        x = torch.randn(3)
668        y = torch.randn(3)
669        msg = r"in_dims is not compatible with the structure of `inputs`"
670
671        with self.assertRaisesRegex(ValueError, msg):
672            vmap(torch.mul, (0,))(x, y)
673        with self.assertRaisesRegex(ValueError, msg):
674            vmap(torch.mul, (0, 0, 0))(x, y)
675        with self.assertRaisesRegex(ValueError, msg):
676            vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
677        with self.assertRaisesRegex(ValueError, msg):
678            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
679        # The following should not throw
680        vmap(torch.mul, (0, 0))(x, y)
681
682    def test_integer_in_dim_but_not_tensor_input_err_msg(self):
683        def foo(xy):
684            return xy[0] * xy[1]
685
686        def bar(x, yz):
687            return x * yz[0] * yz[1]
688
689        x = torch.randn(2, 3)
690
691        # the following are errors in jax (and will always be errors)
692        msg = "Got in_dim=0 for an input but the input is of type"
693        with self.assertRaisesRegex(ValueError, msg):
694            vmap(torch.sum)(x, 0)
695        with self.assertRaisesRegex(ValueError, msg):
696            vmap(torch.sum, (0, 0))(x, 0)
697        with self.assertRaisesRegex(ValueError, msg):
698            vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
699        # The following should not throw
700        vmap(torch.sum, (0, None))(x, 0)
701
702    def test_in_dim_not_in_tensor_err_msg(self):
703        def foo(x):
704            return x * x
705
706        x = torch.randn(2, 3)
707        y = torch.randn(2, 3)
708
709        msg = r"Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w"
710        with self.assertRaisesRegex(ValueError, msg):
711            vmap(foo)(torch.randn([]))
712        with self.assertRaisesRegex(ValueError, msg):
713            vmap(foo, in_dims=(0,))(torch.randn([]))
714        with self.assertRaisesRegex(ValueError, msg):
715            vmap(foo, in_dims=(-3,))(x)
716        with self.assertRaisesRegex(ValueError, msg):
717            vmap(foo, in_dims=(2,))(y)
718        with self.assertRaisesRegex(ValueError, msg):
719            vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
720        # the following should not throw
721        vmap(foo, in_dims=(0,))(torch.randn(2, 3))
722        vmap(foo, in_dims=(1,))(torch.randn(2, 3))
723
724    def test_fallback_does_not_warn_by_default(self):
725        op = torch._test_functorch_fallback
726        x = torch.randn(11)
727        y = torch.randn(11)
728        with warnings.catch_warnings(record=True) as wa:
729            torch.vmap(op)(x, y)
730            # The single warning here is the "vmap is experimental"
731            # warning, not a warning from the vmap fallback path.
732            self.assertEqual(len(wa), 1)
733
734    @unittest.expectedFailure
735    def test_fallback_warns_when_warnings_are_enabled(self):
736        # NB: One day we will implement a batching rule for torch.atan2.
737        # If/when we do, this test should be replaced to test the fallback
738        # path on another operator to avoid bitrot.
739        op = torch._test_functorch_fallback
740        x = torch.randn(11)
741        y = torch.randn(11)
742        with warnings.catch_warnings(record=True) as wa:
743            with EnableVmapFallbackWarnings():
744                torch.vmap(op)(x, y)
745            self.assertEqual(len(wa), 2)
746            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
747
748    def _assert_uses_vmap_fallback(self, vmap_args, inputs):
749        return
750        # with warnings.catch_warnings(record=True) as wa:
751        #     with EnableVmapFallbackWarnings():
752        #         result = vmap(*vmap_args)(*inputs)
753        #     self.assertEqual(len(wa), 2)
754        #     self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
755
756    def test_fallback_zero_dim(self):
757        op = torch._test_functorch_fallback
758        x = torch.randn(11)
759        y = torch.randn(11)
760        self._assert_uses_vmap_fallback((op,), (x, y))
761
762        B0, B1 = 0, 3
763        x = torch.randn(B0, 11)
764        y = torch.randn(11)
765
766        msg = "The fallback path does not support vmap over dims of size 0"
767
768        with self.assertRaisesRegex(RuntimeError, msg):
769            vmap(op, (0, None))(x, y)
770        with self.assertRaisesRegex(RuntimeError, msg):
771            vmap(op, (None, 0))(y, x)
772        with self.assertRaisesRegex(RuntimeError, msg):
773            vmap(op)(x, x)
774
775        x = torch.randn(B0, B1, 11)
776        y = torch.randn(B1, 11)
777        with self.assertRaisesRegex(RuntimeError, msg):
778            vmap(op, (0, None))(x, y)
779        with self.assertRaisesRegex(RuntimeError, msg):
780            vmap(op, (None, 0))(y, x)
781        with self.assertRaisesRegex(RuntimeError, msg):
782            vmap(op)(x, x)
783
784    def test_fallback_warning(self):
785        # We use a dummy function _test_functorch_fallback
786        # defined in prim_native_functions.cpp for this
787        op = torch._test_functorch_fallback
788
789        x = torch.randn(5, 7, 11)
790        y = torch.randn(5, 7, 11)
791
792        self._assert_uses_vmap_fallback((op,), (x, y))
793
794        x = torch.randn(7, 11, 5)
795        y = torch.randn(5, 7, 11)
796        result = vmap(op, (2, 0))(x, y)
797        self.assertEqual(result, op(x.permute(2, 0, 1), y))
798
799        # nested vmap
800        x = torch.randn(7, 11, 5)
801        y = torch.randn(5, 7, 11)
802        result = vmap(vmap(op), (2, 0))(x, y)
803        self.assertEqual(result, op(x.permute(2, 0, 1), y))
804
805        # big batch size (total 10000)
806        x = torch.randn(100, 10, 10, 5)
807        y = torch.randn(100, 10, 10)
808        result = vmap(vmap(vmap(op)))(x, y)
809        self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))
810
811    # TODO: No clue what is wrong here.
812    @unittest.skip
813    def test_fallback_masked_fill(self):
814        # NB: One day we will implement a batching rule for masked_fill
815        # If/when we do, this test should be replaced to test the fallback
816        # path on another operator to avoid bitrot.
817        def run_test(batch_size):
818            B0 = batch_size
819            x = torch.randn(B0, 7, 11, 13)
820            dim = 0
821            index = torch.tensor([0, 4, 2])
822            values = torch.randn(B0, 3, 13)
823
824            self._assert_uses_vmap_fallback(
825                (torch.index_add, (0, None, None, 0)), (x, dim, index, values)
826            )
827
828            result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
829            expected = torch.index_add(x, dim + 1, index, values.view(B0, 3, 1, 13))
830            self.assertEqual(result, expected)
831
832        run_test(batch_size=5)
833        run_test(batch_size=1237)
834
835    def test_fallback_multiple_returns(self):
836        # NB: One day we will implement a batching rule for torch.var_mean
837        # If/when we do, this test should be replaced to test the fallback
838        # path on another operator to avoid bitrot.
839        B0, B1, B2 = 2, 3, 1237
840        tensor = torch.randn(B0, 10)
841
842        self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
843
844        # fallback correctness on torch.var_mean
845        result = vmap(torch.var_mean)(tensor)
846        expected = torch.var_mean(tensor, dim=1)
847        self.assertEqual(result, expected)
848
849        # nested vmap
850        tensor = torch.randn(B0, B1, 10)
851        result = vmap(vmap(torch.var_mean))(tensor)
852        expected = torch.var_mean(tensor, dim=2)
853        self.assertEqual(result, expected)
854
855        # big batch size, nested vmap
856        tensor = torch.randn(B0, B1, B2, 10)
857        result = vmap(vmap(vmap(torch.var_mean)))(tensor)
858        expected = torch.var_mean(tensor, dim=3)
859        self.assertEqual(result, expected)
860
861    def test_inplace_fallback_unary(self):
862        # Test the in-place fallback on an in-place method that takes no
863        # additional Tensor arguments. This is the simplest case of the fallback.
864        # NB: One day we will implement a batching rule for acos_.
865        # If/when we do, this test should be replaced to test the fallback
866        # path on another operator to avoid bitrot.
867        op = Tensor.acos_
868        B0, B1, B2 = 2, 3, 10000
869
870        x = torch.randn(B0, 5)
871        self._assert_uses_vmap_fallback((op,), (x,))
872
873        # Single vmap
874        x_orig = torch.rand(B0, 5)
875        x = x_orig.clone()
876        result = vmap(op)(x)
877        self.assertTrue(result is x)
878        self.assertEqual(result, x_orig.acos())
879
880        # Single vmap + different out_dim produces a view(!)
881        x_orig = torch.rand(B0, 5)
882        x = x_orig.clone()
883        result = vmap(op, out_dims=(1,))(x)
884        self.assertTrue(result._base is x)
885        self.assertEqual(result, x_orig.t().acos())
886
887        # Nested vmap
888        x_orig = torch.randn(B0, B1, 5)
889        x = x_orig.clone()
890        result = vmap(vmap(op))(x)
891        self.assertTrue(result is x)
892        self.assertEqual(result, x_orig.acos())
893
894        # Nested vmap, large batch size
895        x_orig = torch.randn(B0, B1, B2, 5)
896        x = x_orig.clone()
897        result = vmap(vmap(vmap(op)))(x)
898        self.assertTrue(result is x)
899        self.assertEqual(result, x_orig.acos())
900
901    def test_inplace_fallback_nary_same_levels(self):
902        # NB: One day we will implement a batching rule for atan2_
903        # If/when we do, this test should be replaced to test the fallback
904        # path on another operator to avoid bitrot.
905        op = Tensor.atan2_
906        outplace_op = torch.atan2
907
908        x = torch.randn(5, 7, 11)
909        y = torch.randn(5, 7, 11)
910        self._assert_uses_vmap_fallback((op,), (x, y))
911
912        # Single vmap
913        B0 = 5
914        x_orig = torch.randn(7, 11, B0)
915        x = x_orig.clone()
916        y = torch.randn(B0, 7, 11)
917        vmap(op, (2, 0))(x, y)
918        self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))
919
920        # Nested vmap
921        B0, B1 = 5, 7
922        x_orig = torch.randn(B1, 11, B0)
923        x = x_orig.clone()
924        y = torch.randn(B0, B1, 11)
925        vmap(vmap(op), (2, 0))(x, y)
926        self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))
927
928        # big batch size (total 10000)
929        B0, B1, B2 = 100, 10, 10
930        x_orig = torch.randn(B0, B1, B2, 5)
931        x = x_orig.clone()
932        y = torch.randn(B0, B1, B2)
933        vmap(vmap(vmap(op)))(x, y)
934        self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))
935
936    # ("Fallback isInplaceVmapCompatible check is broken")
937    @unittest.expectedFailure
938    def test_inplace_fallback_nary_different_levels(self):
939        # NB: One day we will implement a batching rule for atan2_
940        # If/when we do, this test should be replaced to test the fallback
941        # path on another operator to avoid bitrot.
942        op = Tensor.atan2_
943        outplace_op = torch.atan2
944        B0, B1 = 2, 3
945
946        x = torch.rand(B0, 7)
947        y = torch.rand(7)
948        self._assert_uses_vmap_fallback((op, (0, None)), (x, y))
949
950        # op(left, right): All of the levels in right are found in left
951        x_orig = torch.rand(B0, 7)
952        x = x_orig.clone()
953        y = torch.rand(7)
954        vmap(op, in_dims=(0, None))(x, y)
955        self.assertEqual(x, outplace_op(x_orig, y))
956
957        x_orig = torch.rand(B0, B1, 7)
958        x = x_orig.clone()
959        y = torch.rand(B0, 7)
960        vmap(vmap(op, in_dims=(0, None)))(x, y)
961        self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))
962
963        # op(left, right): Some of the levels in right are not found in left
964        msg = r"vmap: aten::atan2_\(self, \*extra_args\) is not possible"
965        x = torch.rand(7)
966        y = torch.rand(B0, 7)
967        with self.assertRaisesRegex(RuntimeError, msg):
968            vmap(op, in_dims=(None, 0))(x, y)
969
970        x = torch.rand(B1, 7)
971        y = torch.rand(B0, 7)
972        with self.assertRaisesRegex(RuntimeError, msg):
973            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)
974
975        x = torch.rand(B1, 7)
976        y = torch.rand(7, B0)
977        with self.assertRaisesRegex(RuntimeError, msg):
978            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)
979
980        x = torch.rand(B0, 7)
981        y = torch.rand(B0, B1, 7)
982        with self.assertRaisesRegex(RuntimeError, msg):
983            vmap(vmap(op, in_dims=(None, 0)))(x, y)
984
985    def test_backward_unsupported_interaction(self):
986        x = torch.randn(3, requires_grad=True)
987        y = torch.randn(5)
988        grad = torch.randn_like(x)
989        err_msg = r"backward\(\) called inside a functorch transform"
990
991        def backward_on_vmapped_tensor(x):
992            x.sum().backward()
993
994        # FIXME
995        return self.skipTest(
996            "error: element 0 of tensors does not require grad and does not have a grad_fn"
997        )
998        with self.assertRaisesRegex(RuntimeError, err_msg):
999            vmap(backward_on_vmapped_tensor)(x)
1000
1001        def backward_with_vmapped_grad(x, grad):
1002            x.backward(grad)
1003
1004        with self.assertRaisesRegex(RuntimeError, err_msg):
1005            vmap(backward_with_vmapped_grad)(x, grad)
1006
1007        def completely_unrelated_backward(y):
1008            x.sum().backward()
1009            return y
1010
1011        with self.assertRaisesRegex(RuntimeError, err_msg):
1012            vmap(completely_unrelated_backward)(y)
1013
1014    @unittest.expectedFailure
1015    def test_grad_unsupported_interaction(self):
1016        input_tensor = torch.randn(3, requires_grad=True)
1017        err_msg = "autograd.grad.* called inside torch.vmap"
1018
1019        captured = torch.randn(3, requires_grad=True)
1020
1021        def output_to_grad_is_vmapped(input_tensor):
1022            output = (captured * input_tensor).sum()
1023            return torch.autograd.grad([output], [captured])[0]
1024
1025        with self.assertRaisesRegex(RuntimeError, err_msg):
1026            vmap(output_to_grad_is_vmapped)(input_tensor)
1027
1028        output = (input_tensor**2).sum()
1029
1030        def input_to_grad_is_vmapped(input_tensor):
1031            return torch.autograd.grad([output], [input_tensor])[0]
1032
1033        with self.assertRaisesRegex(RuntimeError, err_msg):
1034            vmap(input_to_grad_is_vmapped)(input_tensor)
1035
1036    def test_batched_gradient_basic(self):
1037        N = 3
1038        x = torch.randn(N, requires_grad=True)
1039        y = torch.randn(N)
1040
1041        def vjp_mul(v):
1042            return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
1043
1044        batched_v = torch.eye(N)
1045        jacobian = vmap(vjp_mul)(batched_v)
1046        self.assertEqual(jacobian, torch.diagflat(y))
1047
1048    def test_functools_partial(self):
1049        x = torch.randn(3)
1050        y = torch.randn(2, 3)
1051        result = vmap(functools.partial(torch.mul, x))(y)
1052        self.assertEqual(result, x * y)
1053
1054    def test_nn_module(self):
1055        tensor = torch.randn(2, 3)
1056        model = torch.nn.Linear(3, 3, bias=False)
1057        result = vmap(model)(tensor)
1058        self.assertEqual(result, model(tensor))
1059
1060    def test_fallback_with_undefined_grad(self):
1061        B0 = 7
1062        x = torch.randn(2, 3, 4, 5, requires_grad=True)
1063        weight = torch.randn(3, 3, 1, 1)
1064        v = torch.randn(B0, 2, 3, 4, 5)
1065
1066        def get_vjp(v):
1067            result = torch.nn.functional.conv2d(x, weight)
1068            (grad_x,) = torch.autograd.grad(result, x, v)
1069            return grad_x
1070
1071        # Runs vmap(get_vjp)(v), which should not error out.
1072        # The backward formula for convolution returns an undefined
1073        # Tensor for grad_bias because the original bias does not exist.
1074        #
1075        # In the future we'll probably add a batching rule for convolution
1076        # backward. When this happens, we should modify this test to use a
1077        # different op (and/or create and use a dummy operator) to avoid bitrot.
1078        self._assert_uses_vmap_fallback([get_vjp], [v])
1079
1080    def test_reshape_dim_into(self):
1081        x = torch.randn(2, 3, 5, 7)
1082
1083        y = reshape_dim_into(0, 0, x)
1084        self.assertEqual(y, x.reshape(6, 5, 7))
1085
1086        y = reshape_dim_into(0, 1, x)
1087        self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7))
1088
1089        y = reshape_dim_into(0, 2, x)
1090        self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))
1091
1092        y = reshape_dim_into(1, 2, x)
1093        self.assertEqual(y, x.movedim(1, 2).reshape(2, 5, 3 * 7))
1094
1095        y = reshape_dim_into(0, -2, x)
1096        self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7))
1097
1098        y = reshape_dim_into(0, -1, x)
1099        self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))
1100
1101        y = reshape_dim_into(-4, -1, x)
1102        self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))
1103
1104    def test_reshape_dim_outof(self):
1105        x = torch.randn(12, 12, 12).permute(2, 1, 0)
1106
1107        y = reshape_dim_outof(0, 2, x)
1108        self.assertEqual(y, x.reshape(2, 6, 12, 12))
1109
1110        y = reshape_dim_outof(1, 4, x)
1111        self.assertEqual(y, x.reshape(12, 4, 3, 12))
1112
1113        y = reshape_dim_outof(2, 6, x)
1114        self.assertEqual(y, x.reshape(12, 12, 6, 2))
1115
1116        y = reshape_dim_outof(-1, 6, x)
1117        self.assertEqual(y, x.reshape(12, 12, 6, 2))
1118
1119        # Case: `0` sized dim.
1120        x = torch.randn(12, 12, 0)
1121        y = reshape_dim_outof(-1, 6, x)
1122        self.assertEqual(y.shape, torch.Size((12, 12, 6, 0)))
1123
1124    def test_batch_rule_does_not_need_to_handle_no_batched_input(self):
1125        def f(x, y):
1126            res = torch.dot(y, torch.ones(2))
1127            return x + res
1128
1129        x = torch.randn(7, 5)
1130        y = torch.randn(3, 2)
1131        out = vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
1132        expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x
1133        self.assertEqual(out, expected)
1134
1135    def test_decomposition_under_python_dispatcher(self):
1136        # This test will raise an error if the vmap fallback gets invoked.
1137        # Here we test that decomps registered to FuncTorchBatchedDecomposition
1138        # are respected by the Python Dispatcher.
1139        t = torch.ones(3, 3) * 5
1140        with DisableVmapFallback():
1141            with torch._dispatch.python.enable_python_dispatcher():
1142                o = torch.vmap(torch.square)(t)
1143        self.assertEqual(o, torch.square(t))
1144
1145    def _test_vmap_autocast(self, device):
1146        if torch.device(device).type == "cpu":
1147            amp_dtype = torch.bfloat16
1148        else:
1149            amp_dtype = torch.float16
1150
1151        a_float32 = torch.rand(4, 2, 3, device=device)
1152        b_float32 = torch.rand(4, 3, 2, device=device)
1153        c_float32 = torch.rand(4, 2, 2, device=device)
1154        d_float32 = torch.rand(4, 3, 2, device=device)
1155
1156        # Case 1, autocast inside vmapped function
1157        def func1(x, y, z, w):
1158            with torch.autocast(dtype=amp_dtype, device_type=device):
1159                e_float16 = torch.matmul(x, y)
1160                assert e_float16.dtype == amp_dtype, e_float16.dtype
1161                f_float16 = torch.matmul(z, e_float16)
1162                assert f_float16.dtype == amp_dtype, f_float16.dtype
1163            return torch.matmul(w, f_float16.float())
1164
1165        expected = func1(a_float32, b_float32, c_float32, d_float32)
1166        out = vmap(func1)(a_float32, b_float32, c_float32, d_float32)
1167        assert expected.allclose(out)
1168
1169        # Case 2, autocast decorator inside vmapped function
1170        @torch.autocast(dtype=amp_dtype, device_type=device)
1171        def func2(x, y, z, w):
1172            e_float16 = torch.matmul(x, y)
1173            assert e_float16.dtype == amp_dtype, e_float16.dtype
1174            f_float16 = torch.matmul(z, e_float16)
1175            assert f_float16.dtype == amp_dtype, f_float16.dtype
1176            return torch.matmul(w, f_float16)
1177
1178        expected = func2(a_float32, b_float32, c_float32, d_float32)
1179        out = vmap(func2)(a_float32, b_float32, c_float32, d_float32)
1180        assert expected.allclose(out)
1181
1182        # Case 3, autocast is outside vmapped function
1183        def func3(x, y, z, w):
1184            e_float16 = torch.matmul(x, y)
1185            assert e_float16.dtype == amp_dtype, e_float16.dtype
1186            f_float16 = torch.matmul(z, e_float16)
1187            assert f_float16.dtype == amp_dtype, f_float16.dtype
1188            return torch.matmul(w, f_float16)
1189
1190        with torch.autocast(dtype=amp_dtype, device_type=device):
1191            expected = func3(a_float32, b_float32, c_float32, d_float32)
1192            out = vmap(func3)(a_float32, b_float32, c_float32, d_float32)
1193
1194        assert expected.allclose(out)
1195
1196    @unittest.skip("Somehow, vmap and autocast do not work on CPU")
1197    def test_vmap_autocast_cpu(self):
1198        self._test_vmap_autocast("cpu")
1199
1200    @skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
1201    def test_vmap_autocast_cuda(self):
1202        self._test_vmap_autocast("cuda")
1203
1204    def test_restore_vmap_pytree_input_output(self):
1205        def f(x, y):
1206            output0 = x[0] + x[1]
1207            output1 = y
1208            return {"a": output0, "b": output1}
1209
1210        B = 2
1211        x0 = torch.randn(B, 3)
1212        x1 = torch.randn(B)
1213        y = torch.randn(4, B)
1214
1215        out, out_dims = restore_vmap(f, ((0, 0), 1), B, "error")((x0, x1), y)
1216        expected = vmap(f, in_dims=((0, 0), 1), out_dims={"a": 0, "b": 1})((x0, x1), y)
1217        self.assertEqual(out, expected)
1218        self.assertEqual(out_dims, {"a": 0, "b": 1})
1219
1220    def test_restore_vmap_no_vmapped_inputs(self):
1221        def f(x, y, z):
1222            return x, y * z, z
1223
1224        B = 2
1225        # Mix of tensor and non-tensor inputs
1226        x = torch.randn(3)
1227        y = torch.randn(4)
1228        z = 5
1229        out, out_dims = restore_vmap(f, (None, None, None), B, "error")(x, y, z)
1230        self.assertEqual(out, f(x, y, z))
1231        self.assertEqual(out_dims, (None, None, None))
1232
1233    def test_restore_vmap_unexpanded_outputs(self):
1234        def f(x, y):
1235            # Mix of tensor and non-tensor outputs
1236            return 3 * y, y.sum(), None
1237
1238        B = 2
1239        x = torch.randn(B, 3)
1240        y = torch.randn(4)
1241        out, out_dims = restore_vmap(f, (0, None), B, "error")(x, y)
1242        self.assertEqual(out, f(None, y))
1243        self.assertEqual(out_dims, (None, None, None))
1244
1245    def test_data_attribute(self):
1246        def foo(x):
1247            y = x.data
1248            return x
1249
1250        with self.assertRaisesRegex(
1251            RuntimeError, "accessing `data` under vmap transform"
1252        ):
1253            torch.func.vmap(foo)(torch.randn(3, 3))
1254
1255        def foo(x):
1256            x.data = torch.ones(3, 3)
1257            return x
1258
1259        with self.assertRaisesRegex(
1260            RuntimeError, "mutating directly with `.data` under vmap"
1261        ):
1262            torch.func.vmap(foo)(torch.randn(3, 3))
1263
1264
1265def slice_inputs(inputs, bdims, i):
1266    result = []
1267    for inp, bdim in zip(inputs, bdims):
1268        if bdim is None:
1269            result.append(inp)
1270        else:
1271            result.append(inp.select(bdim, i))
1272    return tuple(result)
1273
1274
1275def reference_vmap(op, inputs, in_dims=0, out_dims=0, return_nt=False):
1276    if isinstance(in_dims, int):
1277        in_dims = (in_dims,) * len(inputs)
1278    bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
1279    assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
1280    bdim_size = bdim_sizes[0]
1281    results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
1282
1283    assert len(results) > 0
1284    op_has_single_return = not isinstance(results[0], tuple)
1285    if op_has_single_return:
1286        assert all(isinstance(result, torch.Tensor) for result in results)
1287        if isinstance(out_dims, int):
1288            out_dims = (out_dims,) * 1
1289        if return_nt:
1290            return torch.nested.nested_tensor(list(results))
1291        else:
1292            return torch.stack(results, dim=out_dims[0])
1293
1294    assert all(isinstance(result, tuple) for result in results)
1295    num_returns = len(results[0])
1296    assert all(len(result) == num_returns for result in results)
1297    if isinstance(out_dims, int):
1298        out_dims = (out_dims,) * num_returns
1299    if return_nt:
1300        return tuple(
1301            torch.nested.nested_tensor(list(result_shards))
1302            for result_shards in zip(*results)
1303        )
1304    else:
1305        return tuple(
1306            torch.stack(result_shards, out_dim)
1307            for result_shards, out_dim in zip(zip(*results), out_dims)
1308        )
1309
1310
1311class TensorFactory:
1312    @staticmethod
1313    def rand(size, device="cpu", dtype=torch.float):
1314        return torch.rand(size, device=device, dtype=dtype)
1315
1316    @staticmethod
1317    def randn(size, device="cpu", dtype=torch.float):
1318        return torch.randn(size, device=device, dtype=dtype)
1319
1320    @staticmethod
1321    def randp1(size, device="cpu", dtype=torch.float):
1322        return torch.rand(size, device=device, dtype=dtype) + 1
1323
1324
1325# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
1326# (slow) sequential map+stack fallback.
1327#
1328# check_view: Test if the first returned output is a view of the first input
1329# check_propagates_grad: Test if the operation propagates gradients.
1330
1331
1332def _vmap_test(
1333    self,
1334    op,
1335    inputs,
1336    in_dims=0,
1337    out_dims=0,
1338    check_view=False,
1339    check_propagates_grad=True,
1340):
1341    result = vmap(op, in_dims, out_dims)(*inputs)
1342    are_nested = [t.is_nested for t in pytree.tree_leaves(result)]
1343    reference_result = reference_vmap(
1344        op, inputs, in_dims, out_dims, return_nt=any(are_nested)
1345    )
1346    self.assertEqual(result, reference_result)
1347    op_has_single_return = not isinstance(result, tuple)
1348
1349    if check_view:
1350        result_as_tuple = (result,) if op_has_single_return else result
1351        for output in result_as_tuple:
1352            input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
1353            self.assertTrue(
1354                output._base is input0_base,
1355                msg="result was not a view of the first input!",
1356            )
1357
1358    if not check_propagates_grad:
1359        return
1360    # Assuming input[0] is a floating-point tensor. Check if the vmap
1361    # operation propagates the requires_grad flag to the zeroth output.
1362    # Some vmap operators are implemented in a way that assumes that
1363    # they are composite with respect to autograd. If the operator ever is
1364    # changed to not be composite with respect to autograd, then the
1365    # following check should fail.
1366    inputs_clone = list(inputs)
1367    inputs_clone[0] = inputs[0].clone().requires_grad_()
1368    result = vmap(op, in_dims, out_dims)(*inputs_clone)
1369    result_as_tuple = (result,) if op_has_single_return else result
1370    self.assertTrue(result[0].requires_grad)
1371
1372
1373def should_allow_vmap_fallback_usage(fn):
1374    return getattr(fn, "_allow_vmap_fallback_usage", False)
1375
1376
1377def allowVmapFallbackUsage(fn):
1378    fn._allow_vmap_fallback_usage = True
1379    return fn
1380
1381
1382# All tests of TestVmapBase check that the slow vmap fallback is never invoked.
1383# This is so that we can incrementally add batching rules for operators to
1384# replace the slow vmap fallback path for said operators. To skip this check,
1385# please use the allowVmapFallbackUsage decorator.
1386#
1387# NB: Don't add tests to TestVmapBase directly, unless you want them to run
1388# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators.
1389#
1390# NB: TestVmapBase is a nested class. This prevents test runners from picking
1391# it up and running it.
1392
1393
1394class Namespace:
1395    class TestVmapBase(TestCase):
1396        def __init__(self, method_name="runTest"):
1397            super().__init__(method_name)
1398
1399            test_method = getattr(self, method_name, None)
1400            if test_method is None:
1401                return
1402
1403            if not should_allow_vmap_fallback_usage(test_method):
1404                setattr(
1405                    self,
1406                    method_name,
1407                    self._wrap_method_with_vmap_fallback_check(test_method),
1408                )
1409
1410        def _wrap_method_with_vmap_fallback_check(self, method):
1411            # msg = (
1412            #     'Expected the test to not invoke the vmap fallback path, i.e., '
1413            #     'all of the operators being tested in this test should have batching '
1414            #     'rules implemented. If you are intentionally testing something to '
1415            #     'do with the fallback path, use allowVmapFallbackUsage. Otherwise, '
1416            #     'please make sure that batching rules are implemented for the '
1417            #     'operator(s) being tested.'
1418            # )
1419
1420            @functools.wraps(method)
1421            def wrapper(self, *args, **kwargs):
1422                with warnings.catch_warnings(record=True):
1423                    warnings.simplefilter("always")
1424                    with EnableVmapFallbackWarnings():
1425                        method(*args, **kwargs)
1426                    # for captured_warning in wa:
1427                    #     self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg)
1428
1429            return types.MethodType(wrapper, self)
1430
1431        @allowVmapFallbackUsage
1432        def test_vmap_fallback_check_ok(self):
1433            # One day we'll implement a batching rule for torch.var_mean.
1434            # When that happens, please change the example to use an
1435            # operator that doesn't have a batching rule implemented.
1436            op_using_fallback = torch.var_mean
1437            vmap(op_using_fallback)(torch.rand(3))
1438
1439        @unittest.expectedFailure
1440        def test_vmap_fallback_check(self):
1441            @self._wrap_method_with_vmap_fallback_check
1442            def no_fallback(self):
1443                pass
1444
1445            # One day we'll implement a batching rule for torch.var_mean.
1446            # When that happens, please change the example to use an
1447            # operator that doesn't have a batching rule implemented.
1448            op_using_fallback = torch.var_mean
1449
1450            @self._wrap_method_with_vmap_fallback_check
1451            def uses_fallback(self):
1452                vmap(op_using_fallback)(torch.rand(3))
1453
1454            no_fallback(self)
1455
1456            with self.assertRaises(AssertionError):
1457                uses_fallback(self)
1458
1459
1460def _make_case(op, input_getter=TensorFactory.randn):
1461    return (op, input_getter)
1462
1463
1464@markDynamoStrictTest
1465class TestVmapOperators(Namespace.TestVmapBase):
1466    def _vmap_test(self, *args, **kwargs):
1467        return _vmap_test(self, *args, **kwargs)
1468
1469    def _vmap_view_test(self, *args, **kwargs):
1470        self._vmap_test(*args, **kwargs, check_view=True)
1471
1472    def _test_unary(self, op, getter, device, *args, **kwargs):
1473        test = functools.partial(self._vmap_test, *args, **kwargs)
1474        B0, B1 = 7, 11
1475
1476        # Single vmap, various in_dims / out_dims
1477        test(op, [getter([B0, 3], device)])
1478        test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
1479        test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
1480
1481        # Doubly nested vmap
1482        test(vmap(op), [getter([B0, B1], device)])
1483        test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
1484        test(
1485            vmap(op, in_dims=2),
1486            [getter([2, 5, B0, B1, 3], device)],
1487            in_dims=2,
1488            out_dims=2,
1489        )
1490
1491    @parametrize(
1492        "case",
1493        [
1494            (torch.abs, TensorFactory.randn),
1495            (torch.acos, TensorFactory.rand),
1496            (torch.asin, TensorFactory.rand),
1497            (torch.atan, TensorFactory.rand),
1498            (torch.ceil, TensorFactory.randn),
1499            (torch.cos, TensorFactory.rand),
1500            (torch.cosh, TensorFactory.rand),
1501            (torch.digamma, TensorFactory.rand),
1502            (torch.exp, TensorFactory.randn),
1503            (torch.expm1, TensorFactory.randn),
1504            (torch.floor, TensorFactory.randn),
1505            (torch.frac, TensorFactory.randn),
1506            (torch.lgamma, TensorFactory.rand),
1507            (torch.log, TensorFactory.randp1),
1508            (torch.log10, TensorFactory.randp1),
1509            (torch.log1p, TensorFactory.randp1),
1510            (torch.log2, TensorFactory.randp1),
1511            (torch.neg, TensorFactory.randn),
1512            (torch.reciprocal, TensorFactory.randp1),
1513            (torch.relu, TensorFactory.randn),
1514            (torch.round, TensorFactory.randn),
1515            (torch.rsqrt, TensorFactory.randp1),
1516            (torch.sigmoid, TensorFactory.randn),
1517            (torch.sign, TensorFactory.randn),
1518            (torch.sin, TensorFactory.rand),
1519            (torch.sinh, TensorFactory.rand),
1520            (torch.sqrt, TensorFactory.rand),
1521            (torch.tan, TensorFactory.rand),
1522            (torch.tanh, TensorFactory.rand),
1523            (torch.trunc, TensorFactory.randn),
1524        ],
1525        name_fn=lambda x: x[0].__name__,
1526    )
1527    def test_unary_pointwise(self, case):
1528        op, getter = case
1529        self._test_unary(op, getter, "cpu")
1530
1531        # test in-place
1532        method = getattr(Tensor, f'{op.__name__ + "_"}')
1533        self._test_unary(method, getter, "cpu", check_propagates_grad=False)
1534
1535    def test_clone(self):
1536        # Some basic tests
1537        self._test_unary(lambda x: x.clone(), TensorFactory.randn, "cpu")
1538        self._test_unary(
1539            lambda x: x.clone(memory_format=torch.preserve_format),
1540            TensorFactory.randn,
1541            "cpu",
1542        )
1543        self._test_unary(
1544            lambda x: x.clone(memory_format=torch.contiguous_format),
1545            TensorFactory.randn,
1546            "cpu",
1547        )
1548
1549        # Test that the per-examples are contiguous when using torch.contiguous_format
1550        def clone_contiguous(x):
1551            return x.clone(memory_format=torch.contiguous_format)
1552
1553        B0, B1 = 3, 5
1554        x = torch.randn(2, B0, 7)
1555        y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x)
1556        self.assertTrue(y.movedim(1, 0).is_contiguous())
1557        self.assertTrue(y[:, 0, :].is_contiguous())
1558
1559        x = torch.randn(2, B0, 7, B1)
1560        y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x)
1561        self.assertTrue(y.is_contiguous())
1562        self.assertTrue(y[0][0].is_contiguous())
1563
1564        msg = r"only supported with memory_format torch.preserve_format or torch.contiguous_format"
1565        with self.assertRaisesRegex(RuntimeError, msg):
1566            vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0))
1567        with self.assertRaisesRegex(RuntimeError, msg):
1568            vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(
1569                torch.randn(B0)
1570            )
1571
1572    def test_weird_matmul_case(self):
1573        # Check that this doesn't crash.
1574        # https://github.com/pytorch/functorch/issues/417
1575        x = torch.randn(5, 2, 2, 2)
1576        y = torch.randn(5, 7, 2)
1577
1578        vmap(vmap(torch.matmul, in_dims=(None, 0)))(x, y)
1579
1580    @parametrize(
1581        "case",
1582        (
1583            (torch.clamp_min_, TensorFactory.randn),
1584            (torch.clamp_max_, TensorFactory.randn),
1585        ),
1586        name_fn=lambda x: x[0].__name__,
1587    )
1588    def test_clamp_inplace_variant(self, case):
1589        test = self._vmap_test
1590
1591        def get_number(getter):
1592            return getter([]).item()
1593
1594        op, getter = case
1595        device = "cpu"
1596        B0, B1 = 7, 11
1597
1598        # Single vmap: op(Tensor, Tensor)
1599        test(
1600            op,
1601            (getter([B0, 3], device), getter([B0, 3], device)),
1602            check_propagates_grad=False,
1603        )
1604        test(
1605            op,
1606            (getter([B0], device), getter([B0], device)),
1607            check_propagates_grad=False,
1608        )
1609        test(
1610            op,
1611            (getter([2, B0, 3], device), getter([2, B0, 3], device)),
1612            in_dims=(1, 1),
1613            check_propagates_grad=False,
1614        )
1615        test(
1616            op,
1617            (getter([B0, 2, 3], device), getter([2, B0, 3], device)),
1618            in_dims=(0, 1),
1619            out_dims=1,
1620            check_propagates_grad=False,
1621        )
1622        test(
1623            op,
1624            (getter([B0, 2, 3], device), getter([1, 1], device)),
1625            in_dims=(0, None),
1626            check_propagates_grad=False,
1627        )
1628        test(
1629            op,
1630            (getter([B0, 3], device), getter([B0, 3], device)),
1631            in_dims=(0, 0),
1632            check_propagates_grad=False,
1633        )
1634
1635        # Nested vmap: op(Tensor, Tensor)
1636        test(
1637            vmap(op),
1638            (getter([B0, B1, 2, 3], device), getter([B0, B1, 1, 3], device)),
1639            check_propagates_grad=False,
1640        )
1641
1642        # Python number overload: op(Tensor, Number)
1643        number = get_number(getter)
1644        self._test_unary(
1645            lambda t: op(t, number), getter, device, check_propagates_grad=False
1646        )
1647
1648    @parametrize(
1649        "case",
1650        [
1651            subtest(_make_case(torch.clamp_min), name="clamp_min"),
1652            subtest(_make_case(torch.clamp_max), name="clamp_max"),
1653        ],
1654    )
1655    def test_clamp_variant(self, case):
1656        test = self._vmap_test
1657
1658        def get_number(getter):
1659            return getter([]).item()
1660
1661        op, getter = case
1662        device = "cpu"
1663        B0, B1 = 7, 11
1664
1665        # Single vmap: op(Tensor, Tensor)
1666        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
1667        test(op, (getter([B0], device), getter([B0, 2, 3], device)))
1668        test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
1669        test(
1670            op,
1671            (getter([B0], device), getter([2, B0, 3], device)),
1672            in_dims=(0, 1),
1673            out_dims=1,
1674        )
1675        test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
1676        test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(None, 0))
1677
1678        # Nested vmap: op(Tensor, Tensor)
1679        test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
1680        test(
1681            vmap(op, in_dims=(None, 0)),
1682            (getter([B0, 2, 3], device), getter([B1, 3], device)),
1683            in_dims=(0, None),
1684        )
1685
1686        # Python number overload: op(Tensor, Number)
1687        number = get_number(getter)
1688        self._test_unary(lambda t: op(t, number), getter, device)
1689
1690    def test_copy_(self):
1691        x = torch.randn(3)
1692        y = torch.randn(3)
1693        vmap(Tensor.copy_)(x, y)
1694        self.assertEqual(x, y)
1695
1696        x = torch.randn(3)
1697        y = torch.randn(3, 2)
1698        vmap(Tensor.copy_, in_dims=(1, None))(y, x)
1699        self.assertEqual(y, x.expand(2, 3).t())
1700
1701        x = torch.randn(3)
1702        y = torch.randn(2, 3)
1703        with self.assertRaisesRegex(RuntimeError, "inplace"):
1704            vmap(Tensor.copy_, in_dims=(None, 0))(x, y)
1705
1706    def test_silu_backward(self):
1707        test = self._vmap_test
1708        device = "cpu"
1709        getter = TensorFactory.randp1
1710        B0 = 7
1711        op = torch.ops.aten.silu_backward
1712
1713        # Single vmap: op(Tensor, Tensor)
1714        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
1715        test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0))
1716        test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None))
1717
1718    @skipIf(
1719        TEST_WITH_TORCHDYNAMO
1720        and os.getenv("BUILD_ENVIRONMENT", "") == "linux-focal-py3.8-clang10",
1721        "Segfaults with dynamo on focal, see https://github.com/pytorch/pytorch/issues/107173",
1722    )
1723    @parametrize(
1724        "case",
1725        [
1726            subtest(_make_case(torch.add), name="add"),
1727            subtest(_make_case(lambda x, y: x + y), name="add_dunder"),
1728            subtest(_make_case(torch.sub), name="sub"),
1729            subtest(_make_case(lambda x, y: x - y), name="sub_dunder"),
1730            subtest(_make_case(torch.mul), name="mul"),
1731            subtest(_make_case(lambda x, y: x * y), name="mul_dunder"),
1732            subtest(
1733                _make_case(torch.div, input_getter=TensorFactory.randp1), name="div"
1734            ),
1735            subtest(
1736                _make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
1737                name="div_dunder",
1738            ),
1739            subtest(
1740                _make_case(torch.pow, input_getter=TensorFactory.randp1), name="pow"
1741            ),
1742            subtest(
1743                _make_case(lambda x, y: x**y, input_getter=TensorFactory.randp1),
1744                name="pow_dunder",
1745            ),
1746        ],
1747    )
1748    def test_arithmetic(self, case):
1749        test = self._vmap_test
1750
1751        def get_number(getter):
1752            return getter([]).item()
1753
1754        op, getter = case
1755        device = "cpu"
1756        B0, B1 = 7, 11
1757
1758        # Single vmap: op(Tensor, Tensor)
1759        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
1760        test(op, (getter([B0], device), getter([B0, 2, 3], device)))
1761        test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
1762        test(
1763            op,
1764            (getter([B0], device), getter([2, B0, 3], device)),
1765            in_dims=(0, 1),
1766            out_dims=1,
1767        )
1768        test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
1769        test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
1770
1771        # Nested vmap: op(Tensor, Tensor)
1772        test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
1773        test(
1774            vmap(op, in_dims=(None, 0)),
1775            (getter([B0, 2, 3], device), getter([B1, 3], device)),
1776            in_dims=(0, None),
1777        )
1778
1779        # Python number overload: op(Tensor, Number) (and vice-versa)
1780        number = get_number(getter)
1781        self._test_unary(lambda t: op(t, number), getter, device)
1782        number = get_number(getter)
1783        self._test_unary(lambda t: op(number, t), getter, device)
1784
1785        # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
1786        test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
1787        test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
1788        test(op, (getter([B0], device), getter([B0], device)))
1789
1790        # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
1791        test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
1792        test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
1793
1794        if not torch.cuda.is_available():
1795            return
1796
1797        # TODO(rzou): fix the following
1798        # # Test cross-device scalars
1799        # number = get_number(getter)
1800        # self._test_unary(lambda t: op(t, number), getter, device='cuda')
1801        # self._test_unary(lambda t: op(number, t), getter, device='cuda')
1802        # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
1803
1804    def test_as_strided(self):
1805        def _test(sizes, strides, offset, tensor, lambd):
1806            # bdim at dim 0 test
1807            result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor)
1808            expected = vmap(lambd)(tensor)
1809            self.assertTrue(result._base is expected._base)
1810            self.assertEqual(result, expected)
1811
1812            # bdim at dim -1 test
1813            tensor = tensor.movedim(0, -1)
1814            result = vmap(lambda t: t.as_strided(sizes, strides, offset), -1)(tensor)
1815            expected = vmap(lambd, -1)(tensor)
1816            self.assertTrue(result._base is expected._base)
1817            self.assertEqual(result, expected)
1818
1819        # single vmap test
1820        B0 = 5
1821        # Each Tensor has shape [B0, 2, 3]; the expressions below
1822        # are just to get tensors of different strides that have shape [B0, 2, 3]
1823        tensors = [
1824            # contiguous
1825            torch.randn(B0, 2, 3),
1826            # non-contiguous
1827            torch.randn(B0, 3, 2).transpose(1, 2),
1828            torch.randn(3, 2, B0).movedim(-1, 0).transpose(1, 2),
1829            # non-zero storage offset
1830            torch.randn(2, B0, 2, 3)[1],
1831            torch.randn(2, 2, B0, 3)[1].movedim(1, 0),
1832            # non-contiguous strides, zero storage offset
1833            torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0],
1834            torch.randn(2, 4, B0, 3, 7).movedim(2, 0)[:, :, 0, :, 0],
1835            # non-contiguous strides, non-zero storage offset
1836            torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1],
1837            torch.randn(2, 4, 3, 7, B0).movedim(-1, 0)[:, :, 2, :, 1],
1838        ]
1839
1840        for x in tensors:
1841            S0, S1 = x.stride()[1:]
1842            offset = x.storage_offset()
1843
1844            # Broadcast
1845            _test(
1846                [5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3)
1847            )
1848            # transpose
1849            _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1))
1850            # select
1851            _test([2], [S0], offset + S1, x, lambda x: x[:, 1])
1852            # diagonal
1853            _test([2], [S0 + S1], offset, x, lambda x: x.diagonal())
1854            # strided slice
1855            _test([2], [S1 * 2], offset, x, lambda x: x[0, ::2])
1856
1857        # Nested vmap test
1858        B1 = 7
1859        x = torch.randn(B1, B0, 2, 3)
1860        S0, S1 = x.stride()[2:]
1861        result = vmap(
1862            vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1
1863        )(x)
1864        expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x)
1865        self.assertTrue(result._base is expected._base)
1866        self.assertEqual(result, expected)
1867
1868        # Check that mal-formatted size/strides doesn't crash
1869        with self.assertRaisesRegex(
1870            RuntimeError, "size and stride must have the same length"
1871        ):
1872            x = torch.randn(B0, 2, 3).transpose(0, 1)
1873            vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x)
1874
1875        # All the Sanity check #1{a,b,c} cases check that
1876        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1877        # doesn't index memory that is out of bounds of xs[i]. This condition
1878        # is important to the correctness of the as_strided batching rule
1879        # (see NOTE: [When will the as_strided_batching_rule fail?])
1880
1881        # Sanity check #1a: The maximum indexable location of
1882        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1883        # is less than or equal to the maximum indexable location of xs[i].
1884        msg = "This is not supported inside of vmap"
1885        with self.assertRaisesRegex(RuntimeError, msg):
1886            x = torch.randn(B0, 3)
1887            vmap(lambda x: x.as_strided([3], [1], 1))(x)
1888        with self.assertRaisesRegex(RuntimeError, msg):
1889            x = torch.randn(B0, 3, 5)
1890            vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x)
1891        with self.assertRaisesRegex(RuntimeError, msg):
1892            x = torch.randn(B0, B1, 3, 5)
1893            vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x)
1894
1895        # Sanity check #1b: The min indexable location of
1896        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1897        # is greater than or equal to the min indexable location of xs[i].
1898        with self.assertRaisesRegex(RuntimeError, msg):
1899            x = torch.randn(2, B0, 3)[1]
1900            vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x)
1901
1902        # Sanity check #1c:
1903        # xs[i] is a zero-dim tensor, but
1904        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1905        # is not
1906        with self.assertRaisesRegex(RuntimeError, msg):
1907            x = torch.randn(B0, 0, 3)
1908            vmap(lambda x: x.as_strided([3], [1]))(x)
1909
1910    def test_nll_loss(self):
1911        test = self._vmap_test
1912        op = F.nll_loss
1913        B = 3
1914
1915        y = torch.randn(B, 2, 5)
1916        t = torch.randint(0, 5, (B, 2))
1917        test(op, (y, t))
1918        test(functools.partial(op, reduction="sum"), (y, t))
1919        test(functools.partial(op, reduction="none"), (y, t))
1920
1921        y = torch.randn(B, 2, 5)
1922        t = torch.randint(0, 5, (2,))
1923        test(op, (y, t), in_dims=(0, None))
1924        test(functools.partial(op, reduction="sum"), (y, t), in_dims=(0, None))
1925        test(functools.partial(op, reduction="none"), (y, t), in_dims=(0, None))
1926
1927    def test_adaptive_avg_pool2d(self):
1928        test = self._vmap_test
1929        op = functools.partial(F.adaptive_avg_pool2d, output_size=(3, 3))
1930
1931        x = torch.randn(3, 5, 7, 9, 11)
1932        test(op, (x,))
1933        test(op, (x,), in_dims=(1,))
1934        test(op, (x,), in_dims=(4,))
1935
1936    def test_bmm(self):
1937        op = torch.bmm
1938        test = self._vmap_test
1939        B0, B1 = 7, 11
1940
1941        # shape mismatch
1942        msg = ""
1943        with self.assertRaisesRegex(RuntimeError, msg):
1944            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1945        with self.assertRaisesRegex(RuntimeError, msg):
1946            vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
1947        with self.assertRaisesRegex(RuntimeError, msg):
1948            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
1949
1950        # left arg is vmapped
1951        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
1952        test(
1953            vmap(op, in_dims=(0, None)),
1954            (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
1955            in_dims=(1, None),
1956        )
1957
1958        # right arg is vmapped
1959        test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
1960        test(
1961            vmap(op, in_dims=(None, 0)),
1962            (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
1963            in_dims=(None, 1),
1964        )
1965
1966        # both args are vmapped
1967        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
1968        test(
1969            vmap(op),
1970            (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)),
1971            in_dims=(1, 0),
1972        )
1973        test(
1974            vmap(op, in_dims=(0, None)),
1975            (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)),
1976            in_dims=(None, 0),
1977        )
1978
1979    def test_cat(self):
1980        test = self._vmap_test
1981        B0, B1 = 5, 7
1982
1983        # Quick hack b/c vmap can't accept a list of tensors as an argument
1984        def get_op(dim):
1985            def op(*tensors):
1986                return torch.cat(tensors, dim=dim)
1987
1988            return op
1989
1990        test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
1991        test(get_op(0), (torch.rand(B0, 0), torch.rand(B0, 0)))
1992        test(get_op(0), (torch.rand(2), torch.rand(B0, 0)), in_dims=(None, 0))
1993        test(
1994            get_op(1),
1995            (torch.rand(2, 5), torch.rand(B0, 0), torch.rand(2, 3)),
1996            in_dims=(None, 0, None),
1997        )
1998        test(get_op(1), (torch.rand(B0, 2, 3), torch.rand(B0, 0)))
1999        test(get_op(1), (torch.rand(B0, 2, 3, 4), torch.rand(0)), in_dims=(0, None))
2000        test(
2001            get_op(0),
2002            (torch.rand(0), torch.rand(B0, 2), torch.rand(B0, 0)),
2003            in_dims=(None, 0, 0),
2004        )
2005        test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
2006        test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
2007        test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
2008        test(
2009            vmap(get_op(0), in_dims=(0, None)),
2010            (torch.rand(B1, 2), torch.rand(B0, 3)),
2011            in_dims=(None, 0),
2012        )
2013        test(
2014            vmap(get_op(0), in_dims=(0, 0)),
2015            (torch.rand(B1, 2), torch.rand(B0, B1, 3)),
2016            in_dims=(None, 0),
2017        )
2018
2019    def test_unsafe_view(self):
2020        # Unsafe view isn't exposed, so we get at it via
2021        # vmap(grad(matmul))
2022        test = functools.partial(self._vmap_test, check_propagates_grad=False)
2023        B = 2
2024        x = torch.randn(B, 2, 3, 3)
2025        y = torch.randn(B, 3, 3)
2026
2027        def baz(x, y):
2028            return (x @ y).sum()
2029
2030        test(functorch.grad(baz), (x, y))
2031
2032    def test_conj(self):
2033        op = torch.conj
2034
2035        def run_test(dtype):
2036            def get(shape):
2037                return torch.randn(shape, dtype=dtype)
2038
2039            B0, B1 = 7, 11
2040            test = self._vmap_test
2041
2042            # Single vmap, various in_dims / out_dims
2043            test(op, [get([B0, 3])])
2044            test(op, [get([2, 5, B0, 3])], in_dims=2)
2045            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
2046
2047            # Doubly nested vmap
2048            test(vmap(op), [get([B0, B1])])
2049            test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
2050            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
2051
2052        # correctness tests
2053        run_test(torch.float)
2054        run_test(torch.cfloat)
2055
2056        # check that torch.conj on a non-complex tensor returns the same tensor
2057        real_tensor = torch.randn(3)
2058        result = vmap(op)(real_tensor)
2059        self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
2060
2061    def test_contiguous(self):
2062        op = Tensor.contiguous
2063
2064        self._test_unary(op, TensorFactory.randn, "cpu")
2065
2066        # check that contiguous returns the original tensor if the per-examples
2067        # are already contiguous
2068        B0 = 3
2069        x = torch.randn(B0, 2, 5, 7)
2070        x = x.movedim(0, 2)
2071        result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
2072        self.assertTrue(result is x)
2073
2074        msg = "NYI: querying is_contiguous inside of vmap for memory_format"
2075        tensor = torch.randn(B0, 3)
2076        with self.assertRaisesRegex(RuntimeError, msg):
2077            vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
2078        with self.assertRaisesRegex(RuntimeError, msg):
2079            vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)
2080
2081    def test_stride(self):
2082        B0 = 3
2083
2084        x = torch.randn(B0, 2, 5, 7)
2085
2086        def foo(x):
2087            assert x.stride() == (7 * 5, 7, 1)
2088            return x
2089
2090        vmap(foo)(x)
2091
2092        x = torch.randn(2, B0, 5, 7).movedim(1, 0)
2093
2094        def bar(x):
2095            assert x.stride() == (7 * 5 * B0, 7, 1)
2096            return x
2097
2098        vmap(bar)(x)
2099
2100    def test_chunk(self):
2101        test = self._vmap_view_test
2102        op = torch.chunk
2103        B0, B1, B2 = 7, 11, 13
2104
2105        # tests for torch.split(self, split_size: int, dim)
2106        test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
2107        test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
2108        test(
2109            vmap(op, in_dims=(0, None, None)),
2110            (torch.rand(B1, 1023, B0, 5), 4, 0),
2111            in_dims=(2, None, None),
2112        )
2113        test(
2114            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
2115            (torch.rand(B1, 2, B0, 64, B2),),
2116            in_dims=2,
2117        )
2118
2119    def test_clamp(self):
2120        clamp_cases = (
2121            (lambda t: t.clamp(min=-0.5), TensorFactory.randn),
2122            (lambda t: t.clamp(max=0.5), TensorFactory.randn),
2123            (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
2124            (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
2125            (lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
2126        )
2127        for op, getter in clamp_cases:
2128            self._test_unary(op, getter, "cpu")
2129
2130    def test_comparison_ops(self):
2131        test = functools.partial(self._vmap_test, check_propagates_grad=False)
2132
2133        getter = TensorFactory.randn
2134        B0, B1 = 7, 11
2135
2136        ops = (
2137            torch.eq,
2138            lambda x, y: x == y,
2139            torch.gt,
2140            lambda x, y: x > y,
2141            torch.ge,
2142            lambda x, y: x >= y,
2143            torch.le,
2144            lambda x, y: x <= y,
2145            torch.lt,
2146            lambda x, y: x < y,
2147            torch.ne,
2148            lambda x, y: x != y,
2149        )
2150
2151        for op in ops:
2152            # Single vmap: op(Tensor, Tensor)
2153            test(op, (getter([B0, 3]), getter([B0, 3])))
2154            test(op, (getter([B0]), getter([B0, 2, 3])))
2155            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1))
2156            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1)
2157            test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None))
2158            test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None))
2159
2160            # Nested vmap: op(Tensor, Tensor)
2161            test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3])))
2162            test(
2163                vmap(op, in_dims=(None, 0)),
2164                (getter([B0, 2, 3]), getter([B1, 3])),
2165                in_dims=(0, None),
2166            )
2167
2168            # test number as inputs
2169            number = getter([]).item()
2170            self._test_unary(
2171                lambda t: op(t, number), getter, "cpu", check_propagates_grad=False
2172            )
2173
2174    def test_cross_batch_size_three(self):
2175        # Let's test corner case when batch_size is 3 and cross' dim argument is not specified
2176        # According to the cross API, dim will be assigned to the first dim with value 3
2177        # In this test we ensure that found dim is not batch dim.
2178        op = torch.cross
2179        test = self._vmap_test
2180        B0 = B1 = 3
2181        test(op, (torch.rand(B0, 2, 3), torch.rand(B0, 2, 3)))
2182        test(
2183            vmap(op, in_dims=(0, None)),
2184            (torch.rand(B0, B1, 2, 3), torch.rand(B0, B1, 2, 3)),
2185            in_dims=(None, 1),
2186        )
2187
2188    def test_diagonal(self):
2189        tensor = torch.randn(3, 5, 7, 11, 13)
2190        test = self._vmap_view_test
2191        op = torch.diagonal
2192        test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
2193        test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
2194        test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
2195        test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
2196        test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
2197        test(
2198            vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
2199            (tensor,),
2200            in_dims=1,
2201            out_dims=1,
2202        )
2203
2204    def test_dot(self):
2205        op = torch.dot
2206        test = self._vmap_test
2207        B0, B1 = 7, 11
2208
2209        # shape mismatch
2210        msg = ""
2211        with self.assertRaisesRegex(RuntimeError, msg):
2212            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
2213        with self.assertRaisesRegex(RuntimeError, msg):
2214            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
2215        with self.assertRaisesRegex(RuntimeError, msg):
2216            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))
2217
2218        # left arg is vmapped
2219        test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
2220        test(
2221            vmap(op, in_dims=(0, None)),
2222            (torch.rand(B1, B0, 5), torch.rand(5)),
2223            in_dims=(1, None),
2224        )
2225
2226        # right arg is vmapped
2227        test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
2228        test(
2229            vmap(op, in_dims=(None, 0)),
2230            (torch.rand(5), torch.rand(B1, B0, 5)),
2231            in_dims=(None, 1),
2232        )
2233
2234        # both args are vmapped
2235        test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
2236        test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
2237        test(
2238            vmap(op, in_dims=(0, None)),
2239            (torch.rand(B1, 5), torch.rand(B0, 5)),
2240            in_dims=(None, 0),
2241        )
2242
2243    def test_expand_as(self):
2244        op = torch.Tensor.expand_as
2245        test = self._vmap_view_test
2246        B0, B1, B2 = 7, 11, 13
2247        test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
2248        test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
2249        test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
2250        test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
2251        test(
2252            vmap(op),
2253            (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)),
2254            in_dims=(0, 1),
2255        )
2256        test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
2257        test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
2258
2259    def test_fill_and_zero_inplace(self):
2260        test = functools.partial(self._vmap_test, check_propagates_grad=False)
2261        B0, B1 = 7, 11
2262        ops = (
2263            lambda t: t.fill_(0.1),
2264            lambda t: t.fill_(torch.tensor(0.2)),
2265            lambda t: t.zero_(),
2266        )
2267
2268        for op in ops:
2269            # Single vmap, various in_dims / out_dims
2270            test(op, [TensorFactory.randn([B0, 3])])
2271            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2)
2272            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
2273
2274            # Doubly nested vmap
2275            test(vmap(op), [TensorFactory.randn([B0, B1])])
2276            test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2)
2277            test(
2278                vmap(op, in_dims=2),
2279                [TensorFactory.randn([2, 5, B0, B1, 3])],
2280                in_dims=2,
2281                out_dims=2,
2282            )
2283
2284        # test when value is a batched tensor for fill_ operator
2285        B0, B1 = 3, 5
2286        test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)])
2287
2288        with self.assertRaisesRegex(RuntimeError, ""):
2289            # Runtime Error is thrown when the tensor being written to isn't being vmapped over
2290            vmap(Tensor.fill_, (None, 0))(
2291                TensorFactory.randn([B0, B1]), TensorFactory.randn([B0])
2292            )
2293
2294    def _test_complex_views(self, op, dtypes):
2295        test = self._vmap_view_test
2296
2297        def run_test(op, dtype):
2298            def get(shape):
2299                return torch.randn(shape, dtype=dtype)
2300
2301            B0, B1 = 7, 11
2302
2303            # Single vmap, various in_dims / out_dims
2304            test(op, [get([B0, 3])])
2305            test(op, [get([3, B0])], in_dims=1)
2306            test(op, [get([2, 5, B0, 3])], in_dims=2)
2307            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
2308
2309            # Doubly nested vmap
2310            test(vmap(op), [get([B0, B1])])
2311            test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4)
2312            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
2313
2314        for dtype in dtypes:
2315            run_test(op, dtype)
2316
2317    def test_real(self):
2318        self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble])
2319
2320    def test_imag(self):
2321        self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble])
2322
2323    def test_view_as_real(self):
2324        self._test_complex_views(
2325            torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble]
2326        )
2327
2328    def test_view_as_complex(self):
2329        def run_test(dtype):
2330            def get(shape):
2331                return torch.randn(shape, dtype=dtype)
2332
2333            op = torch.view_as_complex
2334            test = self._vmap_view_test
2335            B0, B1 = 7, 11
2336
2337            # Single vmap, various in_dims / out_dims
2338            test(op, [get([B0, 3, 2])])
2339            test(op, [get([2, 5, B0, 3, 2])], in_dims=2)
2340            test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2)
2341
2342            # Doubly nested vmap
2343            test(vmap(op), [get([B0, B1, 2])])
2344            test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2)
2345            test(
2346                vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], in_dims=2, out_dims=2
2347            )
2348
2349            # Interesting case #1: Batch dim directly before dim of size 2
2350            test(op, [get([3, B0, 2])], in_dims=1)
2351            test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2)
2352
2353            # Interesting case #2: Batch dim at end of tensor, success cases
2354            # view_as_complex requires that the dim with size 2 have stride 1
2355            # in order for the view to function property
2356            test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1)
2357            test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)])
2358            test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)])
2359
2360            # Interesting case #3: Batch dim at end of tensor, failure cases
2361            msg = "Tensor must have a last dimension with stride 1"
2362            with self.assertRaisesRegex(RuntimeError, msg):
2363                vmap(op, in_dims=1)(get([2, B0]))
2364            with self.assertRaisesRegex(RuntimeError, msg):
2365                vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1]))
2366
2367            # Invalid input: no dimension of size 2
2368            msg = "Input tensor must have one or more dimensions"
2369            with self.assertRaisesRegex(RuntimeError, msg):
2370                vmap(op)(get([B0]))
2371            with self.assertRaisesRegex(RuntimeError, msg):
2372                vmap(vmap(op))(get([B0, B1]))
2373
2374            # Invalid input: Batch dim has size 2, but the logical last dim does
2375            # not have size 2
2376            msg = "Tensor must have a last dimension of size 2"
2377            with self.assertRaisesRegex(RuntimeError, msg):
2378                vmap(op, in_dims=1)(get([3, 2]))
2379
2380        for dtype in [torch.float, torch.double]:
2381            run_test(dtype)
2382
2383    def test_is_complex(self):
2384        ctensor = torch.randn(3, dtype=torch.cfloat)
2385        tensor = torch.randn(3)
2386
2387        def foo(x):
2388            if x.is_complex():
2389                return torch.tensor(1)
2390            else:
2391                return torch.tensor(0)
2392
2393        self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
2394        self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
2395
2396    def test_is_floating_point(self):
2397        float_tensor = torch.tensor([1.0, 2.0, 3.0])
2398        long_tensor = torch.tensor([1, 2, 3])
2399
2400        def foo(x):
2401            if x.is_floating_point():
2402                return torch.tensor(1)
2403            else:
2404                return torch.tensor(0)
2405
2406        self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
2407        self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
2408
2409    @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
2410    def test_is_contiguous(self):
2411        def foo(x):
2412            if x.is_contiguous():
2413                return torch.tensor(1.0)
2414            else:
2415                return torch.tensor(0.0)
2416
2417        B0, B1 = 3, 5
2418
2419        # Single batch dim
2420        contig = torch.randn(B0, 2, 7)
2421        self.assertEqual(vmap(foo)(contig), torch.ones(B0))
2422
2423        noncontig = torch.randn(2, B0, 7)
2424        self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))
2425
2426        noncontig = torch.randn(2, B0, 7).movedim(1, 0)
2427        self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))
2428
2429        noncontig = torch.randn(2, 7, B0)
2430        self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))
2431
2432        # Multiple batch dims
2433        contig = torch.randn(B0, B1, 3)
2434        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
2435
2436        contig = torch.randn(B1, B0, 3)
2437        self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))
2438
2439        contig = torch.randn(B1, B0, 3).movedim(0, 1)
2440        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
2441
2442        noncontig = torch.randn(B0, 3, B1)
2443        self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))
2444
2445        # is_contiguous on empty tensor is True
2446        def bar(x):
2447            assert x.is_contiguous()
2448            return x
2449
2450        vmap(bar)(torch.randn(B0, 0, 3))
2451        vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
2452        vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2))
2453
2454        # is_contiguous with other memory formats
2455        def baz(x, memory_format):
2456            x.is_contiguous(memory_format=memory_format)
2457            return x
2458
2459        msg = "NYI: querying is_contiguous inside of vmap for memory_format"
2460        tensor = torch.randn(B0, 2, 7, 3)
2461        with self.assertRaisesRegex(RuntimeError, msg):
2462            vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
2463        with self.assertRaisesRegex(RuntimeError, msg):
2464            vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
2465
2466        for mf in (torch.channels_last, torch.channels_last_3d):
2467
2468            @torch.compile(backend="eager", fullgraph=True)
2469            def f(x):
2470                if x.is_contiguous(memory_format=mf):
2471                    return x.sin()
2472                return x.cos()
2473
2474            with self.assertRaisesRegex(RuntimeError, msg):
2475                vmap(f)(torch.randn(3, 3))
2476
2477    def test_unsqueeze(self):
2478        op = torch.unsqueeze
2479        test = self._vmap_view_test
2480        B0, B1 = 7, 11
2481
2482        # unsqueeze dim 0
2483        test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
2484        test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
2485
2486        # unsqueeze last dim (positive)
2487        test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
2488        test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
2489
2490        # unsqueeze last dim (negative)
2491        test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None))
2492        test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None))
2493
2494        # nested vmaps
2495        def unsqueeze_0(x):
2496            return torch.unsqueeze(x, 0)
2497
2498        def unsqueeze_last(x):
2499            return torch.unsqueeze(x, -1)
2500
2501        # bdims in canonical order
2502        test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2),))
2503        test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),))
2504
2505        # wild bdims
2506        test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2)
2507        test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
2508        test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2)
2509        test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
2510
2511    def test_movedim(self):
2512        op = torch.movedim
2513        test = self._vmap_view_test
2514        B0, B1, B2 = 7, 11, 13
2515
2516        # movedim(tensor, int, int) variant
2517        test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
2518        test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
2519        test(
2520            vmap(op, in_dims=(0, None, None)),
2521            (torch.rand(B1, 2, B0, 5), 0, 1),
2522            in_dims=(2, None, None),
2523        )
2524        test(
2525            vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
2526            (torch.rand(B1, 2, B0, 5, B2), 0, 1),
2527            in_dims=(2, None, None),
2528        )
2529
2530        # movedim(tensor, intlist, intlist) variant
2531        test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
2532        test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
2533        test(
2534            vmap(op, in_dims=(0, None, None)),
2535            (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]),
2536            in_dims=(2, None, None),
2537        )
2538        test(
2539            vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
2540            (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]),
2541            in_dims=(2, None, None),
2542        )
2543
2544    def test_mm(self):
2545        op = torch.mm
2546        test = self._vmap_test
2547        B0, B1 = 7, 11
2548
2549        # shape mismatch
2550        msg = "Shape mismatch"
2551        with self.assertRaisesRegex(RuntimeError, msg):
2552            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
2553        with self.assertRaisesRegex(RuntimeError, msg):
2554            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
2555        with self.assertRaisesRegex(RuntimeError, msg):
2556            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
2557
2558        # left arg is vmapped
2559        test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
2560        test(
2561            vmap(op, in_dims=(0, None)),
2562            (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
2563            in_dims=(1, None),
2564        )
2565
2566        # right arg is vmapped
2567        test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
2568        test(
2569            vmap(op, in_dims=(None, 0)),
2570            (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
2571            in_dims=(None, 1),
2572        )
2573
2574        # both args are vmapped
2575        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
2576        test(
2577            vmap(op),
2578            (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)),
2579            in_dims=(1, 0),
2580        )
2581        test(
2582            vmap(op, in_dims=(0, None)),
2583            (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)),
2584            in_dims=(None, 0),
2585        )
2586
2587    def test_mv(self):
2588        op = torch.mv
2589        test = self._vmap_test
2590        B0, B1 = 7, 11
2591
2592        # shape mismatch
2593        msg = ""
2594        with self.assertRaisesRegex(RuntimeError, msg):
2595            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
2596        with self.assertRaisesRegex(RuntimeError, msg):
2597            vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
2598        with self.assertRaisesRegex(RuntimeError, msg):
2599            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))
2600
2601        # left arg is vmapped
2602        test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
2603        test(
2604            vmap(op, in_dims=(0, None)),
2605            (torch.rand(B1, B0, 2, 5), torch.rand(5)),
2606            in_dims=(1, None),
2607        )
2608
2609        # right arg is vmapped
2610        test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
2611        test(
2612            vmap(op, in_dims=(None, 0)),
2613            (torch.rand(2, 5), torch.rand(B1, B0, 5)),
2614            in_dims=(None, 1),
2615        )
2616
2617        # both args are vmapped
2618        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
2619        test(
2620            vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)
2621        )
2622        test(
2623            vmap(op, in_dims=(0, None)),
2624            (torch.rand(B1, 2, 5), torch.rand(B0, 5)),
2625            in_dims=(None, 0),
2626        )
2627
2628    def test_narrow(self):
2629        op = torch.narrow
2630        test = self._vmap_view_test
2631        B0, B1, B2 = 7, 11, 13
2632
2633        test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
2634        test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
2635        test(
2636            vmap(op, in_dims=(0, None, None, None)),
2637            (torch.rand(B1, 2, B0, 5), 1, 0, 0),
2638            in_dims=(2, None, None, None),
2639        )
2640        test(
2641            vmap(
2642                vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
2643            ),
2644            (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3),
2645            in_dims=(2, None, None, None),
2646        )
2647
2648    def test_new_empty(self):
2649        # Empty is non-deterministic so we just check that the shape of the
2650        # output tensor is what we expect and that the vmap fallback isn't used.
2651        op = Tensor.new_empty
2652
2653        B0, B1 = 7, 11
2654
2655        result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
2656        self.assertEqual(result.shape, [B0, 2, 3])
2657
2658        result = vmap(lambda x: op(x, []))(torch.randn(B0))
2659        self.assertEqual(result.shape, [B0])
2660
2661        result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
2662        self.assertEqual(result.shape, [B0, B1, 2, 3])
2663
2664    def test_new_empty_strided(self):
2665        # Empty is non-deterministic so we just check that the size and shape
2666        # of the output are what we expect and that the vmap fallback isn't used
2667        B0, B1 = 7, 11
2668
2669        def _test_single_vmap(size, stride, B0):
2670            x = torch.randn(B0)
2671            result = vmap(lambda x: x.new_empty_strided(size, stride))(x)
2672            S = torch.empty_strided(size, stride).storage().size()
2673            self.assertEqual(result.shape, [B0] + size)
2674            self.assertEqual(result.stride(), [S] + stride)
2675
2676        def _test_double_vmap(size, stride, B0, B1):
2677            x = torch.randn(B0, B1)
2678            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x)
2679            S = torch.empty_strided(size, stride).storage().size()
2680            self.assertEqual(result.shape, [B0, B1] + size)
2681            self.assertEqual(result.stride(), [B1 * S, S] + stride)
2682
2683            x = torch.randn(B1, B0)
2684            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(
2685                x
2686            )
2687            S = x.new_empty_strided(size, stride).storage().size()
2688            self.assertEqual(result.shape, [B0, B1] + size)
2689            self.assertEqual(result.stride(), [B1 * S, S] + stride)
2690
2691        # contiguous case
2692        _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0)
2693        _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1)
2694
2695        # expanded
2696        _test_single_vmap([2, 3, 5], [0, 5, 1], B0)
2697        _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1)
2698
2699        # some of these cases are pretty strange, just verifying that if
2700        # empty_strided allows them then BatchedTensor.new_empty_strided
2701        # can as well
2702        for shape in [[2, 3, 4], [0, 2, 0]]:
2703            for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]:
2704                _test_single_vmap(shape, strides, B0)
2705                _test_double_vmap(shape, strides, B0, B1)
2706
2707    def test_new_zeros(self):
2708        op = Tensor.new_zeros
2709        test = functools.partial(self._vmap_test, check_propagates_grad=False)
2710        B0, B1 = 7, 11
2711
2712        test(lambda x: op(x, 2, 3), (torch.rand(B0),))
2713        test(lambda x: op(x, []), (torch.rand(B0),))
2714        test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))
2715
2716    def test_select(self):
2717        op = torch.select
2718        test = self._vmap_view_test
2719        B0, B1, B2 = 7, 11, 13
2720        test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
2721        test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
2722        test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2723        test(
2724            vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)),
2725            (torch.rand(B1, 2, B0, B2, 5),),
2726            in_dims=2,
2727        )
2728
2729    def test_roll_no_dims(self):
2730        op = torch.roll
2731        test = self._vmap_test
2732        B0, B1, B2 = 7, 11, 13
2733        test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
2734        test(op, (torch.rand(2, B0, 5), 3), in_dims=(1, None))
2735        test(vmap(lambda t: op(t, 3)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2736        test(
2737            vmap(vmap(lambda t: op(t, 3), in_dims=1)),
2738            (torch.rand(B1, 2, B0, B2, 5),),
2739            in_dims=2,
2740        )
2741
2742    def test_stack(self):
2743        test = self._vmap_test
2744        B0, B1 = 5, 7
2745
2746        # Quick hack b/c vmap can't accept a list of tensors as an argument
2747        def get_op(dim):
2748            def op(*tensors):
2749                return torch.stack(tensors, dim=dim)
2750
2751            return op
2752
2753        test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
2754        test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
2755        test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
2756        test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
2757        test(
2758            vmap(get_op(0), in_dims=(0, None)),
2759            (torch.rand(B1, 2), torch.rand(B0, 2)),
2760            in_dims=(None, 0),
2761        )
2762        test(
2763            vmap(get_op(0), in_dims=(0, 0)),
2764            (torch.rand(B1, 2), torch.rand(B0, B1, 2)),
2765            in_dims=(None, 0),
2766        )
2767
2768    def test_slice(self):
2769        test = self._vmap_view_test
2770        B0, B1, B2 = 7, 11, 13
2771        test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
2772        test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
2773        test(
2774            vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2
2775        )
2776        test(
2777            vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
2778            (torch.rand(3, 5, B0, B1, B2),),
2779            in_dims=2,
2780        )
2781
2782    @xfailIfTorchDynamo
2783    def test_squeeze(self):
2784        def verify_behavior(op, min_ndim=1):
2785            test = self._vmap_view_test
2786            B0, B1 = 1, 11
2787            # These tests cannot be used with an operator that requires more
2788            # than 1 dimension after batching.
2789            if min_ndim <= 1:
2790                test(op, (torch.rand(B0),))
2791                test(op, (torch.rand(B1),))
2792                test(vmap(op), (torch.rand(B0, B1, 1),))
2793                test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
2794            test(op, (torch.rand(B0, 3, 5),))
2795            test(op, (torch.rand(1, B0, 5),), in_dims=1)
2796            test(op, (torch.rand(B0, 0, 1, 5, 1),))
2797            test(op, (torch.rand(B0, 1, 1, 1, 1),))
2798            test(vmap(op), (torch.rand(B0, B1, 1, 3, 4),))
2799            test(vmap(op), (torch.rand(B1, 1, B0, 4, 5),), in_dims=2)
2800
2801        verify_behavior(torch.squeeze)
2802        verify_behavior(lambda x: torch.squeeze(x, dim=0), min_ndim=1)
2803        verify_behavior(lambda x: torch.squeeze(x, dim=1), min_ndim=2)
2804        verify_behavior(lambda x: torch.squeeze(x, dim=-1), min_ndim=2)
2805        verify_behavior(lambda x: torch.squeeze(x, dim=-2), min_ndim=3)
2806
2807        msg = ""
2808        try:
2809            torch.squeeze(torch.rand(10), dim=1)
2810        except IndexError as err:
2811            msg = str(err)
2812        with self.assertRaises(RuntimeError, msg=msg):
2813            vmap(lambda x: torch.squeeze(x, dim=1))(torch.rand(10))
2814
2815    def _test_mean_sum_dim(self, op):
2816        test = self._vmap_test
2817        B0, B1 = 5, 7
2818
2819        # Single vmap, various in_dims / out_dims
2820        test(lambda x: op(x, 0), [torch.randn([B0])])
2821        test(lambda x: op(x, -1), [torch.randn([B0])])
2822        test(lambda x: op(x, 0), [torch.randn([B0, 3])])
2823        test(lambda x: op(x, -1), [torch.randn([2, 5, B0, 3])], in_dims=2)
2824        test(lambda x: op(x, 2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
2825
2826        # Doubly nested vmap
2827        test(vmap(lambda x: op(x, 0)), [torch.randn([B0, B1])])
2828        test(vmap(lambda x: op(x, -1)), [torch.randn([B0, B1])])
2829        test(vmap(lambda x: op(x, -2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
2830        test(
2831            vmap(lambda x: op(x, 2), in_dims=2),
2832            [torch.randn([2, 5, B0, B1, 3])],
2833            in_dims=2,
2834            out_dims=2,
2835        )
2836
2837    def test_sum_dim(self):
2838        self._test_mean_sum_dim(torch.sum)
2839
2840    def test_mean_dim(self):
2841        self._test_mean_sum_dim(torch.mean)
2842
2843    def test_argmax_dim(self):
2844        def test(f, args):
2845            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
2846                self.assertEqual(loop_out, batched_out)
2847
2848        B0 = 5
2849        test(lambda x: torch.argmax(x), [torch.randn(B0)])
2850        test(lambda x: torch.argmax(x), [torch.randn(B0, 2, 3)])
2851        test(lambda x: torch.argmax(x, 0), [torch.randn(B0, 2, 3)])
2852        test(lambda x: torch.argmax(x, -1), [torch.randn(B0, 2, 3)])
2853        test(lambda x: torch.argmax(x, 2), [torch.randn(B0, 2, 3)])
2854
2855    def _test_sum_mean(self, op):
2856        test = self._vmap_test
2857        B0, B1 = 5, 7
2858
2859        # Single vmap, various in_dims / out_dims
2860        test(op, [torch.randn([B0])])
2861        test(op, [torch.randn([B0, 3])])
2862        test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
2863        test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
2864
2865        # Doubly nested vmap
2866        test(vmap(op), [torch.randn([B0, B1])])
2867        test(vmap(op), [torch.randn([B1, 2, 5, B0, 3])])
2868        test(vmap(op), [torch.randn([2, 5, B0, B1, 3])], in_dims=2)
2869
2870    def test_sum(self):
2871        self._test_sum_mean(torch.sum)
2872
2873    def test_mean(self):
2874        self._test_sum_mean(torch.mean)
2875
2876    def test_repeat(self):
2877        test = self._vmap_test
2878        B0 = 7
2879        op = Tensor.repeat
2880        test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
2881        test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)
2882
2883    @skipIfTorchDynamo()
2884    def test_slogdet(self):
2885        test = functools.partial(self._vmap_test, check_propagates_grad=False)
2886        B0 = 7
2887        op = torch.linalg.slogdet
2888        test(op, (torch.rand(B0, 1, 1),))
2889        test(op, (torch.rand(B0, 2, 2),))
2890        test(op, (torch.rand(B0, 3, 2, 2),))
2891        test(op, (torch.rand(3, 2, 2, B0),), in_dims=3)
2892
2893    def test_reshape(self):
2894        test = self._vmap_test
2895        B0, B1, B2 = 7, 11, 13
2896        op = torch.reshape
2897        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
2898        test(
2899            op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False
2900        )
2901        test(
2902            vmap(lambda t: t.reshape([-1])),
2903            (torch.rand(B0, B1, 2, 5),),
2904            check_view=True,
2905        )
2906        test(
2907            vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
2908            (torch.rand(3, B1, 2, B2, 5, B0),),
2909            in_dims=5,
2910            check_view=False,
2911        )
2912
2913    def test_reshape_as(self):
2914        test = self._vmap_test
2915        B0, B1, B2 = 7, 11, 13
2916        op = torch.Tensor.reshape_as
2917        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
2918        test(
2919            op,
2920            (torch.rand(2 * 5), torch.rand(B0, 2, 5)),
2921            in_dims=(None, 0),
2922            check_view=True,
2923        )
2924        test(
2925            op,
2926            (torch.rand(B0, 2 * 5), torch.rand(2, 5)),
2927            in_dims=(0, None),
2928            check_view=True,
2929        )
2930
2931        test(
2932            op,
2933            (torch.rand(2, B0, 5), torch.rand(1, 1, 10)),
2934            in_dims=(1, None),
2935            check_view=False,
2936        )
2937
2938        test(
2939            vmap(op),
2940            (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)),
2941            check_view=True,
2942        )
2943        test(
2944            vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
2945            (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
2946            in_dims=(5, 0),
2947            check_view=False,
2948        )
2949
2950    def test_result_type(self):
2951        def scalar_tensor_with_dtype(op):
2952            def wrapped(*args, **kwargs):
2953                dtype = op(*args, **kwargs)
2954                return torch.ones([], dtype=dtype)
2955
2956            return wrapped
2957
2958        test = self._vmap_test
2959        op = scalar_tensor_with_dtype(torch.result_type)
2960
2961        B0 = 2
2962
2963        test(
2964            op,
2965            (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
2966            check_propagates_grad=False,
2967        )
2968        test(
2969            op,
2970            (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
2971            check_propagates_grad=False,
2972        )
2973
2974        test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
2975        test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)
2976
2977        test(
2978            lambda x: op(x, torch.tensor(1)),
2979            (torch.randn(B0),),
2980            check_propagates_grad=False,
2981        )
2982        test(
2983            lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
2984            (torch.randn(B0),),
2985            check_propagates_grad=False,
2986        )
2987
2988        test(
2989            op,
2990            (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
2991            check_propagates_grad=False,
2992        )
2993        test(
2994            op,
2995            (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
2996            check_propagates_grad=False,
2997        )
2998
2999        test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
3000        test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)
3001
3002        test(
3003            lambda x: op(x, torch.tensor(1)),
3004            (torch.randn(B0, 2),),
3005            check_propagates_grad=False,
3006        )
3007        test(
3008            lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
3009            (torch.randn(B0, 2),),
3010            check_propagates_grad=False,
3011        )
3012
3013        test(
3014            op,
3015            (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
3016            check_propagates_grad=False,
3017        )
3018        test(
3019            op,
3020            (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
3021            check_propagates_grad=False,
3022        )
3023
3024    def test_tensor_split(self):
3025        test = self._vmap_view_test
3026        op = torch.tensor_split
3027        B0, B1, B2 = 7, 11, 13
3028
3029        # tests for torch.tensor_split(self, indices_or_sections: int, dim)
3030        test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
3031        test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
3032        test(
3033            vmap(op, in_dims=(0, None, None)),
3034            (torch.rand(B1, 1023, B0, 5), 256, 0),
3035            in_dims=(2, None, None),
3036        )
3037        test(
3038            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
3039            (torch.rand(B1, 2, B0, 64, B2),),
3040            in_dims=2,
3041        )
3042
3043        # tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
3044        test(
3045            op,
3046            (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1),
3047            in_dims=(0, None, None),
3048        )
3049        test(
3050            op,
3051            (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1),
3052            in_dims=(1, None, None),
3053        )
3054        test(
3055            vmap(op, in_dims=(0, None, None)),
3056            (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
3057            in_dims=(2, None, None),
3058        )
3059        test(
3060            vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
3061            (torch.rand(B1, 2, B0, 64, B2),),
3062            in_dims=2,
3063        )
3064
3065    @skipIfTorchDynamo("really slow")
3066    def test_split(self):
3067        test = self._vmap_view_test
3068        op = torch.split
3069        B0, B1, B2 = 7, 11, 13
3070
3071        # tests for torch.split(self, split_size: int, dim)
3072        test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
3073        test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
3074        test(
3075            vmap(op, in_dims=(0, None, None)),
3076            (torch.rand(B1, 1023, B0, 5), 256, 0),
3077            in_dims=(2, None, None),
3078        )
3079        test(
3080            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
3081            (torch.rand(B1, 2, B0, 64, B2),),
3082            in_dims=2,
3083        )
3084
3085        # tests for torch.split(self, split_size: List[int], dim)
3086        test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
3087        test(
3088            op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None)
3089        )
3090        test(
3091            vmap(op, in_dims=(0, None, None)),
3092            (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
3093            in_dims=(2, None, None),
3094        )
3095        test(
3096            vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
3097            (torch.rand(B1, 2, B0, 64, B2),),
3098            in_dims=2,
3099        )
3100
3101    def test_trace(self):
3102        op = torch.trace
3103        test = self._vmap_test
3104        B0, B1, B2 = 7, 11, 13
3105        test(op, (torch.rand(B0, 2, 5),))
3106        test(op, (torch.rand(2, B0, 5),), in_dims=1)
3107        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
3108        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
3109
3110    def test_transpose(self):
3111        op = torch.transpose
3112        test = self._vmap_view_test
3113
3114        B0, B1, B2 = 7, 11, 13
3115        test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),))
3116        test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),))
3117        test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),))
3118        test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1)
3119        test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
3120        test(
3121            vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)),
3122            (torch.rand(B1, 2, B0, 5, B2),),
3123            in_dims=2,
3124        )
3125
3126        # Special case: scalar tensor
3127        for dim1, dim2 in itertools.product([0, -1], [0, -1]):
3128            x = torch.rand(B0)
3129            result = vmap(lambda x: op(x, dim1, dim2))(x)
3130            self.assertTrue(result is x)
3131
3132    def test_t(self):
3133        op = torch.t
3134        test = self._vmap_view_test
3135        B0, B1, B2 = 7, 11, 13
3136        test(op, (torch.rand(B0, 2, 5),))
3137        test(op, (torch.rand(2, B0, 5),), in_dims=1)
3138        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
3139        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
3140
3141    def test_T_numpy(self):
3142        def op(t):
3143            return t.T
3144
3145        test = self._vmap_view_test
3146        B0, B1, B2 = 7, 11, 13
3147        test(op, (torch.rand(B0, 2, 3, 5),))
3148        test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
3149        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
3150        test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
3151        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
3152
3153    def test_to(self):
3154        test = self._vmap_test
3155        B0, B1 = 7, 11
3156
3157        test(lambda t: t.to("cpu"), (torch.rand(B0),))
3158        test(lambda t: t.to(torch.double), (torch.rand(B0),))
3159        test(
3160            lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64))
3161        )
3162        test(
3163            lambda t, o: t.to(o),
3164            (torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
3165            in_dims=(0, None),
3166        )
3167        test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
3168
3169        # also test some casting methods
3170        test(lambda t: t.double(), (torch.rand(B0),))
3171        test(lambda t: t.float(), (torch.rand(B0),))
3172        test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
3173        test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
3174
3175    def test_unfold(self):
3176        op = torch.Tensor.unfold
3177        test = self._vmap_view_test
3178        B0, B1, B2 = 3, 2, 5
3179
3180        test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
3181        test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
3182        test(
3183            vmap(op, in_dims=(0, None, None, None)),
3184            (torch.rand(B1, 7, B0, 11), 1, 5, 1),
3185            in_dims=(2, None, None, None),
3186        )
3187        test(
3188            vmap(
3189                vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
3190            ),
3191            (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4),
3192            in_dims=(2, None, None, None),
3193        )
3194
3195    def test_unbind(self):
3196        test = self._vmap_view_test
3197        op = torch.unbind
3198        B0, B1, B2 = 7, 11, 13
3199
3200        test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
3201        test(op, (torch.rand(B0, 2, 0),))
3202        test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
3203        test(
3204            vmap(op, in_dims=(0, None)),
3205            (torch.rand(B1, 1023, B0, 5), 1),
3206            in_dims=(2, None),
3207        )
3208        test(
3209            vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
3210            (torch.rand(B1, 2, B0, 32, B2),),
3211            in_dims=2,
3212        )
3213
3214    def test_view(self):
3215        test = self._vmap_view_test
3216        B0, B1, B2 = 7, 11, 13
3217        op = torch.Tensor.view
3218
3219        # We should error out if the view would produce an incorrect result
3220        with self.assertRaises(RuntimeError):
3221            vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])
3222
3223        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
3224        test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
3225        test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
3226        test(
3227            vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
3228            (torch.rand(B2, B0, B1, 3, 2, 5),),
3229            in_dims=1,
3230        )
3231
3232    def test_view_as(self):
3233        test = self._vmap_view_test
3234        B0, B1, B2 = 7, 11, 13
3235        op = torch.Tensor.view_as
3236
3237        # We should error out if the view would produce an incorrect result
3238        with self.assertRaises(RuntimeError):
3239            vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))
3240
3241        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
3242        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
3243        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))
3244
3245        test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))
3246
3247        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
3248        test(
3249            vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
3250            (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
3251            in_dims=(2, 0),
3252        )
3253
3254    def test_conv2d(self):
3255        conv_setups = [
3256            (torch.nn.Conv1d, torch.conv1d, [2, 4, 15]),
3257            (torch.nn.Conv2d, torch.conv2d, [2, 4, 15, 20]),
3258            (torch.nn.Conv3d, torch.conv3d, [2, 4, 15, 20, 25]),
3259            # (torch.nn.ConvTranspose2d, torch.conv_transpose2d, [2, 4, 15, 20])
3260        ]
3261        for conv_mod, conv_fn, inp_shape in conv_setups:
3262            mod = conv_mod(4, 8, kernel_size=3)
3263            arg_values = [torch.randn(inp_shape), mod.weight, mod.bias]
3264            kwarg_values = {}
3265            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
3266                conv_fn, arg_values, kwarg_values
3267            ):
3268                self.assertEqual(loop_out, batched_out)
3269
3270            arg_values = [torch.randn(inp_shape), mod.weight, None]
3271            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
3272                conv_fn, arg_values, kwarg_values
3273            ):
3274                self.assertEqual(loop_out, batched_out)
3275
3276            mod2 = conv_mod(
3277                4, 8, kernel_size=3, groups=2, stride=3, padding=1, dilation=2
3278            )
3279            arg_values = [torch.randn(inp_shape), mod2.weight, mod2.bias]
3280            kwarg_values = dict(groups=2, stride=3, padding=1, dilation=2)
3281            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
3282                conv_fn, arg_values, kwarg_values
3283            ):
3284                self.assertEqual(loop_out, batched_out)
3285
3286            arg_values = [torch.randn(inp_shape), mod2.weight, None]
3287            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
3288                conv_fn, arg_values, kwarg_values
3289            ):
3290                self.assertEqual(loop_out, batched_out)
3291
3292    def test_one_hot(self):
3293        sample_inputs = [
3294            (torch.randint(0, 3, []), 3),
3295            (torch.randint(0, 3, [2, 3, 4]), 4),
3296        ]
3297        for args in sample_inputs:
3298            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
3299                F.one_hot, args, {}
3300            ):
3301                self.assertEqual(loop_out, batched_out)
3302
3303    def test_conj_bit(self):
3304        x = torch.tensor([1 + 1j, 2 + 1j])
3305
3306        def foo(x):
3307            assert not x.is_conj()
3308            y = x.conj()
3309            assert y.is_conj()
3310            return y
3311
3312        res = vmap(foo)(x)
3313        self.assertEqual(res, x.conj())
3314
3315    def test_mode_key(self):
3316        def vmap_f(x):
3317            return x + torch.randn(())
3318
3319        def naive_f(x, shape):
3320            return x + torch.randn(shape)
3321
3322        torch.manual_seed(0)
3323        out1 = vmap(vmap(vmap_f, randomness="different"), randomness="different")(
3324            torch.ones(2, 3)
3325        )
3326
3327        torch.manual_seed(0)
3328        out2 = naive_f(torch.ones(2, 3), (2, 3))
3329        self.assertEqual(out1, out2)
3330
3331        torch.manual_seed(0)
3332        out1 = vmap(vmap(vmap_f, randomness="different"), randomness="different")(
3333            torch.ones(2, 3, 4)
3334        )
3335
3336        torch.manual_seed(0)
3337        out2 = naive_f(torch.ones(2, 3, 4), (2, 3, 1))
3338        self.assertEqual(out1, out2)
3339
3340        self.assertTrue(torch.randn(()).dim() == 0)
3341
3342    @parametrize("in_dim", [0, 1, 2])
3343    @parametrize("out_dim", [0, 1, 2])
3344    @parametrize("randomness", ["error", "same"])
3345    def test_chunk_vmap(self, in_dim, out_dim, randomness):
3346        x = torch.randn(4, 5, 6)
3347
3348        def f(x):
3349            y = x.sin()
3350            if randomness != "error":
3351                y = y + torch.rand_like(x)
3352            return y
3353
3354        rs = torch.get_rng_state()
3355        expected = vmap(f, in_dims=in_dim, out_dims=out_dim, randomness=randomness)(x)
3356
3357        for chunks in [1, 2, 3, 4, 7, 10, 16]:
3358            torch.set_rng_state(rs)
3359            output = chunk_vmap(
3360                f,
3361                in_dims=in_dim,
3362                out_dims=out_dim,
3363                randomness=randomness,
3364                chunks=chunks,
3365            )(x)
3366            self.assertEqual(output, expected)
3367
3368    @parametrize("in_dim", [0, 1, 2])
3369    @parametrize("out_dim", [0, 1, 2])
3370    @parametrize("randomness", ["error", "same"])
3371    def test_vmap_chunksize(self, in_dim, out_dim, randomness):
3372        x = torch.randn(4, 5, 6)
3373        y = torch.randn_like(x)
3374
3375        # fn: Single Input/Single Output
3376        def f(x):
3377            y = x.sin()
3378            if randomness != "error":
3379                y = y + torch.rand_like(x)
3380            return y
3381
3382        f_args = (x,)
3383        f_kwargs = {"in_dims": in_dim, "out_dims": out_dim, "randomness": randomness}
3384
3385        # fn: Nested Input/Single Output
3386        def f1(pair):
3387            x, y = pair
3388            z = x.sin() + y.cos()
3389            if randomness != "error":
3390                z = z + torch.rand_like(z)
3391            return z
3392
3393        f1_args = ((x, y),)
3394        f1_kwargs = {
3395            "in_dims": ((in_dim,) * 2,),
3396            "out_dims": out_dim,
3397            "randomness": randomness,
3398        }
3399
3400        # fn: Single Input/Nested Output
3401        def f2(x):
3402            y = x.sin()
3403            if randomness != "error":
3404                y = y + torch.rand_like(x)
3405            return {"out": y, "out1": y + 2}
3406
3407        f2_args = (x,)
3408        f2_kwargs = {"in_dims": in_dim, "out_dims": out_dim, "randomness": randomness}
3409
3410        # fn: Nested Input/Nested Output (first tensor is not vmapped).
3411        def f3(inp_dict):
3412            x = inp_dict["inp"]
3413            y = inp_dict["inp1"]
3414            z = x.sin() + y.cos()
3415            if randomness != "error":
3416                z = z + torch.rand_like(z)
3417            return {"z": z, "tuple": (z, z + 1)}
3418
3419        f3_args = (
3420            {
3421                "inp": x.index_select(in_dim, torch.tensor([0])).squeeze(in_dim),
3422                "inp1": y,
3423            },
3424        )
3425        f3_kwargs = {
3426            "in_dims": ({"inp": None, "inp1": in_dim},),
3427            "out_dims": out_dim,
3428            "randomness": randomness,
3429        }
3430
3431        # fn: Nested Input/Nested Output (first argument is not a Tensor).
3432        def f4(inp_dict):
3433            x = inp_dict["inp"]
3434            y = inp_dict["inp1"]
3435            z = x + y.cos()
3436            if randomness != "error":
3437                z = z + torch.rand_like(z)
3438            return {"z": z, "tuple": (z, z + 1)}
3439
3440        f4_args = ({"inp": 2.0, "inp1": y},)
3441        f4_kwargs = {
3442            "in_dims": ({"inp": None, "inp1": in_dim},),
3443            "out_dims": out_dim,
3444            "randomness": randomness,
3445        }
3446
3447        fns_and_args = (
3448            (f, f_args, f_kwargs),
3449            (f1, f1_args, f1_kwargs),
3450            (f2, f2_args, f2_kwargs),
3451            (f3, f3_args, f3_kwargs),
3452            (f4, f4_args, f4_kwargs),
3453        )
3454        for fn, args, kwargs in fns_and_args:
3455            rs = torch.get_rng_state()
3456            expected_vmap = vmap(fn, **kwargs)(*args)
3457            for chunk_size in (1, 2, 3, 4, 7, 10, 16, 100):
3458                torch.set_rng_state(rs)
3459                output = vmap(fn, chunk_size=chunk_size, **kwargs)(*args)
3460                self.assertEqual(output, expected_vmap)
3461
3462    @parametrize("in_dim", [0, 1])
3463    @parametrize("out_dim", [0, 1])
3464    @parametrize("randomness", ["error", "same"])
3465    def test_vmap_chunksize_error(self, in_dim, out_dim, randomness):
3466        x = torch.randn(4, 5, 6)
3467
3468        def f(x):
3469            y = x.sin()
3470            if randomness != "error":
3471                y = y + torch.rand_like(x)
3472            return y
3473
3474        # Incorrect `chunk_size`
3475        for chunk_size in (-1, 0):
3476            with self.assertRaisesRegex(
3477                ValueError, "vmap: chunk_size should be None or greater than 0."
3478            ):
3479                vmap(
3480                    f,
3481                    in_dims=in_dim,
3482                    out_dims=out_dim,
3483                    randomness=randomness,
3484                    chunk_size=chunk_size,
3485                )(x)
3486
3487        # Incorrect `out_dims`
3488        msg = "out_dims is not compatible with the structure of `outputs`"
3489        with self.assertRaisesRegex(ValueError, msg):
3490            vmap(
3491                f,
3492                in_dims=in_dim,
3493                out_dims=(out_dim, out_dim),
3494                randomness=randomness,
3495                chunk_size=2,
3496            )(x)
3497
3498    @parametrize("in_dim", [0, 1])
3499    @parametrize("out_dim", [0, 1])
3500    @parametrize("randomness", ["error", "same"])
3501    def test_vmap_chunksize_composition(self, in_dim, out_dim, randomness):
3502        x = torch.randn(4, 5, 6)
3503        y = torch.randn_like(x)
3504
3505        # fn: Single Input/Single Output
3506        def f(x):
3507            y = x.sin()
3508            if randomness != "error":
3509                y = y + torch.rand_like(x)
3510            return y
3511
3512        f_args = (x,)
3513
3514        # fn: Nested Input/Single Output
3515        def f1(pair):
3516            x, y = pair
3517            z = x.sin() + y.cos()
3518            if randomness != "error":
3519                z = z + torch.rand_like(z)
3520            return z
3521
3522        f1_args = ((x, y),)
3523
3524        # fn: Single Input/Nested Output
3525        def f2(x):
3526            y = x.sin()
3527            if randomness != "error":
3528                y = y + torch.rand_like(x)
3529            return {"out": y, "out1": y + 2}
3530
3531        f2_args = (x,)
3532
3533        # fn: Nested Input/Nested Output
3534        def f3(inp_dict):
3535            x = inp_dict["inp"]
3536            y = inp_dict["inp1"]
3537            z = x.sin() + y.cos()
3538            if randomness != "error":
3539                z = z + torch.rand_like(z)
3540            return {"z": z, "tuple": (z, z + 1)}
3541
3542        f3_args = ({"inp": x, "inp1": y},)
3543
3544        for fn, args in ((f, f_args), (f1, f1_args), (f2, f2_args), (f3, f3_args)):
3545            rs = torch.get_rng_state()
3546            expected = vmap(
3547                vmap(fn, in_dims=in_dim, out_dims=out_dim, randomness=randomness),
3548                in_dims=in_dim,
3549                out_dims=out_dim,
3550                randomness=randomness,
3551            )(*args)
3552            for chunk_size in (1, 2, 3, 4, 7, 10, 16, 100):
3553                torch.set_rng_state(rs)
3554                actual = vmap(
3555                    vmap(
3556                        fn,
3557                        in_dims=in_dim,
3558                        out_dims=out_dim,
3559                        randomness=randomness,
3560                        chunk_size=chunk_size,
3561                    ),
3562                    in_dims=in_dim,
3563                    out_dims=out_dim,
3564                    randomness=randomness,
3565                    chunk_size=chunk_size,
3566                )(*args)
3567                self.assertEqual(actual, expected)
3568
3569
3570instantiate_parametrized_tests(TestVmapOperators)
3571
3572
3573def construct_v(output, batch_size, contig=False):
3574    if contig:
3575        return torch.randn(
3576            batch_size, *output.shape, dtype=output.dtype, device=output.device
3577        )
3578    result = torch.randn(
3579        *output.shape, batch_size, dtype=output.dtype, device=output.device
3580    )
3581    return result.movedim(-1, 0)
3582
3583
3584def as_tuple(x):
3585    if isinstance(x, tuple):
3586        return x
3587    elif isinstance(x, list):
3588        return tuple(x)
3589    else:
3590        return (x,)
3591
3592
3593def differentiable(args):
3594    return tuple(
3595        arg
3596        for arg in as_tuple(args)
3597        if isinstance(arg, torch.Tensor) and arg.requires_grad
3598    )
3599
3600
3601def _get_rand_no_zeros(*args, **kwargs):
3602    requires_grad = kwargs.get("requires_grad", False)
3603    kwargs_without_requires_grad = kwargs.copy()
3604    kwargs_without_requires_grad["requires_grad"] = False
3605    result = torch.rand(*args, **kwargs_without_requires_grad)
3606    return result.clamp_min_(0.1).requires_grad_(requires_grad)
3607
3608
3609@markDynamoStrictTest
3610class TestVmapBatchedGradient(Namespace.TestVmapBase):
3611    def _vmap_test(self, *args, **kwargs):
3612        return _vmap_test(self, *args, **kwargs)
3613
3614    # Tests batched gradient computation of outputs = op(*args, **kwargs)
3615    # by comparing it to a sequential map+stack fallback.
3616    #
3617    # output_process_fn: a function that maps the outputs to the part
3618    #       that should be differentiated.
3619    # batch_size: the batch dim size for the batched grad
3620    def _batched_grad_test(
3621        self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
3622    ):
3623        if kwargs is None:
3624            kwargs = {}
3625        outputs = op(*args, **kwargs)
3626        outputs = differentiable(output_process_fn(outputs))
3627        for contig in [True, False]:
3628            batched_vectors = tuple(
3629                construct_v(out, batch_size, contig) for out in outputs
3630            )
3631
3632            def vector_jacobian_product(*vectors):
3633                return torch.autograd.grad(
3634                    outputs, differentiable(args), vectors, retain_graph=True
3635                )
3636
3637            self._vmap_test(
3638                vector_jacobian_product, batched_vectors, check_propagates_grad=False
3639            )
3640
3641    # Tests batched second grad computation of outputs = op(*args, **kwargs).
3642    # by comparing it to a sequential map+stack fallback.
3643    #
3644    # output_process_fn: a function that maps the outputs to the part
3645    #       that should be differentiated.
3646    # batch_size: the batch dim size for the batched grad
3647    #
3648    # NB: we only test computing batched gradients in the second gradient
3649    # computation. One specific use case that does this is computing the hessian
3650    # matrix of a scalar-valued function; this is useful in Bayesian Logistic
3651    # Regression.
3652    # It might be useful to have a test that computes batched first gradients and
3653    # then uses those to compute batched second gradients in the future.
3654    def _batched_grad_grad_test(
3655        self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
3656    ):
3657        if kwargs is None:
3658            kwargs = {}
3659        outputs = op(*args, **kwargs)
3660        outputs = differentiable(output_process_fn(outputs))
3661        ones = tuple(torch.ones_like(out) for out in outputs)
3662        # Same thing as summing together all of the outputs and calling .backward()
3663        first_grads = torch.autograd.grad(
3664            outputs, differentiable(args), ones, create_graph=True
3665        )
3666        first_grads = differentiable(first_grads)
3667        self.assertNotEqual(
3668            len(first_grads), 0, "None of the first grads depend on the input!"
3669        )
3670
3671        for contig in [True, False]:
3672            batched_vectors = tuple(
3673                construct_v(grad, batch_size, contig) for grad in first_grads
3674            )
3675
3676            def vector_hessian_product(*vectors):
3677                outputs = torch.autograd.grad(
3678                    first_grads,
3679                    differentiable(args),
3680                    vectors,
3681                    retain_graph=True,
3682                    allow_unused=True,
3683                )
3684                outputs = tuple(out for out in outputs if out is not None)
3685                assert len(outputs) > 0
3686                return outputs
3687
3688            self._vmap_test(
3689                vector_hessian_product, batched_vectors, check_propagates_grad=False
3690            )
3691
3692    def _test_arithmetic(self, op, device, test_grad_grad=True):
3693        x = torch.randn(2, 3, requires_grad=True, device=device)
3694        y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
3695        scalar = 3.14
3696        self._batched_grad_test(op, (x, y))
3697        self._batched_grad_test(op, (scalar, y))
3698        self._batched_grad_test(op, (x, scalar))
3699
3700        if test_grad_grad:
3701            self._batched_grad_grad_test(op, (x, y))
3702
3703    def test_add(self, device):
3704        self._test_arithmetic(torch.add, device, test_grad_grad=False)
3705        self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)
3706
3707    def test_sub(self, device):
3708        self._test_arithmetic(torch.sub, device, test_grad_grad=False)
3709        self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)
3710
3711    def test_mul(self, device):
3712        self._test_arithmetic(torch.mul, device)
3713        self._test_arithmetic(lambda x, y: x * y, device)
3714
3715    def test_div(self, device):
3716        self._test_arithmetic(torch.div, device)
3717        self._test_arithmetic(lambda x, y: x / y, device)
3718
3719    def test_binary_cross_entropy(self, device):
3720        x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True))
3721        target = torch.rand(3, 2, device=device)
3722
3723        op = functools.partial(F.binary_cross_entropy, target=target)
3724
3725        self._batched_grad_test(op, (x,), {})
3726        self._batched_grad_grad_test(op, (x,), {})
3727
3728    def test_log_softmax(self, device):
3729        op = functools.partial(torch.log_softmax, dim=-1)
3730        x = torch.randn(3, 2, device=device, requires_grad=True)
3731
3732        self._batched_grad_test(op, (x,), {})
3733        self._batched_grad_grad_test(op, (x,), {})
3734
3735    def test_expand(self, device):
3736        x = torch.randn(2, 3, device=device, requires_grad=True)
3737
3738        def op(x):
3739            return x.expand(5, 5, 2, 3)
3740
3741        self._batched_grad_test(op, (x,))
3742
3743    @allowVmapFallbackUsage
3744    def test_index(self, device):
3745        x = torch.randn(2, 3, requires_grad=True, device=device)
3746        index = torch.tensor([[0, 0], [1, 1]], device=device)
3747
3748        def op(x):
3749            y = x * x
3750            return y[index]
3751
3752        self._batched_grad_test(op, (x,))
3753        self._batched_grad_grad_test(op, (x,))
3754
3755    def test_lgamma(self, device):
3756        x = torch.randn(2, 3, requires_grad=True, device=device)
3757        self._batched_grad_test(Tensor.lgamma, (x,))
3758        self._batched_grad_grad_test(Tensor.lgamma, (x,))
3759
3760    def test_log(self, device):
3761        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
3762        self._batched_grad_test(torch.log, (x,))
3763        self._batched_grad_grad_test(torch.log, (x,))
3764
3765    def test_logsumexp(self, device):
3766        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
3767
3768        def op(x):
3769            return torch.logsumexp(x, -1)
3770
3771        self._batched_grad_test(op, (x,))
3772        self._batched_grad_grad_test(op, (x,))
3773
3774    def test_log1p(self, device):
3775        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
3776        self._batched_grad_test(torch.log1p, (x,))
3777        self._batched_grad_grad_test(torch.log1p, (x,))
3778
3779    @allowVmapFallbackUsage
3780    def test_max(self, device):
3781        x = torch.randn(2, 3, requires_grad=True, device=device)
3782        self._batched_grad_test(torch.max, (x,))
3783
3784    @allowVmapFallbackUsage
3785    def test_median(self, device):
3786        x = torch.randn(2, 3, requires_grad=True, device=device)
3787        self._batched_grad_test(torch.median, (x,))
3788
3789    @allowVmapFallbackUsage
3790    def test_min(self, device):
3791        x = torch.randn(2, 3, requires_grad=True, device=device)
3792        self._batched_grad_test(torch.min, (x,))
3793
3794    def test_permute(self, device):
3795        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
3796
3797        def op(x):
3798            return x.permute(2, 0, 1)
3799
3800        self._batched_grad_test(op, (x,))
3801
3802    def test_reshape(self, device):
3803        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
3804
3805        def op(x):
3806            return x.reshape([2 * 3, 5])
3807
3808        self._batched_grad_test(op, (x,))
3809
3810    def test_sigmoid(self, device):
3811        x = torch.randn(2, 3, requires_grad=True, device=device)
3812        self._batched_grad_test(Tensor.sigmoid, (x,))
3813        self._batched_grad_grad_test(Tensor.sigmoid, (x,))
3814
3815    def test_stack(self, device):
3816        x = torch.randn(2, 3, device=device, requires_grad=True)
3817        y = torch.randn(2, 3, device=device, requires_grad=True)
3818
3819        def op(x, y):
3820            return torch.stack([x, y])
3821
3822        self._batched_grad_test(op, (x, y))
3823
3824    def test_select(self, device):
3825        x = torch.randn(2, 3, device=device, requires_grad=True)
3826        self._batched_grad_test(lambda x: x[1], (x,))
3827        self._batched_grad_test(lambda x: x.select(1, 2), (x,))
3828        self._batched_grad_test(lambda x: x.select(-1, 0), (x,))
3829
3830    def test_slice(self, device):
3831        x = torch.randn(2, 3, 5, device=device, requires_grad=True)
3832        self._batched_grad_test(lambda x: x[0:1], (x,))
3833        self._batched_grad_test(lambda x: x[:, 1:3], (x,))
3834        self._batched_grad_test(lambda x: x[..., 1:3], (x,))
3835
3836    def test_trace(self, device):
3837        x = torch.randn(2, 3, device=device, requires_grad=True)
3838        self._batched_grad_test(Tensor.trace, (x,))
3839
3840        x = torch.randn(3, 2, 2, device=device)
3841
3842        def sum_grad_trace(x):
3843            return grad(torch.trace)(x).sum()
3844
3845        output = vmap(grad(sum_grad_trace))(x)
3846        self.assertEqual(output, torch.zeros_like(output))
3847
3848    def test_where(self, device):
3849        x = torch.randn(3, 2, device=device)
3850        y = torch.ones(3, 2, device=device)
3851
3852        def f(x, y):
3853            return torch.where(x > 0, x, y)
3854
3855        # Check that there is no runtime error, exactness tests are done with opinfo
3856        vmap(f)(x, y)
3857
3858        x = torch.randint(0, 2, size=(4, 3), dtype=torch.float)
3859
3860        def f(t):
3861            return torch.where(t)
3862
3863        with self.assertRaisesRegex(
3864            RuntimeError, r"Attempted to vmap over aten::where"
3865        ):
3866            vmap(f)(x)
3867
3868    def test_threshold(self, device):
3869        x = torch.randn(2, 3, device=device, requires_grad=True)
3870        self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
3871
3872    @parametrize("backend", PLATFORM_SPECIFIC_SDPA)
3873    def test_sdpa(self, device, backend):
3874        if device == "cpu":
3875            raise unittest.SkipTest("This test is only for CUDA for now")
3876
3877        def T(*args):
3878            return torch.randn(*args, dtype=torch.float16, device=device)
3879
3880        backend_ctx = sdpa_kernel([backend])
3881        with backend_ctx:
3882            for batching in [
3883                (True, True, True),
3884                (True, False, False),
3885                (False, True, True),
3886            ]:
3887                size = [8, 4, 128, 64]
3888                if batching[0]:
3889                    query = T(3, *size)
3890                else:
3891                    query = T(*size)
3892                if batching[1]:
3893                    key = T(3, *size)
3894                else:
3895                    key = T(*size)
3896                if batching[2]:
3897                    value = T(3, *size)
3898                else:
3899                    value = T(*size)
3900                in_dims = tuple(0 if b else None for b in batching)
3901                attention = F.scaled_dot_product_attention
3902
3903                self._vmap_test(
3904                    attention,
3905                    (query, key, value),
3906                    in_dims=in_dims,
3907                )
3908                # Backwards test doesn't work yet
3909                # self._batched_grad_test(
3910                #     lambda query, key, value: F.scaled_dot_product_attention(
3911                #         query, key, value
3912                #     ),
3913                #     (query, key, value),
3914                # )
3915
3916            B = 4
3917            query = torch.rand(4, 32, B, 8, 128, dtype=torch.float16, device=device)
3918            key = torch.rand(4, B, 32, 8, 128, dtype=torch.float16, device=device)
3919            value = torch.rand(4, 32, 8, 128, dtype=torch.float16, device=device)
3920            self._vmap_test(
3921                F.scaled_dot_product_attention,
3922                (query, key, value),
3923                in_dims=(2, 1, None),
3924            )
3925
3926    @parametrize("backend", PLATFORM_SPECIFIC_SDPA)
3927    @parametrize("randomness", ["error", "same", "different"])
3928    def test_randomness(self, device, randomness, backend):
3929        if device == "cpu":
3930            raise unittest.SkipTest("This test is only for CUDA for now")
3931        backend_ctx = sdpa_kernel([backend])
3932        with backend_ctx:
3933            B = 4
3934            query = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device)
3935            key = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device)
3936            value = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device)
3937
3938            def f(q, k, v, dropout):
3939                return F.scaled_dot_product_attention(q, k, v, dropout_p=dropout)
3940
3941            # No matter the randomness mode, dropout=0.0 should pass
3942            vmap(
3943                functools.partial(f, dropout=0.0),
3944                in_dims=(0, 0, 0),
3945                randomness=randomness,
3946            )(query, key, value)
3947
3948            fail_with_randomness = randomness == "error"
3949            if backend != SDPBackend.MATH:
3950                fail_with_randomness |= randomness == "same"
3951            context = (
3952                self.assertRaises(RuntimeError)
3953                # We currently don't support randomness == "same", and "error" should always error with randomness
3954                if fail_with_randomness
3955                else contextlib.nullcontext()
3956            )
3957            with context:
3958                vmap(
3959                    functools.partial(f, dropout=0.5),
3960                    in_dims=(0, 0, 0),
3961                    randomness=randomness,
3962                )(query, key, value)
3963
3964    @allowVmapFallbackUsage
3965    def test_inplace_view(self, device):
3966        leaf = torch.randn(4, 5, requires_grad=True)
3967
3968        def func(leaf):
3969            # Make sure the function is non-trivially twice differentiable
3970            base = leaf * leaf
3971            view = base[0]
3972            view.cos_()
3973            return view
3974
3975        self._batched_grad_test(func, (leaf,), {})
3976        self._batched_grad_grad_test(func, (leaf,), {})
3977
3978    @allowVmapFallbackUsage
3979    def test_inplace_manyview(self, device):
3980        leaf = torch.randn(4, 4, 5, requires_grad=True)
3981
3982        def func(leaf):
3983            # Make sure the function is non-trivially twice differentiable
3984            base = leaf * leaf
3985            view = base.transpose(0, 2)
3986            view = view[1]
3987            view = view.diagonal()
3988            view = view[::2]
3989            view.cos_()
3990            return view
3991
3992        self._batched_grad_test(func, (leaf,), {})
3993        self._batched_grad_grad_test(func, (leaf,), {})
3994
3995    def test_diagonal(self, device):
3996        x = torch.randn(4, 5, device=device, requires_grad=True)
3997        self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,))
3998
3999        x = torch.randn(3, 4, 5, device=device, requires_grad=True)
4000        self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))
4001
4002    @allowVmapFallbackUsage
4003    def test_unrelated_output(self, device):
4004        B0 = 3
4005        x = torch.randn([], requires_grad=True)
4006        y = torch.randn([], requires_grad=True)
4007        gy = torch.randn(B0, requires_grad=True)
4008
4009        def vjp(v):
4010            (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
4011            return torch.zeros_like(x) if res is None else res
4012
4013        result = vmap(vjp)(gy)
4014        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
4015
4016    @allowVmapFallbackUsage
4017    def test_unrelated_output_multiple_grad(self, device):
4018        B0 = 3
4019        x = torch.randn([], requires_grad=True)
4020        y = torch.randn([], requires_grad=True)
4021        gy = torch.randn(B0, requires_grad=True)
4022
4023        def vjp(v):
4024            (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
4025            return torch.zeros_like(x) if res is None else res
4026
4027        _ = vjp(gy[0])
4028        result = vmap(vjp)(gy)
4029        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
4030
4031
4032def discover_variants(opinfo):
4033    aliases = []
4034    inplace_variants = []
4035
4036    if opinfo.inplace_variant:
4037        inplace_variants.append(opinfo.inplace_variant)
4038
4039    aliases.append(opinfo.op)
4040    for alias in opinfo.aliases:
4041        aliases.append(alias.op)
4042        if alias.inplace_variant:
4043            inplace_variants.append(alias.inplace_variant)
4044    return aliases, inplace_variants
4045
4046
4047# TODO: enable this when we get a bit closer to getting torch.vmap x torch.compile working.
4048# @markDynamoStrictTest
4049@unMarkDynamoStrictTest
4050class TestVmapOperatorsOpInfo(TestCase):
4051    def vmap_outplace_test(
4052        self,
4053        func,
4054        args,
4055        kwargs,
4056        in_dims,
4057        check_shape_only=False,
4058        postprocess_fn=None,
4059        out_dim=0,
4060    ):
4061        for vmap_out, loop_out in compute_quantities_for_vmap_test(
4062            func, args, kwargs, in_dims, out_dim=out_dim
4063        ):
4064            if postprocess_fn is not None:
4065                loop_out = postprocess_fn(loop_out)
4066                vmap_out = postprocess_fn(vmap_out)
4067            if check_shape_only:
4068                self.assertEqual(vmap_out.shape, loop_out.shape)
4069                continue
4070            self.assertEqual(vmap_out, loop_out)
4071
4072    def vmap_inplace_test(
4073        self, func, args, kwargs, in_dims, postprocess_fn=None, out_dim=0
4074    ):
4075        # NB: This test assumes that the first argument is being modified.
4076        # This is OK because it's what every other OpInfo-based test assumes,
4077        # but it is going to need a more robust solution eventually.
4078        if in_dims[0] is None:
4079            # Check that we correctly raise an error when vmap is impossible
4080            # on the in-place operation
4081            with self.assertRaises(RuntimeError):
4082                for _ in compute_quantities_for_vmap_test(
4083                    func,
4084                    args,
4085                    kwargs,
4086                    in_dims,
4087                    out_dim=out_dim,
4088                    compute_loop_out=False,
4089                    clone_inputs=True,
4090                ):
4091                    pass
4092            return
4093        for vmap_out, loop_out in compute_quantities_for_vmap_test(
4094            func,
4095            args,
4096            kwargs,
4097            in_dims,
4098            clone_inputs=True,
4099            out_dim=out_dim,
4100        ):
4101            if postprocess_fn is not None:
4102                loop_out = postprocess_fn(loop_out)
4103                vmap_out = postprocess_fn(vmap_out)
4104            self.assertEqual(vmap_out, loop_out)
4105
4106    def opinfo_vmap_test(
4107        self,
4108        device,
4109        dtype,
4110        op,
4111        check_has_batch_rule,
4112        skip_inplace=(),
4113        postprocess_fn=None,
4114    ):
4115        def test():
4116            # Error inputs check
4117            if op.error_inputs_func is not None:
4118                error_inputs = op.error_inputs(device)
4119                for error_input in error_inputs:
4120                    sample_input = error_input.sample_input
4121                    args = (sample_input.input,) + tuple(sample_input.args)
4122                    kwargs = sample_input.kwargs
4123                    for batched_args, in_dims, _ in generate_vmap_inputs(args, {}):
4124                        with self.assertRaises(Exception):
4125                            vmap(op, in_dims)(*batched_args, **kwargs)
4126
4127            # Sample inputs check
4128            sample_inputs_op = {
4129                # Take too long with reference inputs
4130                "special.chebyshev_polynomial_t",
4131                "special.chebyshev_polynomial_u",
4132                "special.chebyshev_polynomial_v",
4133                "special.chebyshev_polynomial_w",
4134                "special.hermite_polynomial_he",
4135                "special.laguerre_polynomial_l",
4136                "special.legendre_polynomial_p",
4137                "special.shifted_chebyshev_polynomial_t",
4138                "special.shifted_chebyshev_polynomial_u",
4139                "special.shifted_chebyshev_polynomial_v",
4140                "special.shifted_chebyshev_polynomial_w",
4141            }
4142            if op.name in sample_inputs_op:
4143                sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
4144            else:
4145                sample_inputs_itr = op.reference_inputs(
4146                    device, dtype, requires_grad=False
4147                )
4148            aliases, inplace_aliases = discover_variants(op)
4149            check_shape_only = op.name in ("empty_like", "new_empty")
4150            for sample_input in sample_inputs_itr:
4151                args = (sample_input.input,) + sample_input.args
4152                if not any(isinstance(arg, torch.Tensor) for arg in args):
4153                    # Atleast one tensor required for vmap.
4154                    continue
4155                kwargs = sample_input.kwargs
4156                is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs)
4157                out_dim = 0
4158                if op.name == "NumpySplitCopyWithIntCustomOp":
4159                    # special case for this custom op
4160                    def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim):
4161                        return [0 for _ in range(len(splits) + 1)], None
4162
4163                    out_dim = sample_vmap_out_dim_numpy_split_copy_with_int(*args)
4164                for batched_args, in_dims, _ in generate_vmap_inputs(
4165                    args, {}, is_batch_norm_and_training=is_batch_norm_and_training
4166                ):
4167                    for func in aliases:
4168                        self.vmap_outplace_test(
4169                            func,
4170                            batched_args,
4171                            kwargs,
4172                            in_dims,
4173                            check_shape_only,
4174                            postprocess_fn,
4175                            out_dim=out_dim,
4176                        )
4177                    if op.name in skip_inplace:
4178                        continue
4179                    if not is_valid_inplace_sample_input(
4180                        sample_input, op, op.inplace_variant
4181                    ):
4182                        continue
4183                    for func in inplace_aliases:
4184                        self.vmap_inplace_test(
4185                            func, batched_args, kwargs, in_dims, postprocess_fn
4186                        )
4187
4188        if check_has_batch_rule:
4189            check_vmap_fallback(self, test, op)
4190        else:
4191            test()
4192
4193    vmap_fail = {
4194        # -------------------- ALLOWED FAILURES --------------------------------
4195        # These are things that we either cannot fix or are not actually problems
4196        xfail("resize_"),
4197        xfail("resize_as_"),
4198        xfail("to_sparse"),
4199        xfail("__getitem__"),  # dynamic mask
4200        xfail("index_put"),  # dynamic mask
4201        xfail(
4202            "nn.functional.dropout"
4203        ),  # works, can't check against for loop because of randomness inconsistency
4204        xfail("nn.functional.scaled_dot_product_attention"),  # randomness
4205        xfail("nn.functional.multi_head_attention_forward"),  # randomness
4206        xfail("masked_select"),  # dynamic op
4207        xfail("nonzero"),  # dynamic op
4208        xfail("unique", ""),  # dynamic op
4209        xfail("unique_consecutive", ""),  # dynamic op
4210        xfail("allclose"),  # returns a boolean
4211        xfail("uniform"),  # randomness is tested separately
4212        xfail("rand_like"),  # randomness is tested separately
4213        xfail("randint_like"),  # randomness is tested separately
4214        xfail("randn_like"),  # randomness is tested separately
4215        xfail("bernoulli", ""),  # randomness is tested separately
4216        xfail("normal", ""),  # randomness is tested separately
4217        xfail("normal", "number_mean"),  # randomness is tested separately
4218        xfail("multinomial", ""),  # randomness
4219        xfail("nn.functional.embedding", ""),  # we only support some cases
4220        xfail("nn.functional.rrelu"),  # randomness
4221        xfail("nn.functional.dropout2d", ""),  # randomness
4222        xfail("nn.functional.dropout3d", ""),  # randomness
4223        xfail("nn.functional.alpha_dropout", ""),  # randomness
4224        xfail("nn.functional.feature_alpha_dropout", "with_train"),  # randomness
4225        xfail("as_strided"),  # Our test runner can't handle this; manual test exists
4226        xfail("as_strided_copy"),
4227        xfail(
4228            "as_strided_scatter"
4229        ),  # no batching rule implemented, default doesnt work
4230        skip(
4231            "new_empty_strided"
4232        ),  # empty tensor data is garbage so it's hard to make comparisons with it
4233        xfail("nn.functional.fractional_max_pool3d"),  # randomness
4234        xfail("nn.functional.fractional_max_pool2d"),  # randomness
4235        xfail("pca_lowrank", ""),  # random operation
4236        xfail("svd_lowrank", ""),  # random operation
4237        xfail("sparse.sampled_addmm"),  # sparse
4238        xfail("sparse.mm", "reduce"),  # sparse
4239        xfail(
4240            "NumpyCubeNotComposableAutogradFunction"
4241        ),  # Not composable autograd.Function
4242        skip("_softmax_backward_data"),
4243        skip(
4244            "linalg.eigh", ""
4245        ),  # not always return the same result for the same input, see test_linalg_eigh for manual test
4246        skip("to"),  # RuntimeError: required rank 4 tensor to use channels_last format
4247        # UnimplementedError: data-dependent operators cannot be vmapped
4248        xfail("NumpyNonzeroCustomOp"),
4249        xfail("NumpyNMSCustomOp"),
4250        # ----------------------------------------------------------------------
4251        # ---------------------------- BUGS ------------------------------------
4252        # entries in here don't work and need to be fixed.
4253        # Each one of these is a bug
4254        decorate("frexp", decorator=skipIfTorchDynamo()),
4255        xfail("clamp_min", ""),  # Exception not raised on error input
4256        xfail("clamp_max", ""),  # Exception not raised on error input
4257        xfail(
4258            "view_as_complex"
4259        ),  # RuntimeError: Tensor must have a last dimension with stride 1
4260        xfail("tensor_split"),  # data_ptr
4261        xfail(
4262            "histogramdd"
4263        ),  # expected Tensor as element 0 in argument 0, but got tuple
4264        xfail("nn.functional.gaussian_nll_loss"),  # data-dependent control flow error
4265        xfail(
4266            "nn.functional.embedding_bag"
4267        ),  # embedding renorm vmap inplace incompatible
4268        xfail("narrow"),  # Batching rule not implemented for aten::narrow.Tensor
4269        # required rank 4 tensor to use channels_last format
4270        xfail("bfloat16"),
4271        xfail("bool"),
4272        xfail("byte"),
4273        xfail("char"),
4274        xfail("double"),
4275        xfail("float"),
4276        xfail("half"),
4277        xfail("int"),
4278        xfail("long"),
4279        xfail("short"),
4280        xfail("cdouble"),
4281        xfail("cfloat"),
4282        xfail(
4283            "jiterator_binary", device_type="cuda"
4284        ),  # NYI: querying is_contiguous inside of vmap
4285        xfail(
4286            "jiterator_binary_return_by_ref", device_type="cuda"
4287        ),  # NYI: querying is_contiguous inside of vmap
4288        xfail(
4289            "jiterator_4inputs_with_extra_args", device_type="cuda"
4290        ),  # NYI: querying is_contiguous inside of vmap
4291        xfail(
4292            "equal", ""
4293        ),  # TypeError: object of type 'bool' has no len(); likely testrunner problem
4294        xfail(
4295            "jiterator_unary", device_type="cuda"
4296        ),  # NYI: querying is_contiguous inside of vmap
4297        xfail(
4298            "jiterator_2inputs_2outputs", device_type="cuda"
4299        ),  # NYI: querying is_contiguous inside of vmap
4300        # ---------------------------------------------------------------------
4301        # TypeError: expected Tensor as element 0 in argument 0, but got NotImplementedType
4302        xfail("__rsub__"),
4303        # RuntimeError: Batching rule not implemented for aten::moveaxis.int;
4304        # the fallback path doesn't work on out= or view ops.
4305        xfail("movedim"),
4306        # RuntimeError: NYI: querying is_contiguous inside of vmap for
4307        # memory_format other than torch.contiguous_format
4308        xfail("contiguous"),
4309        # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only supported
4310        # with memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
4311        xfail("clone"),
4312        # RuntimeError: When vmap-ing torch.nn.functional.one_hot,
4313        # please provide an explicit positive num_classes argument.
4314        xfail("nn.functional.one_hot"),
4315        # RuntimeError: Expected all tensors to be on the same device,
4316        # but found at least two devices, cuda:0 and cpu!
4317        xfail("eq", device_type="cuda"),
4318        xfail("ge", device_type="cuda"),
4319        xfail("gt", device_type="cuda"),
4320        xfail("le", device_type="cuda"),
4321        xfail("lt", device_type="cuda"),
4322        xfail("ne", device_type="cuda"),
4323        # RuntimeError: aten::_flash_attention_forward hit the vmap fallback which is currently disabled
4324        xfail("torch.ops.aten._flash_attention_forward"),
4325    }
4326
4327    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
4328    @ops(
4329        op_db + additional_op_db + autograd_function_db + custom_op_db,
4330        dtypes=OpDTypes.any_one,
4331    )
4332    @opsToleranceOverride(
4333        "TestVmapOperatorsOpInfo",
4334        "test_vmap_exhaustive",
4335        (
4336            tol1(
4337                "linalg.det",
4338                {torch.float32: tol(atol=1e-04, rtol=1e-04)},
4339                device_type="cuda",
4340            ),
4341            # The following is often flaky, but just on windows.
4342            # We should investigate if it's actually a problem or not.
4343            tol1(
4344                "nn.functional.conv_transpose3d",
4345                {torch.float32: tol(atol=1e-04, rtol=1e-02)},
4346                device_type="cuda",
4347            ),
4348        ),
4349    )
4350    @toleranceOverride(
4351        {
4352            torch.float32: tol(atol=1e-04, rtol=1e-04),
4353            torch.complex64: tol(atol=1e-04, rtol=1e-04),
4354        }
4355    )
4356    @skipOps(
4357        "TestVmapOperatorsOpInfo",
4358        "test_vmap_exhaustive",
4359        vmap_fail.union(
4360            {
4361                # RuntimeError: Batch norm got a batched tensor as input while the running_mean or running_var,
4362                # which will be updated in place, were not batched.
4363                xfail("native_batch_norm"),
4364                xfail("_native_batch_norm_legit"),
4365                # TODO: implement batching rule
4366                xfail("_batch_norm_with_update"),
4367                xfail("tril"),  # Exception not raised on error input
4368                xfail("triu"),  # Exception not raised on error input
4369                xfail("as_strided", "partial_views"),
4370                # RuntimeError: output with shape [4, 4] doesn't match the broadcast shape [1, 4, 4]
4371                xfail("addcdiv"),
4372                xfail("addcmul"),
4373                xfail("clamp"),
4374                xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
4375                # TypeError: expected Tensor as element 0 in argument 0, but got float
4376                xfail("item"),
4377            }
4378        ),
4379    )
4380    def test_vmap_exhaustive(self, device, dtype, op):
4381        # needs to be fixed
4382        inplace_failure_list = ()
4383        self.opinfo_vmap_test(
4384            device,
4385            dtype,
4386            op,
4387            check_has_batch_rule=False,
4388            skip_inplace=inplace_failure_list,
4389        )
4390
4391    @with_tf32_off
4392    @ops(
4393        op_db + additional_op_db + autograd_function_db + custom_op_db,
4394        dtypes=OpDTypes.any_one,
4395    )
4396    @opsToleranceOverride(
4397        "TestVmapOperatorsOpInfo",
4398        "test_op_has_batch_rule",
4399        (
4400            tol1(
4401                "linalg.det",
4402                {torch.float32: tol(atol=1e-04, rtol=1e-04)},
4403                device_type="cuda",
4404            ),
4405        ),
4406    )
4407    @toleranceOverride(
4408        {
4409            torch.float32: tol(atol=1e-04, rtol=1e-04),
4410            torch.complex64: tol(atol=1e-04, rtol=1e-04),
4411        }
4412    )
4413    @skipOps(
4414        "TestVmapOperatorsOpInfo",
4415        "test_op_has_batch_rule",
4416        vmap_fail.union(
4417            {
4418                xfail("as_strided", "partial_views"),
4419                skip(
4420                    "to"
4421                ),  # RuntimeError: required rank 4 tensor to use channels_last format
4422                xfail("fill"),
4423                # Batch norm got a batched tensor as input while the running_mean or running_var,
4424                # which will be updated in place, were not batched.
4425                xfail("native_batch_norm"),
4426                xfail("_native_batch_norm_legit"),
4427                # TODO: implement batching rule
4428                xfail("_batch_norm_with_update"),
4429                xfail("histogram"),
4430                xfail("scatter_reduce", "sum"),
4431                xfail("scatter_reduce", "mean"),
4432                xfail("scatter_reduce", "amax"),
4433                xfail("scatter_reduce", "amin"),
4434                # `index_put` OpInfo in pytorch/pytorch has
4435                # masked index as input which is not supported
4436                xfail("index_put", ""),
4437                xfail("isin"),
4438                xfail("masked_fill"),
4439                xfail("masked_scatter"),
4440                xfail("masked_select"),
4441                xfail("nanquantile"),
4442                xfail("ormqr"),
4443                xfail("put"),
4444                xfail("quantile"),
4445                xfail("renorm"),
4446                xfail("resize_as_"),
4447                xfail("take"),
4448                xfail("tensor_split"),
4449                xfail("to_sparse"),
4450                # TypeError: expected Tensor as element 0 in argument 0, but got float
4451                xfail("item"),
4452                xfail("tril"),  # Exception not raised on error input
4453                xfail("triu"),  # Exception not raised on error input
4454                xfail("__getitem__", ""),
4455                xfail("count_nonzero"),
4456                xfail(
4457                    "nn.functional.dropout"
4458                ),  # works, can't check against for loop because of randomness inconsistency
4459                xfail("nn.functional.scaled_dot_product_attention"),  # randomness
4460                xfail("nn.functional.multi_head_attention_forward"),  # randomness
4461                xfail("torch.ops.aten._efficient_attention_forward"),  # outputs ints
4462                xfail("resize_"),
4463                xfail("view_as_complex"),
4464                xfail("matrix_exp"),
4465                xfail("fft.ihfft2"),
4466                xfail("fft.ihfftn"),
4467                xfail("allclose"),
4468                xfail("argwhere"),
4469                xfail("unique_consecutive"),
4470                xfail("unique"),
4471                xfail("nn.functional.ctc_loss"),
4472                xfail("nn.functional.gaussian_nll_loss"),
4473                xfail("histc"),
4474                xfail("as_strided"),
4475                xfail("as_strided_copy"),
4476                xfail("t_copy"),
4477                xfail("unsqueeze_copy"),
4478                xfail("istft"),
4479                xfail("nonzero"),
4480                xfail("nn.functional.fractional_max_pool2d"),
4481                xfail("stft"),
4482                xfail("isclose"),
4483                xfail("nn.functional.fractional_max_pool3d"),
4484                xfail("nn.functional.bilinear"),
4485                xfail("nn.functional.embedding_bag"),
4486                xfail("linalg.tensorsolve"),
4487                xfail("bernoulli", ""),
4488                xfail("nn.functional.feature_alpha_dropout", "with_train"),
4489                xfail("native_dropout_backward"),
4490                xfail("nn.functional.kl_div", ""),
4491                xfail("multinomial", ""),
4492                xfail("pca_lowrank", ""),
4493                xfail("normal", ""),
4494                xfail("nn.functional.dropout2d", ""),
4495                xfail("normal", "number_mean"),
4496                xfail("svd_lowrank", ""),
4497                xfail("diagflat", ""),
4498                xfail("special.log_ndtr"),
4499                xfail(
4500                    "narrow"
4501                ),  # Batching rule not implemented for aten::narrow.Tensor
4502                xfail("nn.functional.triplet_margin_loss", ""),
4503                xfail("nn.functional.pdist", ""),
4504                xfail("scatter_reduce", "sum"),
4505                xfail("scatter_reduce", "amax"),
4506                xfail("nn.functional.max_unpool1d", "grad"),
4507                xfail("nn.functional.multi_margin_loss", ""),
4508                xfail("scatter_reduce", "prod"),
4509                xfail("nn.functional.multilabel_margin_loss", ""),
4510                xfail("scatter_reduce", "amin"),
4511                xfail("nn.functional.max_unpool3d", "grad"),
4512                xfail("nn.functional.max_unpool2d", ""),
4513                xfail("nn.functional.max_unpool2d", "grad"),
4514                xfail("nn.functional.margin_ranking_loss", ""),
4515                xfail("nn.functional.max_unpool1d", ""),
4516                xfail("nn.functional.soft_margin_loss", ""),
4517                xfail("scatter_reduce", "mean"),
4518                xfail("nn.functional.max_unpool3d", ""),
4519                xfail("linalg.ldl_solve", "", device_type="cpu"),
4520                xfail("chalf", ""),
4521                xfail("clamp_max", ""),
4522                xfail("jiterator_binary_return_by_ref", device_type="cuda"),
4523                xfail("jiterator_unary", device_type="cuda"),
4524                xfail("jiterator_2inputs_2outputs", device_type="cuda"),
4525                xfail("special.airy_ai"),
4526                xfail("clamp_min", ""),
4527                xfail("sparse.sampled_addmm"),
4528                xfail("sparse.mm", "reduce"),
4529                xfail("special.chebyshev_polynomial_u"),
4530                xfail("_segment_reduce", "offsets"),
4531                xfail("index_reduce", "prod"),
4532                xfail("index_reduce", "mean"),
4533                xfail("index_reduce", "amin"),
4534                xfail("index_reduce", "amax"),
4535                xfail("special.laguerre_polynomial_l"),
4536                xfail("special.hermite_polynomial_h"),
4537                xfail("jiterator_binary", device_type="cuda"),
4538                xfail("jiterator_4inputs_with_extra_args", device_type="cuda"),
4539                xfail("_segment_reduce", "lengths"),
4540                xfail("lu_solve", ""),
4541                xfail("special.hermite_polynomial_he"),
4542                xfail("nn.functional.dropout3d", ""),
4543                xfail("special.chebyshev_polynomial_t"),
4544                xfail("as_strided_scatter", ""),
4545                xfail("equal", ""),
4546                xfail("linalg.lu", ""),
4547                skip("linalg.ldl_solve", ""),
4548                skip("_softmax_backward_data"),
4549                # One or more of the overload doesn't have a Batch rule.
4550                xfail("bincount"),
4551                # RuntimeError: Expected all tensors to be on the same device,
4552                # but found at least two devices, cuda:0 and cpu!
4553                xfail("ge", device_type="cuda"),
4554                xfail(
4555                    "searchsorted"
4556                ),  # aten::searchsorted.Scalar hit the vmap fallback which is currently disabled
4557            }
4558        ),
4559    )
4560    def test_op_has_batch_rule(self, device, dtype, op):
4561        # needs to be fixed
4562        inplace_failures = (
4563            "addbmm",
4564            "addcdiv",
4565            "addcmul",
4566            "addmm",
4567            "addmv",
4568            "addr",
4569            "baddbmm",
4570            "clamp",
4571            "conj_physical",
4572            "cumprod",
4573            "cumsum",
4574            "floor_divide",
4575            "fmod",
4576            "heaviside",
4577            "hypot",
4578            "igamma",
4579            "igammac",
4580            "index_copy",
4581            "ldexp",
4582            "lerp",
4583            "neg",
4584            "nextafter",
4585            "polygamma",
4586            "pow",
4587            "remainder",
4588            "scatter_add",
4589            "scatter",
4590            "square",
4591            "sub",
4592            "trunc",
4593            "xlogy",
4594        )
4595        self.opinfo_vmap_test(
4596            device, dtype, op, check_has_batch_rule=True, skip_inplace=inplace_failures
4597        )
4598
4599    def test_linalg_svd(self, device):
4600        # linalg_svd returns a tuple of three tensors, (U, S, Vh).
4601        # Given the same input, it may return different tensors,
4602        # because svd isn't unique. To test that the svd is correct, we multiply
4603        # U @ diag(S) @ Vh and check that the output from vmap matches the
4604        # output from a for-loop.
4605        def compute_A(out):
4606            U, S, Vh = out
4607            m = U.shape[-1]
4608            n = Vh.shape[-2]
4609            diag_S = S.new_zeros(*S.shape[:-1], m, n)
4610            diag_S.diagonal(offset=0, dim1=-2, dim2=-1).copy_(S)
4611            return U @ diag_S @ Vh
4612
4613        opinfos = [op for op in op_db if op.name == "linalg.svd"]
4614        assert len(opinfos) > 0
4615
4616        for op in opinfos:
4617            self.opinfo_vmap_test(
4618                device,
4619                torch.float,
4620                op,
4621                check_has_batch_rule=True,
4622                postprocess_fn=compute_A,
4623            )
4624
4625    def test_linalg_eigh(self, device):
4626        # linalg_svd returns two tensors, (Q, L).
4627        # Given the same input, it may return different tensors,
4628        # because the eig decomposition isn't unique.
4629        # To test that eigh is correct, we multiply
4630        # Q @ diag(L) @ Qh and check that the output from vmap matches the
4631        # output from a for-loop.
4632        def compute_A(out):
4633            L, Q = out
4634            n = Q.shape[-1]
4635            diag_L = L.new_zeros(*L.shape[:-1], n, n)
4636            diag_L.diagonal(offset=0, dim1=-2, dim2=-1).copy_(L)
4637            Qh = Q.transpose(-2, -1).conj()
4638            return Q @ diag_L @ Qh
4639
4640        opinfos = [op for op in op_db if op.name == "linalg.eigh"]
4641        assert len(opinfos) > 0
4642
4643        for op in opinfos:
4644            self.opinfo_vmap_test(
4645                device,
4646                torch.float,
4647                op,
4648                check_has_batch_rule=True,
4649                postprocess_fn=compute_A,
4650            )
4651
4652    @skipIfTorchDynamo()
4653    def test_slogdet(self, device):
4654        # There's no OpInfo for this
4655        def test():
4656            B = 2
4657            x = torch.randn(B, 5, 5, device=device)
4658            self.vmap_outplace_test(torch.slogdet, (x,), {}, (0,))
4659
4660        check_vmap_fallback(self, test, torch.slogdet)
4661
4662    def test_index_fill(self, device):
4663        # There's no OpInfo for these tests
4664
4665        B = 2
4666
4667        def test1():
4668            # negative dim
4669            x = torch.randn(B, 5, 5, device=device)
4670            dim = -2
4671            index = torch.tensor([[2, 3], [0, 4]], device=device)
4672            value = 5.0
4673            self.vmap_outplace_test(
4674                torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None)
4675            )
4676
4677        def test2():
4678            # self batched, self logical rank 1, index logical rank 1
4679            x = torch.zeros(B, 3, device=device)
4680            dim = 0
4681            index = torch.tensor([[0], [1]], device=device)
4682            for value in (1.0, torch.rand((), device=device)):
4683                self.vmap_outplace_test(
4684                    torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None)
4685                )
4686
4687        def test3():
4688            # self batched, self logical rank 1, index logical rank 0
4689            x = torch.zeros(B, 3, device=device)
4690            dim = 0
4691            index = torch.tensor([0, 1], device=device)
4692            for value in (1.0, torch.rand((), device=device)):
4693                self.vmap_outplace_test(
4694                    torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None)
4695                )
4696
4697        def test4():
4698            # self not batched, self logical rank 0, index logical rank 1
4699            x = torch.zeros([], device=device)
4700            dim = 0
4701            index = torch.tensor([[0], [0]], device=device)
4702            for value in (1.0, torch.rand((), device=device)):
4703                self.vmap_outplace_test(
4704                    torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None)
4705                )
4706
4707        def test5():
4708            # self not batched, self logical rank 0, index logical rank 0
4709            x = torch.zeros([], device=device)
4710            dim = 0
4711            index = torch.tensor([0, 0], device=device)
4712            for value in (1.0, torch.rand((), device=device)):
4713                self.vmap_outplace_test(
4714                    torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None)
4715                )
4716
4717        def test6():
4718            # self not batched, self logical rank 0, index logical rank 1
4719            x = torch.zeros(3, device=device)
4720            dim = 0
4721            index = torch.tensor([[0], [1]], device=device)
4722            for value in (1.0, torch.rand((), device=device)):
4723                self.vmap_outplace_test(
4724                    torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None)
4725                )
4726
4727        def test7():
4728            # self not batched, self logical rank 0, index logical rank 0
4729            x = torch.zeros(3, device=device)
4730            dim = 0
4731            index = torch.tensor([0, 1], device=device)
4732            for value in (1.0, torch.rand((), device=device)):
4733                self.vmap_outplace_test(
4734                    torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None)
4735                )
4736
4737        def test8():
4738            # self batched, self logical rank > 1, index logical rank 0
4739            x = torch.zeros(B, 3, 3, device=device)
4740            dim = 0
4741            index = torch.tensor([0, 1], device=device)
4742            for value in (1.0, torch.rand((), device=device)):
4743                self.vmap_outplace_test(
4744                    torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None)
4745                )
4746
4747        for test in (test1, test2, test3, test4, test5, test6, test7, test8):
4748            check_vmap_fallback(self, test, torch.index_fill)
4749
4750    def test_fill__Tensor(self, device):
4751        # There's no OpInfo for fill_.Tensor, so here's an extra test for it.
4752        def test():
4753            B = 2
4754            args = (torch.randn(B, 3, device=device), torch.randn(B))
4755            self.vmap_inplace_test(Tensor.fill_, args, {}, (0, 0))
4756
4757            args = (torch.randn(3, B, device=device), torch.randn(B))
4758            self.vmap_inplace_test(Tensor.fill_, args, {}, (-1, 0))
4759
4760            args = (torch.randn(3, device=device), torch.randn(B))
4761            self.vmap_inplace_test(Tensor.fill_, args, {}, (None, 0))
4762
4763            args = (torch.randn(3, B, device=device), torch.randn([]))
4764            self.vmap_inplace_test(Tensor.fill_, args, {}, (1, None))
4765
4766        check_vmap_fallback(self, test, Tensor.fill_)
4767
4768    def test_conv_double_backward(self, device):
4769        images = torch.randn(2, 1, 5, 5, device=device)
4770        weight = torch.randn(2, 1, 2, 2, device=device)
4771        bias = torch.randn(2, device=device)
4772        ggI = torch.randn_like(images)
4773        ggW = torch.randn_like(weight)
4774        ggb = torch.randn_like(bias)
4775        stride = (1, 1)
4776        padding = (0, 0)
4777        dilation = (1, 1)
4778        transposed = False
4779        output_padding = (0, 0)
4780        groups = 1
4781        output_mask = (True, True, True)
4782        gO = torch.randn_like(
4783            F.conv2d(images, weight, bias, stride, padding, dilation, groups)
4784        )
4785
4786        args = (
4787            ggI,
4788            ggW,
4789            ggb,
4790            gO,
4791            weight,
4792            images,
4793            stride,
4794            padding,
4795            dilation,
4796            transposed,
4797            output_padding,
4798            groups,
4799            output_mask,
4800        )
4801        op = torch.ops.aten._convolution_double_backward
4802
4803        generator = get_fallback_and_vmap_exhaustive(op, args, {})
4804        is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(
4805            0
4806        ) == (8, 6)
4807        atol, rtol = (1e-3, 1e-3) if is_cuda_sm86 else (1e-4, 1e-4)
4808
4809        def test():
4810            for loop_out, batched_out in generator:
4811                self.assertEqual(loop_out, batched_out, atol=atol, rtol=rtol)
4812
4813        check_vmap_fallback(self, test, op)
4814
4815    def test_isnan(self, device):
4816        test = functools.partial(_vmap_test, check_propagates_grad=False)
4817
4818        B, N, C, H, W = 2, 3, 24, 5, 7
4819        op = torch.isnan
4820
4821        x = torch.randn(B, N, C, H, W)
4822        x[x > 0] = float("nan")
4823        test(self, op, (x,), in_dims=(0))
4824
4825    def test_sum_scalar(self, device):
4826        x = torch.tensor([10.0], device=device)
4827        y = vmap(torch.sum)(x)
4828        self.assertEqual(y, x)
4829
4830        y = vmap(lambda x: x.sum(0))(x)
4831        self.assertEqual(y, x)
4832
4833        y = vmap(lambda x: x.sum(-1))(x)
4834        self.assertEqual(y, x)
4835
4836    def test_isinf(self, device):
4837        test = functools.partial(_vmap_test, check_propagates_grad=False)
4838
4839        B, N, C, H, W = 2, 3, 24, 5, 7
4840        op = torch.isinf
4841
4842        x = torch.randn(B, N, C, H, W)
4843        x[x > 0] = float("inf")
4844        test(self, op, (x,), in_dims=(0))
4845
4846    def test_foo_like(self, device):
4847        # vfdev-5: Probably, we can remove this line. Flake8 reported as unused
4848        # test = functools.partial(_vmap_test, check_propagates_grad=False)
4849
4850        B, N, C, H, W = 2, 3, 24, 5, 7
4851        for op in [torch.ones_like, torch.zeros_like]:
4852            x = torch.randn(B, N, C, H, W)
4853            # todo(chilli): test these better
4854            # Not testing correctness, just that they run
4855            vmap(op, in_dims=(0,))(
4856                x,
4857            )
4858
4859    def test_flatten(self, device):
4860        test = functools.partial(_vmap_test, check_propagates_grad=False)
4861
4862        op = torch.flatten
4863
4864        x = torch.randn(2, 3, 4, 5)
4865        test(self, op, (x, 1, 2), in_dims=(0, None, None))
4866
4867    def test_group_norm(self, device):
4868        test = functools.partial(_vmap_test, check_propagates_grad=False)
4869
4870        B, N, C, H, W = 2, 3, 24, 5, 7
4871        op = F.group_norm
4872
4873        x = torch.randn(B, N, C, H, W)
4874        weight = torch.randn(C)
4875        bias = torch.randn(C)
4876        test(self, op, (x, 3, weight, bias), in_dims=(0, None, None, None))
4877
4878        x = torch.randn(B, N, C, H, W)
4879        weight = torch.randn(B, C)
4880        bias = torch.randn(B, C)
4881        test(self, op, (x, 4, weight, bias), in_dims=(0, None, 0, 0))
4882
4883    def test_index_put(self, device):
4884        def test(f, t, idx, values):
4885            base = f(t[0], idx[0], values[0])
4886            self.assertEqual(vmap(f, in_dims=(0, 0, 0))(t, idx, values)[0], base)
4887            self.assertEqual(
4888                vmap(f, in_dims=(0, None, None))(t, idx[0], values[0])[0], base
4889            )
4890            self.assertEqual(vmap(f, in_dims=(0, None, 0))(t, idx[0], values)[0], base)
4891            self.assertEqual(vmap(f, in_dims=(0, 0, None))(t, idx, values[0])[0], base)
4892
4893        def f(x, y, z):
4894            x[y] = z
4895            return x
4896
4897        x = torch.randn(3, 4, 5, device=device)
4898        y = torch.zeros((3, 2), device=device).long()
4899        z = torch.randn(3, 2, 5, device=device)
4900        test(f, x, y, z)
4901
4902        # indexing innermost dim
4903        def f(t, idx, values):
4904            t[:, idx] = values
4905            return t
4906
4907        t = torch.zeros((3, 2, 3))
4908        values = torch.ones((3, 1, 2))
4909        idx = torch.tensor([[1, 2]]).expand((3, 2))
4910        test(f, t, idx, values)
4911
4912        # indexing middle dim
4913        def f(t, idx, values):
4914            t[:, idx, :] = values
4915            return t
4916
4917        t = torch.zeros((3, 2, 3, 3))
4918        values = torch.ones((3, 1, 2, 3))
4919        idx = torch.tensor([[0, 2]]).expand((3, 2))
4920        test(f, t, idx, values)
4921
4922        # indexing with slices
4923        def f(t, values):
4924            t[:, :2, :] = values
4925            return t
4926
4927        base = f(t[0], values[0])
4928        self.assertEqual(vmap(f, in_dims=(0, 0))(t, values)[0], base)
4929        self.assertEqual(vmap(f, in_dims=(0, None))(t, values[0])[0], base)
4930
4931        # index_put_
4932        tensor = torch.zeros(3, 3, 4)
4933        value = torch.ones(3, 2)
4934        idxs = (
4935            torch.tensor([[0], [1], [2]]),
4936            torch.tensor([[0]]),
4937            torch.tensor([1, 2]),
4938        )
4939        expected = torch.index_put_(tensor.clone(), idxs, value)
4940
4941        def f(t, idx, v):
4942            torch.index_put_(t, idx, v)
4943            return t
4944
4945        self.assertEqual(
4946            vmap(f, in_dims=(0, (None, None), 0))(tensor, idxs[1:], value), expected
4947        )
4948        self.assertEqual(
4949            vmap(f, in_dims=(0, (None, None), None))(tensor, idxs[1:], value[0]),
4950            expected,
4951        )
4952
4953        # boolean mask
4954        B = 2
4955        x = torch.randn(1, 3, 3)
4956        gy = torch.randn(B, 1, 3, 3)
4957
4958        def f(x, gy):
4959            mask = x < 1e-09
4960            zeros = torch.zeros([])
4961            index_put = torch.ops.aten.index_put.default(gy, [mask], zeros)
4962            return index_put
4963
4964        self.vmap_outplace_test(f, (x, gy), {}, in_dims=(None, 0))
4965
4966    @onlyCUDA
4967    @parametrize("inplace", [True, False])
4968    def test_0d_tensor_index_put(self, device, inplace):
4969        def f(t, idx, v):
4970            fn = torch.index_put_ if inplace else torch.index_put
4971            return fn(t, idx, v)
4972
4973        N = 2
4974        t = torch.zeros((N, 5), device="cuda")
4975        idx = torch.tensor([1, 3])
4976        v = torch.tensor(1, dtype=t.dtype, device="cpu")
4977
4978        expected = torch.tensor([[0, 1, 0, 1, 0], [0, 1, 0, 1, 0]], dtype=t.dtype)
4979        self.assertEqual(expected, vmap(f, in_dims=(0, None, None))(t, (idx,), v))
4980
4981    @parametrize("training", [True, False])
4982    @parametrize("track_running_stats", [True, False])
4983    @parametrize("affine", [True, False])
4984    def test_batch_norm(self, device, affine, track_running_stats, training):
4985        if not track_running_stats and not training:
4986            return
4987
4988        test = functools.partial(_vmap_test, check_propagates_grad=False)
4989        BN = torch.nn.BatchNorm2d
4990        ensemble_size = 10
4991        hidden_dim = 3
4992
4993        weights, buffers, _, _, _ = functional_init_with_buffers(BN, [ensemble_size])(
4994            hidden_dim, affine=affine, track_running_stats=track_running_stats
4995        )
4996
4997        inputs = [torch.randn(ensemble_size, 32, hidden_dim, 16, 16, device=device)]
4998        in_dims = [0]
4999
5000        def append(inp, in_dim):
5001            inputs.append(inp)
5002            in_dims.append(in_dim)
5003
5004        if track_running_stats:
5005            running_mean, running_var, _ = buffers
5006            append(running_mean.to(device), 0)
5007            append(running_var.to(device), 0)
5008        else:
5009            append(None, None)
5010            append(None, None)
5011
5012        if affine:
5013            weight, bias = weights
5014            append(weight.to(device), 0)
5015            append(bias.to(device), 0)
5016        else:
5017            append(None, None)
5018            append(None, None)
5019
5020        append(training, None)
5021
5022        def op(inp, running_mean, running_var, weight, bias, training):
5023            res = F.batch_norm(inp, running_mean, running_var, weight, bias, training)
5024            if track_running_stats:
5025                return res, running_mean, running_var
5026            return res
5027
5028        test(self, op, tuple(inputs), in_dims=tuple(in_dims))
5029
5030    def test_torch_return_types_returns(self, device):
5031        t = torch.randn(3, 2, 2, device=device)
5032        self.assertTrue(
5033            isinstance(vmap(torch.min, (0, None))(t, 0), torch.return_types.min)
5034        )
5035        self.assertTrue(
5036            isinstance(vmap(torch.max, (0, None))(t, 0), torch.return_types.max)
5037        )
5038        self.assertTrue(
5039            isinstance(
5040                vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk
5041            )
5042        )
5043        self.assertTrue(
5044            isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig)
5045        )
5046
5047    def test_namedtuple_returns(self, device):
5048        Point = namedtuple("Point", ["x", "y"])
5049
5050        def f(x, y):
5051            return Point(x=x, y=y)
5052
5053        x = torch.randn(2, 5, device=device)
5054        y = torch.randn(2, 3, device=device)
5055        self.assertTrue(isinstance(vmap(f)(x, y), Point))
5056
5057    def test_inplace_on_view(self, device):
5058        def func(leaf):
5059            base = leaf * leaf
5060            view = base.transpose(0, 1)
5061            view[2:4, 2:4] *= 2
5062            view[0:2, 0:2].diagonal().sin_()
5063            view = view[1:3, 1:3]
5064            view.cos_()
5065            return view
5066
5067        def push_vjp(leaf, gout):
5068            _, vjp_fn = vjp(func, leaf)
5069            (result,) = vjp_fn(gout)
5070            return result
5071
5072        leaf = torch.randn(4, 4, device=device)
5073        gout = torch.randn(2, 2, device=device)
5074        args = (leaf, gout)
5075
5076        for (
5077            batched_args,
5078            in_dims,
5079            _,
5080        ) in generate_vmap_inputs(args, {}):
5081            if in_dims[1] is None:
5082                # triggers some composite compliance problem
5083                continue
5084            self.vmap_outplace_test(push_vjp, batched_args, {}, in_dims)
5085
5086    def test_advanced_indexing(self, device):
5087        def test(f, args):
5088            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
5089                self.assertEqual(loop_out, batched_out)
5090
5091        def f(x, idx):
5092            return x[:, idx]
5093
5094        def f2(x, idx):
5095            return x[idx, :]
5096
5097        def f3(x, idx):
5098            return x[:, :, idx]
5099
5100        inps = (
5101            torch.randn(5, 5, 5, device=device),
5102            torch.randn(5, 5, 5, 5, device=device),
5103            torch.randn(5, 5, 5, 5, 5, device=device),
5104        )
5105        idxes = (
5106            torch.tensor([0, 1, 2], device=device),
5107            torch.tensor([0, 1, 2], device=device).reshape(3, 1),
5108            torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1),
5109        )
5110        for inp, idx in itertools.product(inps, idxes):
5111            test(f, (inp, idx))
5112            test(f2, (inp, idx))
5113            test(f3, (inp, idx))
5114
5115    def test_nested_advanced_indexing(self, device):
5116        e = torch.rand(7, 4, device=device)
5117        idx = torch.tensor([0, 1], device=device).view(2, 1)
5118
5119        # simple reference implementation for comparison
5120        def _fake_vmap(f, in_dims=0, out_dims=0):
5121            def w(input):
5122                r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))]
5123                return torch.stack(r, out_dims)
5124
5125            return w
5126
5127        def with_vmap(_vmap):
5128            def g(idx_):
5129                def f(e_):
5130                    return e_[idx_]
5131
5132                return _vmap(f, in_dims=1)(e)
5133
5134            r = _vmap(g)(idx)
5135            return r
5136
5137        a = with_vmap(vmap)
5138        b = with_vmap(_fake_vmap)
5139        self.assertEqual(a, b)
5140
5141    @ops(
5142        filter(lambda op: "linalg" in op.name, op_db + additional_op_db),
5143        allowed_dtypes=(torch.float,),
5144    )
5145    @skipOps(
5146        "TestVmapOperatorsOpInfo",
5147        "test_vmap_linalg_failure_1D_input",
5148        {
5149            xfail("linalg.vector_norm"),  # can accept vector inputs
5150            xfail("linalg.norm"),  # can accept vector inputs
5151            xfail("linalg.norm", "subgradients_at_zero"),  # can accept vector inputs
5152            xfail("linalg.vander"),  # can accept vector inputs
5153            skip(
5154                "linalg.multi_dot"
5155            ),  # accepts list of tensor inputs, has its own special test
5156            xfail("linalg.vecdot"),
5157            # throws in vmap on CUDA
5158            # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)
5159            # https://github.com/pytorch/pytorch/runs/8110653462?check_suite_focus=true
5160            # but it passes locally
5161            xfail("linalg.diagonal"),
5162            skip("linalg.matrix_norm", ""),
5163            skip("linalg.ldl_solve", ""),
5164        },
5165    )
5166    def test_vmap_linalg_failure_1D_input(self, device, dtype, op):
5167        for sample in op.sample_inputs(device, dtype, requires_grad=False):
5168            if sample.input.dim() != 2 or sample.input.shape[0] == 0:
5169                continue
5170            test_input = sample.input[
5171                0
5172            ]  # using the sample input avoids numerical inconsistency issues
5173            with self.assertRaisesRegex(RuntimeError, "dimension"):
5174                op(test_input, *sample.args, **sample.kwargs)
5175
5176            def op_wrapper(inp):
5177                return op(inp, *sample.args, **sample.kwargs)
5178
5179            # square inputs are more likely to pass linalg checks
5180            test_input = test_input.expand(test_input.shape[0], test_input.shape[0])
5181            with self.assertRaisesRegex(RuntimeError, "dimension"):
5182                return vmap(op_wrapper)(test_input)
5183
5184    def test_vmap_multi_dot_failure_1D_input(self):
5185        # special exception for first and last tensors so making giving 3 items avoids special cases
5186        inputs = (torch.randn(3, 3), torch.randn(3), torch.randn(3, 3))
5187        with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"):
5188            torch.linalg.multi_dot(inputs)
5189
5190        # square inputs are more likely to pass linalg checks
5191        inputs = tuple(i.expand(i.shape[0], i.shape[0]) for i in inputs)
5192        with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"):
5193            return vmap(torch.linalg.multi_dot)(inputs)
5194
5195    def test_vmap_escaped_error(self):
5196        escaped = None
5197
5198        def f(x):
5199            nonlocal escaped
5200            escaped = x
5201            return x**2
5202
5203        x = torch.randn([3, 3, 3, 3, 3])
5204        vmap(f)(x)
5205
5206        common_message = (
5207            r"your tensor may have escaped from inside a function being vmapped.*{0}.*"
5208        )
5209
5210        # Note: These are not a complete set of tests for all possible functions calling 'vmap_check_escaped'
5211
5212        with self.assertRaisesRegex(
5213            RuntimeError, common_message.format("gen_vmap_plumbing")
5214        ):
5215            escaped.sin()
5216
5217        with self.assertRaisesRegex(
5218            RuntimeError, common_message.format("boxed_tensor_inputs_batch_rule")
5219        ):
5220            escaped.sin_()
5221
5222        with self.assertRaisesRegex(
5223            RuntimeError, common_message.format("gen_vmap_inplace_plumbing")
5224        ):
5225            escaped.mul_(1)
5226
5227        with self.assertRaisesRegex(
5228            RuntimeError, common_message.format("binary_cross_entropy_plumbing")
5229        ):
5230            torch.nn.functional.binary_cross_entropy(escaped, torch.zeros([3, 3, 3, 3]))
5231
5232        with self.assertRaisesRegex(
5233            RuntimeError, common_message.format("boxed_existing_bdim_all_batch_rule")
5234        ):
5235            torch.nn.functional.adaptive_max_pool2d(escaped, output_size=(1, 1))
5236
5237        with self.assertRaisesRegex(
5238            RuntimeError, common_message.format("boxed_reduction_batch_rule")
5239        ):
5240            escaped.argmin()
5241
5242        a = torch.zeros([4, 4, 4, 4])
5243        b = torch.zeros([4, 4, 4, 4], dtype=torch.long)
5244        with self.assertRaisesRegex(
5245            RuntimeError, common_message.format("boxed_all_tensors_have_optional_bdim")
5246        ):
5247            torch.ops.aten.adaptive_max_pool2d_backward(escaped, a, b)
5248
5249        vmap(f)(torch.tensor([[0, 0], [0, 0]], dtype=torch.int))
5250        with self.assertRaisesRegex(
5251            RuntimeError, common_message.format("gen_vmap_plumbing_no_returns")
5252        ):
5253            torch.ops.aten._linalg_check_errors(escaped, "linalg.inv", is_matrix=False)
5254
5255    def test_vmap_with_anomaly_detection(self):
5256        with torch.autograd.set_detect_anomaly(True):
5257            x = torch.zeros(3) - 1
5258
5259            def fn(x):
5260                return x.sum()
5261
5262            per_sample_grad = vmap(grad(fn))(x)
5263            self.assertEqual(per_sample_grad, torch.ones_like(x))
5264
5265            def bad_fn(x):
5266                return x.sqrt().sum()
5267
5268            err_msg = "Function 'SqrtBackward0' returned nan values in its 0th output."
5269            with self.assertRaisesRegex(RuntimeError, err_msg):
5270                vmap(grad(bad_fn))(x)
5271
5272    def test_searchsorted_bucketize(self, device):
5273        # OpInfo generates test with repeated samples in batch dim.
5274        # Thus we test explicitly with different samples across a batch.
5275
5276        def test():
5277            boundaries = torch.tensor(
5278                [[1, 4, 5, 7, 9], [1, 2, 6, 8, 10]], device=device
5279            )
5280            v = torch.tensor(3, device=device)
5281            self.vmap_outplace_test(torch.searchsorted, (boundaries, v), {}, (0, None))
5282            self.vmap_outplace_test(torch.bucketize, (v, boundaries), {}, (None, 0))
5283            boundaries = torch.tensor([[1, 4, 5, 7, 9], [1, 2, 4, 8, 9]], device=device)
5284            v = torch.tensor([3, 4], device=device)
5285            self.vmap_outplace_test(torch.searchsorted, (boundaries, v), {}, (0, 0))
5286            self.vmap_outplace_test(torch.bucketize, (v, boundaries), {}, (0, 0))
5287
5288        test()
5289
5290
5291@markDynamoStrictTest
5292class TestRandomness(TestCase):
5293    def _reset_random(self, generator, orig_state, use_generator, seed):
5294        return (
5295            generator.set_state(orig_state)
5296            if use_generator
5297            else torch.manual_seed(seed)
5298        )
5299
5300    def _get_image(self, batched_input, batch_size, device):
5301        if batched_input == "first":
5302            return torch.ones([batch_size, 3, 3, 14, 14], device=device)
5303        if batched_input == "last":
5304            return torch.ones([3, 3, 14, 14, batch_size], device=device)
5305        assert batched_input == "none"
5306        return torch.ones([3, 3, 14, 14], device=device)
5307
5308    def _assert_all_slices_equal(self, tensor):
5309        expected = tensor[0]
5310        self.assertTrue((tensor == expected).all())
5311
5312    def _assert_all_slices_unique(self, tensor):
5313        B0 = tensor.shape[0]
5314        slices_equal = vmap(vmap(lambda x, y: (x == y).all(), (0, None)), (None, 0))(
5315            tensor, tensor
5316        )
5317        assert slices_equal.shape == (B0, B0)
5318        slices_equal.diagonal().zero_()
5319        self.assertEqual(slices_equal, torch.zeros_like(slices_equal))
5320
5321    def _assert_throws_in_error_mode(self, fn, args, in_dims):
5322        with self.assertRaisesRegex(
5323            RuntimeError, r"called random operation while in randomness error mode"
5324        ):
5325            vmap(fn, in_dims=in_dims, randomness="error")(*args)
5326
5327    def _assert_throws_in_different_mode_inplace(self, fn, args, in_dims):
5328        with self.assertRaisesRegex(
5329            RuntimeError, r"different inplace randomness on an unbatched tensor"
5330        ):
5331            vmap(fn, in_dims=in_dims, randomness="different")(*args)
5332
5333    def _assert_throws_in_same_mode_batched(self, fn, args, in_dims):
5334        with self.assertRaisesRegex(
5335            RuntimeError,
5336            r"Vmap does not currently support same randomness with a batched tensor input",
5337        ):
5338            vmap(fn, in_dims=in_dims, randomness="same")(*args)
5339
5340    def _in_dims(self, *batched_strings):
5341        def get_in_dim(batched_string):
5342            if batched_string == "first":
5343                return 0
5344            if batched_string == "last":
5345                return -1
5346            assert batched_string == "none"
5347            return None
5348
5349        batched_strings = batched_strings + (
5350            "first",
5351        )  # for the always batched as first dim dummy argument
5352        return tuple(get_in_dim(batched_string) for batched_string in batched_strings)
5353
5354    @parametrize("randomness", ["same", "different", "error"])
5355    @parametrize("use_generator", [True, False])
5356    def test_factory_ops(self, device, randomness, use_generator):
5357        generator = torch.Generator(device=device)
5358        orig_state = generator.get_state()
5359        kwargs = (
5360            {"device": device, "generator": generator}
5361            if use_generator
5362            else {"device": device}
5363        )
5364        ops = [
5365            lambda _, shape: torch.randn(shape, **kwargs),
5366            lambda _, shape: torch.rand(shape, **kwargs),
5367            lambda _, shape: torch.randint(100, shape, **kwargs),
5368            lambda _, shape: torch.randint(5, 100, shape, **kwargs),
5369            lambda _, shape: torch.normal(0.0, 1.0, shape, **kwargs),
5370        ]
5371        B0 = 4
5372        shape = (3, 3)
5373        seed = 1234567
5374
5375        for op in ops:
5376            passed = torch.randn(B0, device=device)
5377            if randomness == "error":
5378                self._assert_throws_in_error_mode(
5379                    op, (passed, shape), in_dims=(0, None)
5380                )
5381                return
5382
5383            generator = self._reset_random(generator, orig_state, use_generator, seed)
5384            vmap_result = vmap(op, in_dims=(0, None), randomness=randomness)(
5385                passed, shape
5386            )
5387
5388            generator = self._reset_random(generator, orig_state, use_generator, seed)
5389            if randomness == "different":
5390                expected = op(passed, [B0, *shape])
5391                self._assert_all_slices_unique(vmap_result)
5392                self.assertEqual(vmap_result, expected)
5393            else:
5394                expected = op(passed, shape)
5395                self._assert_all_slices_equal(vmap_result)
5396                for i in range(B0):
5397                    self.assertEqual(vmap_result[i], expected)
5398
5399    @parametrize("randomness", ["same", "different", "error"])
5400    @parametrize("use_generator", [True, False])
5401    def test_randperm(self, device, randomness, use_generator):
5402        # needs a special case because randperm doesn't take a batch size
5403        B0 = 4
5404        seed = 1234567
5405        passed = torch.randn(B0, device=device)
5406
5407        torch.manual_seed(seed)
5408        generator = torch.Generator(device=device)
5409        orig_state = generator.get_state()
5410
5411        kwargs = (
5412            {"device": device, "generator": generator}
5413            if use_generator
5414            else {"device": device}
5415        )
5416
5417        if randomness == "error":
5418            with self.assertRaisesRegex(
5419                RuntimeError, r"called random operation while in randomness error mode"
5420            ):
5421                vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)(
5422                    passed
5423                )
5424            return
5425
5426        vmap_result = vmap(
5427            lambda _: torch.randperm(10, **kwargs), randomness=randomness
5428        )(passed)
5429        generator = generator.set_state(orig_state)
5430        torch.manual_seed(seed)
5431        if randomness == "different":
5432            for i in range(B0):
5433                expected = torch.randperm(10, **kwargs)
5434                # RNG differs between eager and via dynamo trace on CUDA
5435                if TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda":
5436                    self._assert_all_slices_unique(vmap_result)
5437                else:
5438                    self.assertEqual(vmap_result[i], expected)
5439        else:
5440            expected = torch.randperm(10, **kwargs)
5441            # RNG differs between eager and via dynamo trace on CUDA
5442            if TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda":
5443                self._assert_all_slices_equal(vmap_result)
5444            else:
5445                for i in range(B0):
5446                    self.assertEqual(vmap_result[i], expected)
5447
5448    @parametrize("randomness", ["error", "same", "different"])
5449    @parametrize("batched_input", ["first", "last", "none"])
5450    def test_dropout(self, device, randomness, batched_input):
5451        def op(t, ignored):
5452            return torch.nn.functional.dropout(torch.ones_like(t), training=True)
5453
5454        B0 = 4
5455        always_batched = torch.randn((B0,))
5456        passed = self._get_image(batched_input, B0, device)
5457        in_dims = self._in_dims(batched_input)
5458
5459        if randomness == "error":
5460            with self.assertRaisesRegex(
5461                RuntimeError, r"called random operation while in randomness error mode"
5462            ):
5463                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
5464            return
5465
5466        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(
5467            passed, always_batched
5468        )
5469
5470        # Check that the randomness is within bounds...
5471        # ideally this is close to 0.5
5472        p_estimate = vmap_result.mean() / 2
5473        self.assertTrue(p_estimate < 0.75)
5474        self.assertTrue(p_estimate > 0.25)
5475
5476        if randomness == "different":
5477            self._assert_all_slices_unique(vmap_result)
5478            return
5479
5480        assert randomness == "same"
5481        self._assert_all_slices_equal(vmap_result)
5482
5483    @parametrize("randomness", ["error", "same", "different"])
5484    @parametrize("batched_input", ["first", "last", "none"])
5485    def test_alpha_dropout(self, device, randomness, batched_input):
5486        def op(t, ignored):
5487            return torch.nn.functional.alpha_dropout(torch.ones_like(t), training=True)
5488
5489        B0 = 4
5490        always_batched = torch.randn((B0,))
5491        passed = self._get_image(batched_input, B0, device)
5492        in_dims = self._in_dims(batched_input)
5493
5494        if randomness == "error":
5495            with self.assertRaisesRegex(
5496                RuntimeError, r"called random operation while in randomness error mode"
5497            ):
5498                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
5499            return
5500
5501        # I have no clue how to actually test correctness of alpha dropout because the docs
5502        # seem wrong: https://github.com/pytorch/pytorch/issues/74004
5503        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(
5504            passed, always_batched
5505        )
5506        if randomness == "different":
5507            self._assert_all_slices_unique(vmap_result)
5508            return
5509
5510        assert randomness == "same"
5511        self._assert_all_slices_equal(vmap_result)
5512
5513    @parametrize("randomness", ["error", "same", "different"])
5514    @parametrize("batched_input", ["first", "last", "none"])
5515    @parametrize("dim", [2, 3])
5516    def test_feature_dropout(self, device, randomness, batched_input, dim):
5517        def op(t, ignored):
5518            f = (
5519                torch.nn.functional.dropout2d
5520                if dim == 2
5521                else torch.nn.functional.dropout3d
5522            )
5523            return f(torch.ones_like(t), training=True)
5524
5525        B0 = 4
5526        always_batched = torch.randn((B0,))
5527        passed = self._get_image(batched_input, B0, device)
5528        if dim == 3:
5529            unsqueeze_dim = -2 if batched_input == "last" else -1
5530            passed = passed.unsqueeze(unsqueeze_dim)
5531        in_dims = self._in_dims(batched_input)
5532
5533        if randomness == "error":
5534            with self.assertRaisesRegex(
5535                RuntimeError, r"called random operation while in randomness error mode"
5536            ):
5537                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
5538            return
5539
5540        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(
5541            passed, always_batched
5542        )
5543
5544        # Check the "feature" pattern
5545        dims = [-1, -2] if dim == 2 else [-1, -2, -3]
5546        planes_numel = (
5547            2
5548            * vmap_result.numel()
5549            / (vmap_result.shape[0] * vmap_result.shape[1] * vmap_result.shape[2])
5550        )
5551        planes = vmap_result.sum(dims)
5552        result = (planes == 0) ^ (planes == planes_numel)
5553        self.assertEqual(result, torch.ones_like(result, dtype=torch.bool))
5554
5555        if randomness == "different":
5556            self._assert_all_slices_unique(vmap_result)
5557            return
5558
5559        assert randomness == "same"
5560        self._assert_all_slices_equal(vmap_result)
5561
5562    @parametrize("randomness", ["error", "same", "different"])
5563    @parametrize("batched_input", ["first", "last", "none"])
5564    def test_feature_alpha_dropout(self, device, randomness, batched_input):
5565        def op(t, ignored):
5566            return torch.nn.functional.feature_alpha_dropout(
5567                torch.ones_like(t), training=True
5568            )
5569
5570        B0 = 4
5571        always_batched = torch.randn((B0,))
5572        passed = self._get_image(batched_input, B0, device)
5573        unsqueeze_dim = -2 if batched_input == "last" else -1
5574        passed = passed.unsqueeze(unsqueeze_dim)
5575        in_dims = self._in_dims(batched_input)
5576
5577        if randomness == "error":
5578            with self.assertRaisesRegex(
5579                RuntimeError, r"called random operation while in randomness error mode"
5580            ):
5581                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
5582            return
5583
5584        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(
5585            passed, always_batched
5586        )
5587
5588        # I have no clue how to actually test correctness of alpha dropout because the docs
5589        # seem wrong: https://github.com/pytorch/pytorch/issues/74004
5590
5591        # Check the "feature" pattern
5592        dims = [-1, -2, -3]
5593        planes = vmap_result.sum(dims)
5594        max_elt = planes.max()
5595        min_elt = planes.min()
5596        result = (planes == min_elt) ^ (planes == max_elt)
5597        self.assertEqual(result, torch.ones_like(result, dtype=torch.bool))
5598
5599        if randomness == "different":
5600            self._assert_all_slices_unique(vmap_result)
5601            return
5602
5603        assert randomness == "same"
5604        self._assert_all_slices_equal(vmap_result)
5605
5606    @parametrize("randomness", ["error", "same", "different"])
5607    @parametrize("batched_input", ["first", "last", "none"])
5608    def test_like_functions(self, device, randomness, batched_input):
5609        seed = 1234567
5610        supported_ops = [
5611            lambda t, _: torch.randint_like(t, 20),
5612            lambda t, _: torch.randint_like(t, 0, 20),
5613            lambda t, _: torch.rand_like(t),
5614            lambda t, _: torch.randn_like(t),
5615        ]
5616        B0 = 4
5617
5618        for op in supported_ops:
5619            always_batched = torch.randn(B0)
5620            passed = self._get_image(batched_input, B0, device)
5621            in_dims = self._in_dims(batched_input)
5622
5623            if randomness == "error":
5624                with self.assertRaisesRegex(
5625                    RuntimeError,
5626                    r"called random operation while in randomness error mode",
5627                ):
5628                    vmap(op, in_dims=in_dims, randomness=randomness)(
5629                        passed, always_batched
5630                    )
5631                return
5632
5633            torch.manual_seed(seed)
5634            vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(
5635                passed, always_batched
5636            )
5637
5638            torch.manual_seed(seed)
5639
5640            if batched_input == "last":
5641                passed = passed.movedim(-1, 0)
5642            if randomness == "different":
5643                if batched_input == "none":
5644                    passed = passed.expand(B0, *passed.shape)
5645                expected = op(passed, 0)
5646
5647                self._assert_all_slices_unique(vmap_result)
5648                # RNG differs between eager and via dynamo trace on CUDA
5649                if not (TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda"):
5650                    self.assertEqual(expected, vmap_result)
5651                return
5652
5653            assert randomness == "same"
5654            if batched_input != "none":
5655                passed = passed[0]
5656            expected = op(passed, 0)
5657            self._assert_all_slices_equal(vmap_result)
5658            # RNG differs between eager and via dynamo trace on CUDA
5659            if not (TEST_WITH_TORCHDYNAMO and torch.device(device).type == "cuda"):
5660                for i in range(B0):
5661                    self.assertEqual(expected, vmap_result[i])
5662
5663    @parametrize("use_generator", [True, False])
5664    @parametrize("randomness", ["error", "same", "different"])
5665    @parametrize("batched_input", ["first", "last", "none"])
5666    def test_random_unary_inplace(
5667        self, device, use_generator, randomness, batched_input
5668    ):
5669        generator = torch.Generator(device=device)
5670        orig_state = generator.get_state()
5671        kwargs = {"generator": generator} if use_generator else {}
5672        ops = [
5673            lambda t, _: t.random_(**kwargs),
5674            lambda t, _: t.random_(100, **kwargs),
5675            lambda t, _: t.random_(-5, 100, **kwargs),
5676            lambda t, _: t.normal_(**kwargs),
5677            lambda t, _: t.bernoulli_(**kwargs),
5678            lambda t, _: t.cauchy_(**kwargs),
5679            lambda t, _: t.exponential_(**kwargs),
5680            lambda t, _: t.geometric_(0.5, **kwargs),
5681            lambda t, _: t.log_normal_(**kwargs),
5682            lambda t, _: t.uniform_(**kwargs),
5683        ]
5684        B0 = 4
5685        seed = 1234567
5686        in_dims = self._in_dims(batched_input)
5687
5688        for op in ops:
5689            # because of in place updates, clone inputs
5690            always_batched = torch.randn(B0, device=device)
5691            passed = self._get_image(batched_input, B0, device)
5692            passed_expected = passed.clone()
5693
5694            if randomness == "error":
5695                self._assert_throws_in_error_mode(
5696                    op, (passed, always_batched), in_dims=in_dims
5697                )
5698                return
5699            if randomness == "different" and batched_input == "none":
5700                self._assert_throws_in_different_mode_inplace(
5701                    op, (passed, always_batched), in_dims=in_dims
5702                )
5703                return
5704
5705            generator = self._reset_random(generator, orig_state, use_generator, seed)
5706            vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(
5707                passed, always_batched
5708            )
5709
5710            if batched_input == "last":
5711                passed_expected = passed_expected.movedim(-1, 0)
5712            generator = self._reset_random(generator, orig_state, use_generator, seed)
5713            if randomness == "different":
5714                expected = op(passed_expected, always_batched)
5715                self._assert_all_slices_unique(vmap_result)
5716                self.assertEqual(vmap_result, expected)
5717            else:
5718                if batched_input != "none":
5719                    passed_expected = passed_expected[
5720                        0
5721                    ].clone()  # bug in pytorch, normal_ on views doesn't work
5722                expected = op(passed_expected, always_batched)
5723                self._assert_all_slices_equal(vmap_result)
5724                for i in range(B0):
5725                    self.assertEqual(vmap_result[i], expected)
5726
5727    @parametrize("use_generator", [True, False])
5728    @parametrize("randomness", ["error", "same", "different"])
5729    @parametrize("batched_input", ["first", "last", "none"])
5730    @parametrize("batched_probability", ["first", "last", "none"])
5731    def test_bernoulli_in_place(
5732        self, device, use_generator, randomness, batched_input, batched_probability
5733    ):
5734        B0 = 4
5735        seed = 1234567
5736        generator = torch.Generator(device=device)
5737        orig_state = generator.get_state()
5738        kwargs = {"generator": generator} if use_generator else {}
5739        in_dims = self._in_dims(batched_input, batched_probability)
5740
5741        def op(t, p, ignored):
5742            return t.bernoulli_(p, **kwargs)
5743
5744        # because of in place updates, clone inputs
5745        always_batched = torch.randn(B0, device=device)
5746        input = self._get_image(batched_input, B0, device)
5747        input_expected = input.clone()
5748        probability = self._get_image(batched_probability, B0, device) - 0.5
5749
5750        if randomness == "error":
5751            self._assert_throws_in_error_mode(
5752                op, (input, probability, always_batched), in_dims=in_dims
5753            )
5754            return
5755        if randomness == "same" and batched_probability != "none":
5756            self._assert_throws_in_same_mode_batched(
5757                op, (input, probability, always_batched), in_dims=in_dims
5758            )
5759            return
5760        if batched_input == "none" and batched_probability != "none":
5761            regex = r"there exists a Tensor `other` in extra_args that has more elements than `self`"
5762            with self.assertRaisesRegex(RuntimeError, regex):
5763                vmap(op, in_dims=in_dims, randomness=randomness)(
5764                    input, probability, always_batched
5765                )
5766            return
5767        if randomness == "different" and batched_input == "none":
5768            self._assert_throws_in_different_mode_inplace(
5769                op, (input, probability, always_batched), in_dims=in_dims
5770            )
5771            return
5772
5773        self._reset_random(generator, orig_state, use_generator, seed)
5774        vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(
5775            input, probability, always_batched
5776        )
5777
5778        self._reset_random(generator, orig_state, use_generator, seed)
5779        if batched_input == "last":
5780            input_expected = input_expected.movedim(-1, 0)
5781        if batched_probability == "last":
5782            probability = probability.movedim(-1, 0)
5783        if randomness == "different":
5784            expected = op(input_expected, probability, always_batched)
5785            self._assert_all_slices_unique(vmap_result)
5786            self.assertEqual(vmap_result, expected)
5787        else:
5788            if batched_input != "none":
5789                input_expected = input_expected[0]
5790            expected = op(input_expected, probability, always_batched)
5791            self._assert_all_slices_equal(vmap_result)
5792            for i in range(B0):
5793                self.assertEqual(vmap_result[i], expected)
5794
5795    @parametrize("use_generator", [True, False])
5796    @parametrize("randomness", ["error", "same", "different"])
5797    @parametrize("batched_input", ["first", "last", "none"])
5798    @parametrize("batched_other", ["first", "last", "none"])
5799    def test_random_binary_out_of_place(
5800        self, device, use_generator, randomness, batched_input, batched_other
5801    ):
5802        generator = torch.Generator(device=device)
5803        orig_state = generator.get_state()
5804        kwargs = {"generator": generator} if use_generator else {}
5805        ops = [
5806            lambda t, o, _: torch.normal(t, o, **kwargs),
5807            lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
5808        ]
5809
5810        B0 = 4
5811        seed = 1234567
5812        in_dims = self._in_dims(batched_input, batched_other)
5813
5814        for op in ops:
5815            always_batched = torch.randn(B0, device=device)
5816            input = self._get_image(batched_input, B0, device)
5817            other = self._get_image(batched_other, B0, device)
5818
5819            if randomness == "error":
5820                self._assert_throws_in_error_mode(
5821                    op, (input, other, always_batched), in_dims=in_dims
5822                )
5823                return
5824            if randomness == "same" and (
5825                batched_input != "none" or batched_other != "none"
5826            ):
5827                self._assert_throws_in_same_mode_batched(
5828                    op, (input, other, always_batched), in_dims=in_dims
5829                )
5830                return
5831
5832            generator = self._reset_random(generator, orig_state, use_generator, seed)
5833            vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(
5834                input, other, always_batched
5835            )
5836
5837            if batched_input == "last":
5838                input = input.movedim(-1, 0)
5839            if batched_other == "last":
5840                other = other.movedim(-1, 0)
5841
5842            generator = self._reset_random(generator, orig_state, use_generator, seed)
5843            if randomness == "different":
5844                if batched_input == "none":
5845                    input = input.expand(B0, *input.shape)
5846                expected = op(input, other, always_batched)
5847                self._assert_all_slices_unique(vmap_result)
5848                self.assertEqual(vmap_result, expected)
5849            else:
5850                assert batched_input == "none" and batched_other == "none"
5851                expected = op(input, other, always_batched)
5852                self._assert_all_slices_equal(vmap_result)
5853                for i in range(B0):
5854                    self.assertEqual(vmap_result[i], expected)
5855
5856    @parametrize("use_generator", [True, False])
5857    @parametrize("randomness", ["error", "same", "different"])
5858    @parametrize("batched_input", ["first", "last", "none"])
5859    def test_random_unary_out_of_place(
5860        self, device, use_generator, randomness, batched_input
5861    ):
5862        generator = torch.Generator(device=device)
5863        orig_state = generator.get_state()
5864        kwargs = {"generator": generator} if use_generator else {}
5865        ops = [
5866            lambda t, _: torch.normal(0.0, torch.abs(t), **kwargs),
5867            lambda t, _: torch.normal(t, 1.0, **kwargs),
5868            lambda t, _: torch.bernoulli(t - 0.5, **kwargs),
5869            lambda t, _: torch.bernoulli(t, 0.5, **kwargs),
5870            lambda t, _: torch._standard_gamma(t, **kwargs),
5871            lambda t, _: torch._sample_dirichlet(t, **kwargs),
5872            lambda t, _: torch.poisson(t, **kwargs),
5873        ]
5874
5875        B0 = 4
5876        seed = 1234567
5877        in_dims = self._in_dims(batched_input)
5878
5879        for op in ops:
5880            always_batched = torch.randn(B0, device=device)
5881            passed = self._get_image(batched_input, B0, device)
5882            if randomness == "error":
5883                self._assert_throws_in_error_mode(
5884                    op, (passed, always_batched), in_dims=in_dims
5885                )
5886                return
5887            if randomness == "same" and batched_input != "none":
5888                self._assert_throws_in_same_mode_batched(
5889                    op, (passed, always_batched), in_dims=in_dims
5890                )
5891                return
5892
5893            generator = self._reset_random(generator, orig_state, use_generator, seed)
5894            vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(
5895                passed, always_batched
5896            )
5897
5898            generator = self._reset_random(generator, orig_state, use_generator, seed)
5899            if randomness == "different":
5900                if batched_input == "none":
5901                    passed = passed.expand(B0, *passed.shape)
5902                if batched_input == "last":
5903                    passed = passed.movedim(-1, 0)
5904                expected = op(passed, always_batched)
5905                self._assert_all_slices_unique(vmap_result)
5906                self.assertEqual(vmap_result, expected)
5907            else:
5908                expected = op(passed, always_batched)
5909                self._assert_all_slices_equal(vmap_result)
5910                for i in range(B0):
5911                    self.assertEqual(vmap_result[i], expected)
5912
5913    @parametrize("use_generator", [True, False])
5914    @parametrize("randomness", ["error", "same", "different"])
5915    @parametrize("batched_call", [True, False])
5916    @parametrize("batched_input", ["first", "last", "none"])
5917    def test_multinomial(
5918        self, device, use_generator, randomness, batched_call, batched_input
5919    ):
5920        def flatten_input(input, batch_call, batch_location):
5921            if batch_call and batch_location != "none":
5922                final_size = 3  # [B0, B, N]
5923            elif not batch_call and batch_location == "none":
5924                final_size = 1  # [N]
5925            else:
5926                final_size = 2  # [B0, N] or [B, N]
5927
5928            start_idx = final_size - 1
5929            end_idx = -1
5930            if batch_location == "last":
5931                start_idx -= 1
5932                end_idx -= (
5933                    1  # gets to correct final size because using negative indices
5934                )
5935
5936            ret = input.flatten(start_idx, end_idx)
5937            assert ret.dim() == final_size
5938            return ret
5939
5940        def op(input, _):
5941            return torch.multinomial(input, 10, **kwargs)
5942
5943        generator = torch.Generator(device=device)
5944        orig_state = generator.get_state()
5945        kwargs = {"generator": generator} if use_generator else {}
5946
5947        B0 = 4
5948        seed = 1234567
5949        in_dims = self._in_dims(batched_input)
5950
5951        always_batched = torch.randn(B0, device=device)
5952        passed = self._get_image(batched_input, B0, device)
5953        passed = flatten_input(passed, batched_call, batched_input)
5954        if randomness == "error":
5955            self._assert_throws_in_error_mode(
5956                op, (passed, always_batched), in_dims=in_dims
5957            )
5958            return
5959        if randomness == "same" and batched_input != "none":
5960            self._assert_throws_in_same_mode_batched(
5961                op, (passed, always_batched), in_dims=in_dims
5962            )
5963            return
5964
5965        generator = self._reset_random(generator, orig_state, use_generator, seed)
5966        vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(
5967            passed, always_batched
5968        )
5969
5970        generator = self._reset_random(generator, orig_state, use_generator, seed)
5971
5972        if randomness == "different":
5973            if batched_input == "none":
5974                passed = passed.expand(B0, *passed.shape)
5975            if batched_input == "last":
5976                passed = passed.movedim(-1, 0)
5977            orig_passed_size = passed.shape[:2] if batched_call else passed.shape[:1]
5978            passed = passed.flatten(0, 1) if batched_call else passed
5979            expected = op(passed, always_batched)
5980            expected = expected.reshape(*orig_passed_size, 10)
5981            self._assert_all_slices_unique(vmap_result)
5982            self.assertEqual(vmap_result, expected)
5983        else:
5984            expected = op(passed, always_batched)
5985            self._assert_all_slices_equal(vmap_result)
5986            for i in range(B0):
5987                self.assertEqual(vmap_result[i], expected)
5988
5989    def test_unsupported_random(self, device):
5990        x = torch.randn(3, device=device)
5991        y = x.abs()
5992        z = x.abs()
5993        with self.assertRaisesRegex(RuntimeError, "calling out variants"):
5994
5995            def f(x):
5996                return torch.randn(3, device=device, out=y)
5997
5998            vmap(f, randomness="same")(x)
5999        with self.assertRaisesRegex(RuntimeError, "calling out variants"):
6000
6001            def f(x0, x1):
6002                return torch.normal(x, y, out=x)
6003
6004            vmap(f, randomness="same")(z, z)
6005        with self.assertRaisesRegex(RuntimeError, "do not yet support"):
6006
6007            def f(z):
6008                return torch.rrelu(x)
6009
6010            vmap(f, randomness="same")(z)
6011
6012    @parametrize("in_dim", [0, 1, 2])
6013    @parametrize("out_dim", [0, 1, 2])
6014    def test_chunk_vmap(self, in_dim, out_dim):
6015        randomness = "different"
6016
6017        x = torch.randn(4, 5, 6)
6018
6019        def f(x):
6020            y = x.sin() + torch.rand_like(x)
6021            return y
6022
6023        for chunks in [1, 2, 3, 4, 7, 10, 16]:
6024            output = chunk_vmap(
6025                f,
6026                in_dims=in_dim,
6027                out_dims=out_dim,
6028                randomness=randomness,
6029                chunks=chunks,
6030            )(x)
6031            self._assert_all_slices_unique(output)
6032
6033    @parametrize("in_dim", [0, 1, 2])
6034    @parametrize("out_dim", [0, 1, 2])
6035    def test_vmap_chunksize(self, in_dim, out_dim):
6036        randomness = "different"
6037
6038        x = torch.randn(4, 5, 6)
6039
6040        def f(x):
6041            y = x.sin() + torch.rand_like(x)
6042            return y
6043
6044        for chunk_size in [1, 2, 3, 4, 7, 10, 16, 100]:
6045            output = vmap(
6046                f,
6047                in_dims=in_dim,
6048                out_dims=out_dim,
6049                randomness=randomness,
6050                chunk_size=chunk_size,
6051            )(x)
6052            self._assert_all_slices_unique(output)
6053
6054    def test_jacfwd_with_random(self):
6055        # checks on behavior are above, this just checks that jacfwd respects
6056        # the randomness param
6057
6058        x = torch.rand(3, 4)
6059        with self.assertRaisesRegex(
6060            RuntimeError, r"called random operation while in randomness error mode"
6061        ):
6062            jacfwd(torch.bernoulli)(x)
6063
6064        # x isn't batched so use bernoulli since it doesn't do inplace randomness
6065        jacfwd(torch.bernoulli, randomness="same")(x)
6066        jacfwd(torch.bernoulli, randomness="different")(x)
6067
6068    @parametrize("randomness", ["error", "same", "different"])
6069    def test_dropout_unbatched(self, device, randomness):
6070        x = torch.randn(3, device=device)
6071        y = torch.randn(1, 3, device=device)
6072
6073        def fn(x, y):
6074            # output from dropout should be a Tensor[B, 1, 3] (B=3)
6075            return x + torch.nn.functional.dropout(y, p=0.5).mean(1)
6076
6077        # We just verify that this doesn't raise an error for
6078        # `same` and `different` randomness.
6079        # Ref: https://github.com/pytorch/pytorch/issues/92283
6080        context = (
6081            self.assertRaises(RuntimeError)
6082            if randomness == "error"
6083            else contextlib.nullcontext()
6084        )
6085        with context:
6086            vmap(fn, in_dims=(0, None), randomness=randomness)(x, y)
6087
6088
6089@markDynamoStrictTest
6090class TestTransformFailure(TestCase):
6091    @skipIfTorchDynamo()
6092    @parametrize(
6093        "transform",
6094        ["vmap", "grad", "grad_and_value", "vjp", "jvp", "jacrev", "jacfwd"],
6095    )
6096    def test_fails_with_autograd_function(self, device, transform):
6097        failed_build_envs = ("linux-focal-py3.8-clang10", "linux-focal-py3.11-clang10")
6098        if (
6099            device == "cpu"
6100            and transform in ["grad", "vmap"]
6101            and TEST_WITH_TORCHDYNAMO
6102            and os.getenv("BUILD_ENVIRONMENT", "") in failed_build_envs
6103        ):
6104            raise unittest.SkipTest(
6105                "Unexpected successes on focal with dynamo,"
6106                + " see https://github.com/pytorch/pytorch/issues/107173"
6107            )
6108
6109        class Test(torch.autograd.Function):
6110            @staticmethod
6111            def forward(_, input):
6112                return input
6113
6114            @staticmethod
6115            def backward(_, grad_input):
6116                return grad_input
6117
6118        transform = getattr(functorch, transform)
6119
6120        def f(x):
6121            return Test.apply(x)
6122
6123        if transform in (grad, grad_and_value):
6124            input = torch.tensor(4.0)
6125        else:
6126            input = torch.randn(5)
6127
6128        if transform == vjp:
6129            transform = functools.partial(transform, f)
6130        elif transform == jvp:
6131            input = (input,)
6132            transform = functools.partial(transform, f, input)
6133        else:
6134            transform = transform(f)
6135
6136        with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
6137            transform(input)
6138
6139
6140@markDynamoStrictTest
6141class TestVmapDeviceType(Namespace.TestVmapBase):
6142    def _vmap_test(self, *args, **kwargs):
6143        return _vmap_test(self, *args, **kwargs)
6144
6145    def test__is_all_true(self, device):
6146        def test():
6147            def f(x, *, expected_result):
6148                result = torch.ops.aten._is_all_true(x)
6149                self.assertFalse(torch._C._functorch.is_batchedtensor(result))
6150                self.assertEqual(result.shape, torch.Size([]))
6151                self.assertEqual(result.item(), expected_result)
6152                return result
6153
6154            x = torch.rand(10, device=device)
6155            vmap(f)(x >= 0, expected_result=True)
6156            vmap(f)(x < 0, expected_result=False)
6157
6158            x[random.choice(range(10))] *= -1
6159            vmap(f)(x >= 0, expected_result=False)
6160            vmap(f)(x < 0, expected_result=False)
6161
6162            x = -torch.rand(10, device=device)
6163            vmap(f)(x > 0, expected_result=False)
6164            vmap(f)(x <= 0, expected_result=True)
6165
6166        check_vmap_fallback(self, test, torch._is_all_true)
6167
6168    def test__is_any_true(self, device):
6169        def test():
6170            def f(x, *, expected_result):
6171                result = torch.ops.aten._is_any_true(x)
6172                self.assertFalse(torch._C._functorch.is_batchedtensor(result))
6173                self.assertEqual(result.shape, torch.Size([]))
6174                self.assertEqual(result.item(), expected_result)
6175                return result
6176
6177            x = torch.zeros(10, device=device, dtype=torch.bool)
6178            vmap(f)(x > 0, expected_result=False)
6179
6180            x[5] = True
6181            vmap(f)(x > 0, expected_result=True)
6182            vmap(f)(x[1::2], expected_result=True)
6183            vmap(f)(x[0::2], expected_result=False)
6184
6185        check_vmap_fallback(self, test, torch._is_any_true)
6186
6187    def test_check_tensor(self, device):
6188        def test():
6189            test_sizes = [
6190                (1,),
6191                (10,),
6192                (1, 1),
6193                (1, 10),
6194                (10, 1),
6195                (10, 10),
6196                (1, 1, 1),
6197                (10, 1, 1),
6198                (1, 10, 1),
6199                (10, 10, 10),
6200            ]
6201
6202            def check_gte_0(t):
6203                return torch._test_check_tensor(t >= 0)
6204
6205            error_message = "Test message for TORCH_CHECK_TENSOR_ALL"
6206
6207            for size in test_sizes:
6208                t_all_gte_0 = torch.rand(size, device=device)
6209                t_all_lt_0 = t_all_gte_0 - 1
6210
6211                vmap(check_gte_0)(t_all_gte_0)
6212
6213                if len(size) >= 2:
6214                    vmap(vmap(check_gte_0))(t_all_gte_0)
6215
6216                with self.assertRaisesRegex(RuntimeError, error_message):
6217                    vmap(check_gte_0)(t_all_lt_0)
6218
6219                if len(size) >= 2:
6220                    with self.assertRaisesRegex(RuntimeError, error_message):
6221                        vmap(vmap(check_gte_0))(t_all_lt_0)
6222
6223                if t_all_gte_0.numel() > 1:
6224                    t_all_gte_0_but_one = t_all_gte_0.clone()
6225                    idx = (random.choice(range(dim_size)) for dim_size in size)
6226                    t_all_gte_0_but_one[(..., *idx)] = -1
6227
6228                    with self.assertRaisesRegex(RuntimeError, error_message):
6229                        vmap(check_gte_0)(t_all_gte_0_but_one)
6230
6231                    if len(size) >= 2:
6232                        with self.assertRaisesRegex(RuntimeError, error_message):
6233                            vmap(vmap(check_gte_0))(t_all_gte_0_but_one)
6234
6235        check_vmap_fallback(self, test, torch._test_check_tensor)
6236
6237
6238@markDynamoStrictTest
6239class TestVmapNestedTensor(Namespace.TestVmapBase):
6240    def _vmap_test(self, *args, **kwargs):
6241        return _vmap_test(self, *args, **kwargs)
6242
6243    # dims should be something like [5, None, 10], with None indicating that a
6244    # random ragged structure should be used
6245    def _create_nt(self, dims, device):
6246        sizes = [
6247            [
6248                d if d is not None else torch.randint(2, 10, size=(1,)).item()
6249                for d in dims[1:]
6250            ]
6251            for d in range(dims[0])
6252        ]
6253        return torch.nested.nested_tensor(
6254            [torch.randn(*size) for size in sizes], device=device
6255        )
6256
6257    # Creates an NT matching another NT's number of components and
6258    # shape / ragged structure for all dims specified to be -1.
6259    def _nt_from_similar(self, other, dims):
6260        assert len(dims) == other.dim()
6261        assert dims[0] == -1 or dims[0] == other.size(0)
6262
6263        ret_sizes = []
6264        for t in other.unbind():
6265            other_size = t.shape
6266            ret_size = []
6267            for i, d in enumerate(dims[1:]):
6268                if d == -1:
6269                    ret_size.append(other_size[i])
6270                else:
6271                    ret_size.append(d)
6272            ret_sizes.append(ret_size)
6273
6274        return torch.nested.nested_tensor(
6275            [torch.randn(*size) for size in ret_sizes], device=other.device
6276        )
6277
6278    @allowVmapFallbackUsage
6279    def test_fallback_unary(self, device):
6280        def f(x):
6281            return x.sin() * 5.0 + 4.0
6282
6283        nt = self._create_nt([4, None, 3], device=device)
6284        self._vmap_test(f, (nt,))
6285
6286    @allowVmapFallbackUsage
6287    def test_fallback_binary(self, device):
6288        def f(x, y):
6289            return x @ y
6290
6291        x = self._create_nt([5, None, 3], device=device)
6292        y = self._create_nt([5, 3, None], device=device)
6293        self._vmap_test(f, (x, y))
6294
6295    @allowVmapFallbackUsage
6296    def test_fallback_binary_nt_and_unbatched_dense(self, device):
6297        def f(x, y):
6298            return x @ y
6299
6300        x = self._create_nt([5, None, 3], device=device)
6301        y = torch.randn(3, 4, device=device)
6302        self._vmap_test(f, (x, y), in_dims=(0, None))
6303
6304    @allowVmapFallbackUsage
6305    def test_fallback_binary_nt_and_batched_dense(self, device):
6306        def f(x, y):
6307            return x @ y
6308
6309        x = self._create_nt([5, None, 3], device=device)
6310        y = torch.randn(5, 3, 4, device=device)
6311        self._vmap_test(f, (x, y))
6312
6313    def test_nt_acts_as_dense_in_vmap(self, device):
6314        def f(x):
6315            assert not x.is_nested
6316            return x
6317
6318        x = self._create_nt([5, None, 3], device=device)
6319        self._vmap_test(f, (x,))
6320
6321    def test_cat_batching_rule(self, device):
6322        def f(x, y, dim):
6323            return torch.cat([x, y], dim=dim)
6324
6325        # Different nested structure, same other dims
6326        x = self._create_nt([3, None, 2], device=device)
6327        y = self._create_nt([3, None, 2], device=device)
6328        self._vmap_test(functools.partial(f, dim=0), (x, y))
6329
6330        x = self._create_nt([3, 2, None], device=device)
6331        y = self._create_nt([3, 2, None], device=device)
6332        self._vmap_test(functools.partial(f, dim=1), (x, y))
6333
6334        # Same nested structure, different other dims
6335        x = self._create_nt([3, 2, None], device=device)
6336        y = self._nt_from_similar(x, [-1, 4, -1])
6337        self._vmap_test(functools.partial(f, dim=0), (x, y))
6338
6339        x = self._create_nt([3, None, 2], device=device)
6340        y = self._nt_from_similar(x, [-1, -1, 4])
6341        self._vmap_test(functools.partial(f, dim=1), (x, y))
6342
6343    # .shape calls don't work on NTs
6344    # TODO: Fix this somehow?
6345    @unittest.expectedFailure
6346    def test_shape_call(self, device):
6347        def f(x):
6348            x.shape[0]
6349            return x
6350
6351        x = self._create_nt([3, None, 2])
6352        self._vmap_test(f, (x,))
6353
6354    def test_nt_with_nonzero_in_dim_raises(self, device):
6355        def f(x):
6356            return x
6357
6358        x = self._create_nt([3, None, 2], device=device)
6359        with self.assertRaisesRegex(
6360            RuntimeError, "Nested tensors can only be vmapped over dim=0"
6361        ):
6362            vmap(f, in_dims=2)(x)
6363
6364    def test_nt_with_nonzero_out_dim_raises(self, device):
6365        def f(x):
6366            return x
6367
6368        x = self._create_nt([3, None, 2], device=device)
6369        with self.assertRaisesRegex(
6370            RuntimeError, "Nested tensors can only be vmapped over dim=0"
6371        ):
6372            vmap(f, out_dims=2)(x)
6373
6374    def test_fallback_with_nt_and_batched_dense_with_nonzero_bdim_raises(self, device):
6375        def f(x, y):
6376            return x @ y
6377
6378        x = self._create_nt([5, None, 3], device=device)
6379        y = torch.randn(3, 5, 4, device=device)
6380
6381        with self.assertRaisesRegex(
6382            RuntimeError,
6383            "Fallback not supported for mixed nested / non-nested arguments without bdim=0",
6384        ):
6385            vmap(f, in_dims=(0, 1))(x, y)
6386
6387    def test_multilevel_vmap_raises(self, device):
6388        def f(x):
6389            return x.sin() * 4.0 + 3.0
6390
6391        x = self._create_nt([2, 2, 2, None], device=device)
6392
6393        with self.assertRaisesRegex(
6394            RuntimeError, "Only one level of vmap is supported"
6395        ):
6396            vmap(vmap(f))(x)
6397
6398        with self.assertRaisesRegex(
6399            RuntimeError, "Only one level of vmap is supported"
6400        ):
6401            vmap(vmap(vmap(f)))(x)
6402
6403
6404only_for = ("cpu", "cuda")
6405instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
6406
6407instantiate_device_type_tests(
6408    TestVmapBatchedGradient,
6409    globals(),
6410    only_for=only_for,
6411)
6412instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for)
6413instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for)
6414instantiate_device_type_tests(TestVmapDeviceType, globals(), only_for=only_for)
6415instantiate_device_type_tests(TestVmapNestedTensor, globals(), only_for=only_for)
6416
6417if __name__ == "__main__":
6418    run_tests()
6419