xref: /aosp_15_r20/external/pytorch/test/test_legacy_vmap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: vmap"]
2
3import functools
4import itertools
5import types
6import warnings
7
8import torch
9import torch.nn.functional as F
10from torch import Tensor
11from torch._vmap_internals import vmap
12from torch.testing._internal.common_device_type import instantiate_device_type_tests
13from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
14
15
16FALLBACK_REGEX = r"There is a performance drop"
17
18
19class EnableVmapFallbackWarnings:
20    def __enter__(self):
21        self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled()
22        torch._C._debug_only_display_vmap_fallback_warnings(True)
23
24    def __exit__(self, *ignored):
25        torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state)
26
27
28class TestVmapAPILegacy(TestCase):
29    def test_non_tensor_output_raises(self):
30        with self.assertRaisesRegex(
31            ValueError, "got type <class 'float'> as the return"
32        ):
33            output = vmap(lambda x: 3.14)(torch.ones(3))
34
35        def multiple_outputs(x):
36            return x, 3
37
38        with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
39            vmap(multiple_outputs)(torch.ones(3))
40
41    def test_different_map_dim_size_raises(self):
42        x = torch.randn(2)
43        y = torch.randn(3)
44        expected_msg = (
45            "Expected all tensors to have the same size in the mapped dimension"
46        )
47        with self.assertRaisesRegex(ValueError, expected_msg):
48            vmap(torch.mul)(x, y)
49        with self.assertRaisesRegex(ValueError, expected_msg):
50            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
51        with self.assertRaisesRegex(ValueError, expected_msg):
52            vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
53                {"x": x, "y": y}
54            )
55
56    def test_func_with_no_inputs(self):
57        expected_msg = "got no inputs"
58
59        def foo():
60            return torch.randn(3)
61
62        def bar(x):
63            return torch.randn(3)
64
65        with self.assertRaisesRegex(ValueError, expected_msg):
66            vmap(foo)()
67
68        with self.assertRaisesRegex(ValueError, expected_msg):
69            vmap(bar)()
70
71    def test_constant_function(self):
72        output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
73        self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
74
75    def test_single_input(self):
76        x = torch.randn(2, 3)
77
78        def square(x):
79            return x * x
80
81        output = vmap(square)(x)
82        self.assertEqual(output, x * x)
83
84    def test_multiple_inputs(self):
85        x = torch.randn(2, 3)
86        y = torch.randn(2, 3)
87        output = vmap(torch.mul)(x, y)
88        self.assertEqual(output, x * y)
89
90    def test_multiple_outputs(self):
91        def foo(x):
92            return x * x, x * x * x
93
94        x = torch.randn(3)
95        outputs = vmap(foo)(x)
96        self.assertEqual(outputs[0], x * x)
97        self.assertEqual(outputs[1], x * x * x)
98
99    def test_multiple_outputs_error_cases(self):
100        # This is the same thing as
101        # def returns_tuple_of_tensors(x):
102        #     return x, x
103        def returns_tuple_of_tensors(x):
104            return (x, x)
105
106        def returns_list_of_two_tensors(x):
107            return [x, x]
108
109        def returns_list_of_one_tensor(x):
110            return [x]
111
112        x = torch.randn(3)
113
114        # should not throw
115        vmap(returns_tuple_of_tensors)(x)
116
117        # jax supports these, but we don't yet
118        msg = "must only return Tensors, got type <class 'list'>"
119        with self.assertRaisesRegex(ValueError, msg):
120            vmap(returns_list_of_two_tensors)(x)
121        with self.assertRaisesRegex(ValueError, msg):
122            vmap(returns_list_of_one_tensor)(x)
123
124    def test_nested_with_same_map_dim(self):
125        x = torch.randn(2, 3, 5)
126        y = torch.randn(2, 3, 5)
127        output = vmap(vmap(torch.mul))(x, y)
128        self.assertEqual(output, x * y)
129
130        output = vmap(vmap(vmap(torch.mul)))(x, y)
131        self.assertEqual(output, x * y)
132
133    def test_nested_with_different_map_dim(self):
134        x = torch.randn(2, 3)
135        y = torch.randn(5, 3)
136        output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
137        self.assertEqual(output.shape, (2, 5, 3))
138        self.assertEqual(output, x.view(2, 1, 3) * y)
139
140        z = torch.randn(7, 3)
141        output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
142        self.assertEqual(output.shape, (2, 5, 7, 3))
143        self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
144
145    def test_noop_in_inner_vmap(self):
146        x = torch.randn(3)
147        y = torch.randn(5)
148        output = vmap(lambda x: vmap(lambda y: x)(y))(x)
149        self.assertEqual(output, x.view(3, 1).expand(3, 5))
150
151    def test_unsupported_op_err_msg(self):
152        # Unsupported view op
153        tensor = torch.randn(2, 3)
154        msg = (
155            r"Batching rule not implemented for aten::.+; the "
156            r"fallback path doesn't work on out= or view ops"
157        )
158        with self.assertRaisesRegex(RuntimeError, msg):
159            vmap(torch.ravel)(tensor)
160
161        def out_op(x, y):
162            return torch.abs(x, out=y)
163
164        with self.assertRaisesRegex(RuntimeError, msg):
165            vmap(out_op)(tensor, tensor)
166
167        tensor = torch.randn(2)
168        # The fallback doesn't support TensorList
169        with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"):
170            vmap(lambda t: torch.atleast_1d([t]))(tensor)
171
172        # Don't support non-tensor returns. This is a limitation of vmap;
173        # functions that don't return tensors must be special cased
174        with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"):
175            vmap(torch.Tensor.item)(tensor)
176
177    def test_nonzero_out_dims(self):
178        # Basic test
179        tensor = torch.randn(2, 3)
180        result = vmap(lambda x: x, out_dims=1)(tensor)
181        self.assertEqual(result, tensor.permute(1, 0))
182        self.assertEqual(result.data_ptr(), tensor.data_ptr())
183
184        # Test that the batch dimension gets permuted to dim 2
185        tensor = torch.randn(2, 3, 5, 7)
186        result = vmap(lambda x: x, out_dims=2)(tensor)
187        self.assertEqual(result, tensor.permute(1, 2, 0, 3))
188        self.assertEqual(result.data_ptr(), tensor.data_ptr())
189
190        # negative out_dim
191        tensor = torch.randn(2, 3, 5, 7)
192        result = vmap(lambda x: x, out_dims=-1)(tensor)
193        self.assertEqual(result, tensor.permute(1, 2, 3, 0))
194        self.assertEqual(result.data_ptr(), tensor.data_ptr())
195
196        # check that out_dims works on ALL outputs
197        tensor = torch.randn(2, 3, 5, 7)
198        other = torch.randn(2, 3, 5, 7)
199        result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
200        self.assertEqual(
201            result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))
202        )
203
204        # use out_dims with the maximum vmap-able tensor dims (64 dims)
205        ndims = 64
206        shape = [2] + [1] * (ndims - 1)
207        expected_shape = [1, 1, 2] + [1] * (ndims - 3)
208        tensor = torch.randn(shape)
209        result = vmap(lambda x: x, out_dims=2)(tensor)
210        self.assertEqual(result.shape, expected_shape)
211
212        # test something that is not the identity function
213        def foo(x, y):
214            return x, x * y, x * y * y
215
216        x = torch.randn(2, 3, 5)
217        y = torch.randn(2, 3, 5)
218        result = vmap(foo, out_dims=1)(x, y)
219        self.assertEqual(
220            result,
221            (
222                x.permute(1, 0, 2),
223                (x * y).permute(1, 0, 2),
224                (x * y * y).permute(1, 0, 2),
225            ),
226        )
227
228    def test_multiple_out_dims(self):
229        def foo(x):
230            return x, x
231
232        def bar(x, y):
233            return x, x, x, x * y
234
235        x = torch.randn(2, 3, 5)
236        y = torch.randn(2, 3, 5)
237        result = vmap(foo, out_dims=(0, 1))(x)
238        self.assertEqual(result, (x, x.permute(1, 0, 2)))
239
240        result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
241        expected = (
242            x.permute(1, 2, 0),
243            x,
244            x.permute(1, 0, 2),
245            (x * y).permute(1, 2, 0),
246        )
247        self.assertEqual(result, expected)
248
249    def test_nested_out_dims(self):
250        y = torch.randn(2, 3, 5, 7)
251
252        # Inner vmap has non-zero out_dim
253        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
254        self.assertEqual(result.shape, (2, 5, 3, 7))
255        self.assertEqual(result, y.permute(0, 2, 1, 3))
256
257        # all vmaps have non-zero out_dim
258        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
259        self.assertEqual(result.shape, (5, 2, 3, 7))
260        self.assertEqual(result, y.permute(2, 0, 1, 3))
261
262        # throwing in some negative out_dims
263        result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
264        self.assertEqual(result.shape, (5, 7, 3, 2))
265        self.assertEqual(result, y.permute(2, 3, 1, 0))
266
267        # testing fn that isn't the identity
268        x = torch.randn(2, 3)
269        y = torch.randn(5, 3)
270        result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
271        self.assertEqual(result.shape, (3, 2, 5))
272        self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
273
274    def test_out_dims_edge_case(self):
275        def foo(x):
276            return x
277
278        # Test that we accept out_dims=(1,) for a function with one output.
279        tensor = torch.randn(2, 3)
280        expected = vmap(foo, out_dims=1)(tensor)
281        result = vmap(foo, out_dims=(1,))(tensor)
282        self.assertEqual(result, expected)
283
284    def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
285        msg = "`out_dims` must be an int or a tuple of int"
286        tensor = torch.randn(2, 3)
287        with self.assertRaisesRegex(ValueError, msg):
288            vmap(lambda x: x, out_dims="lol")(tensor)
289        with self.assertRaisesRegex(ValueError, msg):
290            vmap(lambda x: x, out_dims=("lol",))(tensor)
291        with self.assertRaisesRegex(ValueError, msg):
292            vmap(lambda x: x, out_dims=None)(tensor)
293        with self.assertRaisesRegex(ValueError, msg):
294            vmap(lambda x: x, out_dims=(None,))(tensor)
295
296    def test_out_dims_and_num_outputs_mismatch_err_msg(self):
297        msg = "`out_dims` must have one dim per output"
298        x = torch.randn(2, 3, 5)
299
300        # Too many out_dims
301        with self.assertRaisesRegex(ValueError, msg):
302            vmap(lambda x: x, out_dims=(0, 0))(x)
303        with self.assertRaisesRegex(ValueError, msg):
304            vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
305
306        # Too few out_dims
307        with self.assertRaisesRegex(ValueError, msg):
308            vmap(lambda x: (x, x), out_dims=(0,))(x)
309        with self.assertRaisesRegex(ValueError, msg):
310            vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
311
312    def test_out_dim_out_of_bounds_err_msg(self):
313        # TODO(rzou): This error message isn't that great. It comes straight
314        # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
315        # the error message in the future in C++
316        msg = "Dimension out of range"
317        x = torch.randn(2, 3, 5)
318        with self.assertRaisesRegex(IndexError, msg):
319            vmap(lambda x: x, out_dims=3)(x)
320        with self.assertRaisesRegex(IndexError, msg):
321            vmap(lambda x: x, out_dims=-4)(x)
322
323    def test_non_zero_in_dims(self):
324        tensor = torch.randn(2, 3, 5)
325
326        # Implicit out_dims = 0; vmap will move the batch dim to the front.
327        output = vmap(lambda x: x, (1,))(tensor)
328        self.assertEqual(output, tensor.permute(1, 0, 2))
329        self.assertEqual(output.data_ptr(), tensor.data_ptr())
330
331        x = torch.randn(2, 3)
332        y = torch.randn(3, 2)
333        output = vmap(torch.mul, (0, 1))(x, y)
334        self.assertEqual(output, x * y.t())
335        output = vmap(torch.mul, (1, 0))(x, y)
336        self.assertEqual(output, x.t() * y)
337
338    def test_none_in_dims(self):
339        x = torch.randn(2, 3)
340        y = torch.randn(2, 3)
341
342        # None in_dim for a Tensor means we don't map over it
343        output = vmap(torch.mul, (0, None))(x, y)
344        self.assertEqual(output.shape, (2, 2, 3))
345        self.assertEqual(output, x.view(2, 1, 3) * y)
346
347        # None in_dim for non-tensor arguments
348        output = vmap(torch.mul, (0, None))(x, 2)
349        self.assertEqual(output, x * 2)
350
351    def test_nested_non_default_in_dims(self):
352        x = torch.rand(5, 2, 3)
353        y = torch.rand(3, 5, 2)
354        result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
355        self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
356
357    def test_non_default_in_dims_out_dims(self):
358        x = torch.randn(2, 3, 5)
359
360        # Same in_dim as out_dim, vmap over identity
361        result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
362        self.assertEqual(result, x)
363        self.assertEqual(result.data_ptr(), x.data_ptr())
364
365        # Different in_dim from out_dim, vmap over identity
366        result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
367        self.assertEqual(result.shape, (2, 5, 3))
368        self.assertEqual(result, x.transpose(1, 2))
369        self.assertEqual(result.data_ptr(), x.data_ptr())
370
371        def foo(x):
372            return x * 2
373
374        # Same in_dim as out_dim, vmap over operation
375        result = vmap(foo, in_dims=1, out_dims=1)(x)
376        self.assertEqual(result, x * 2)
377
378        # Different in_dim as out_dim, vmap over operation
379        result = vmap(foo, in_dims=2, out_dims=1)(x)
380        self.assertEqual(result.shape, (2, 5, 3))
381        self.assertEqual(result, (x * 2).transpose(1, 2))
382
383        # Basic nested test.
384        result = vmap(vmap(foo, 1, 1), 1, 1)(x)
385        self.assertEqual(result, x * 2)
386
387    def test_accepts_nested_inputs(self):
388        B0 = 2
389        x = torch.randn(2, 3)
390        y = torch.randn(2, 3)
391
392        # Single layer of nesting
393        out = vmap(lambda z: z[0] + z[1])((x, y))
394        self.assertEqual(out, x + y)
395        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
396        self.assertEqual(out, x + y)
397        out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
398        self.assertEqual(out, x + y)
399
400        out = vmap(lambda z: z[0] + z[1])([x, y])
401        self.assertEqual(out, x + y)
402        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
403        self.assertEqual(out, x + y)
404        out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
405        self.assertEqual(out, x + y)
406
407        out = vmap(lambda z: z["x"] + z["y"])({"x": x, "y": y})
408        self.assertEqual(out, x + y)
409        out = vmap(lambda z: z["x"] + z["y"], in_dims=(0,))({"x": x, "y": y})
410        self.assertEqual(out, x + y)
411        out = vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
412            {"x": x, "y": y}
413        )
414        self.assertEqual(out, x + y)
415
416        # Multiple layers of nesting
417        out_fn = vmap(lambda z: z["x"][0] + z["x"][1][0] + z["y"][0] + z["y"][1])
418        out = out_fn({"x": [x, (x,)], "y": [y, y]})
419        self.assertEqual(out, x + x + y + y)
420
421    def test_in_dims_wrong_type_err_msg(self):
422        x = torch.randn(3)
423        y = torch.randn(3)
424        msg = r"expected `in_dims` to be int or a \(potentially nested\) tuple"
425        with self.assertRaisesRegex(ValueError, msg):
426            vmap(torch.mul, [0, 0])(x, y)
427        with self.assertRaisesRegex(ValueError, msg):
428            vmap(torch.mul, set({0}))(x, y)
429        with self.assertRaisesRegex(ValueError, msg):
430            vmap(torch.mul, "lol")(x, y)
431        with self.assertRaisesRegex(ValueError, msg):
432            vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
433        # The following should not throw
434        vmap(torch.mul, (0, 0))(x, y)
435
436    def test_not_enough_in_dims_err_msg(self):
437        x = torch.randn(3)
438        y = torch.randn(3)
439        msg = r"in_dims is not compatible with the structure of `inputs`"
440
441        with self.assertRaisesRegex(ValueError, msg):
442            vmap(torch.mul, (0,))(x, y)
443        with self.assertRaisesRegex(ValueError, msg):
444            vmap(torch.mul, (0, 0, 0))(x, y)
445        with self.assertRaisesRegex(ValueError, msg):
446            vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
447        with self.assertRaisesRegex(ValueError, msg):
448            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
449        # The following should not throw
450        vmap(torch.mul, (0, 0))(x, y)
451
452    def test_integer_in_dim_but_not_tensor_input_err_msg(self):
453        def foo(xy):
454            return xy[0] * xy[1]
455
456        def bar(x, yz):
457            return x * yz[0] * yz[1]
458
459        x = torch.randn(2, 3)
460        y = torch.randn(2, 3)
461
462        # the following are errors in jax (and will always be errors)
463        msg = "Got in_dim=0 for an input but the input is of type"
464        with self.assertRaisesRegex(ValueError, msg):
465            vmap(torch.sum)(x, 0)
466        with self.assertRaisesRegex(ValueError, msg):
467            vmap(torch.sum, (0, 0))(x, 0)
468        with self.assertRaisesRegex(ValueError, msg):
469            vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
470        # The following should not throw
471        vmap(torch.sum, (0, None))(x, 0)
472
473    def test_in_dim_not_in_tensor_err_msg(self):
474        def foo(x):
475            return x * x
476
477        x = torch.randn(2, 3)
478        y = torch.randn(2, 3)
479
480        msg = r"Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w"
481        with self.assertRaisesRegex(ValueError, msg):
482            vmap(foo)(torch.randn([]))
483        with self.assertRaisesRegex(ValueError, msg):
484            vmap(foo, in_dims=(0,))(torch.randn([]))
485        with self.assertRaisesRegex(ValueError, msg):
486            vmap(foo, in_dims=(-1,))(x)
487        with self.assertRaisesRegex(ValueError, msg):
488            vmap(foo, in_dims=(2,))(y)
489        with self.assertRaisesRegex(ValueError, msg):
490            vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
491        # the following should not throw
492        vmap(foo, in_dims=(0,))(torch.randn(2, 3))
493        vmap(foo, in_dims=(1,))(torch.randn(2, 3))
494
495    def test_fallback_does_not_warn_by_default(self):
496        # NB: One day we will implement a batching rule for torch.atan2.
497        # If/when we do, this test should be replaced to test the fallback
498        # path on another operator to avoid bitrot.
499        op = torch.atan2
500        x = torch.randn(11)
501        y = torch.randn(11)
502        with warnings.catch_warnings(record=True) as wa:
503            result = vmap(op)(x, y)
504            # The single warning here is the "vmap is experimental"
505            # warning, not a warning from the vmap fallback path.
506            self.assertEqual(len(wa), 1)
507
508    def test_fallback_warns_when_warnings_are_enabled(self):
509        # NB: One day we will implement a batching rule for torch.atan2.
510        # If/when we do, this test should be replaced to test the fallback
511        # path on another operator to avoid bitrot.
512        op = torch.atan2
513        x = torch.randn(11)
514        y = torch.randn(11)
515        with warnings.catch_warnings(record=True) as wa:
516            with EnableVmapFallbackWarnings():
517                result = vmap(op)(x, y)
518            self.assertEqual(len(wa), 2)
519            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
520
521    def _assert_uses_vmap_fallback(self, vmap_args, inputs):
522        with warnings.catch_warnings(record=True) as wa:
523            with EnableVmapFallbackWarnings():
524                result = vmap(*vmap_args)(*inputs)
525            self.assertEqual(len(wa), 2)
526            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
527
528    def test_fallback_zero_dim(self):
529        # NB: One day we will implement a batching rule for torch.atan2.
530        # If/when we do, this test should be replaced to test the fallback
531        # path on another operator to avoid bitrot.
532        op = torch.atan2
533        x = torch.randn(11)
534        y = torch.randn(11)
535        self._assert_uses_vmap_fallback((op,), (x, y))
536
537        B0, B1 = 0, 3
538        x = torch.randn(B0, 11)
539        y = torch.randn(11)
540
541        msg = "The fallback path does not support vmap over dims of size 0"
542
543        with self.assertRaisesRegex(RuntimeError, msg):
544            vmap(op, (0, None))(x, y)
545        with self.assertRaisesRegex(RuntimeError, msg):
546            vmap(op, (None, 0))(y, x)
547        with self.assertRaisesRegex(RuntimeError, msg):
548            vmap(op)(x, x)
549
550        x = torch.randn(B0, B1, 11)
551        y = torch.randn(B1, 11)
552        with self.assertRaisesRegex(RuntimeError, msg):
553            vmap(op, (0, None))(x, y)
554        with self.assertRaisesRegex(RuntimeError, msg):
555            vmap(op, (None, 0))(y, x)
556        with self.assertRaisesRegex(RuntimeError, msg):
557            vmap(op)(x, x)
558
559    def test_fallback_atan2(self):
560        # NB: One day we will implement a batching rule for torch.atan2.
561        # If/when we do, this test should be replaced to test the fallback
562        # path on another operator to avoid bitrot.
563        op = torch.atan2
564
565        x = torch.randn(5, 7, 11)
566        y = torch.randn(5, 7, 11)
567
568        self._assert_uses_vmap_fallback((op,), (x, y))
569
570        # fallback on torch.atan2
571        x = torch.randn(7, 11, 5)
572        y = torch.randn(5, 7, 11)
573        result = vmap(op, (2, 0))(x, y)
574        self.assertEqual(result, op(x.permute(2, 0, 1), y))
575
576        # fallback on torch.atan2, nested vmap
577        x = torch.randn(7, 11, 5)
578        y = torch.randn(5, 7, 11)
579        result = vmap(vmap(op), (2, 0))(x, y)
580        self.assertEqual(result, op(x.permute(2, 0, 1), y))
581
582        # big batch size (total 10000)
583        x = torch.randn(100, 10, 10, 5)
584        y = torch.randn(100, 10, 10)
585        result = vmap(vmap(vmap(op)))(x, y)
586        self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))
587
588    def test_fallback_masked_fill(self):
589        # NB: One day we will implement a batching rule for masked_fill
590        # If/when we do, this test should be replaced to test the fallback
591        # path on another operator to avoid bitrot.
592        def run_test(batch_size):
593            B0 = batch_size
594            x = torch.randn(B0, 7, 11, 13)
595            dim = 0
596            index = torch.tensor([0, 4, 2])
597            values = torch.randn(B0, 3, 11, 13)
598
599            self._assert_uses_vmap_fallback(
600                (torch.index_add, (0, None, None, 0)), (x, dim, index, values)
601            )
602
603            result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
604            expected = torch.index_add(x, dim + 1, index, values.view(B0, 3, 11, 13))
605            self.assertEqual(result, expected)
606
607        run_test(batch_size=5)
608        run_test(batch_size=1237)
609
610    def test_fallback_multiple_returns(self):
611        # NB: One day we will implement a batching rule for torch.var_mean
612        # If/when we do, this test should be replaced to test the fallback
613        # path on another operator to avoid bitrot.
614        B0, B1, B2 = 2, 3, 1237
615        tensor = torch.randn(B0, 10)
616
617        self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
618
619        # fallback correctness on torch.var_mean
620        result = vmap(torch.var_mean)(tensor)
621        expected = torch.var_mean(tensor, dim=1)
622        self.assertEqual(result, expected)
623
624        # nested vmap
625        tensor = torch.randn(B0, B1, 10)
626        result = vmap(vmap(torch.var_mean))(tensor)
627        expected = torch.var_mean(tensor, dim=2)
628        self.assertEqual(result, expected)
629
630        # big batch size, nested vmap
631        tensor = torch.randn(B0, B1, B2, 10)
632        result = vmap(vmap(vmap(torch.var_mean)))(tensor)
633        expected = torch.var_mean(tensor, dim=3)
634        self.assertEqual(result, expected)
635
636    def test_inplace_fallback_unary(self):
637        # Test the in-place fallback on an in-place method that takes no
638        # additional Tensor arguments. This is the simplest case of the fallback.
639        # NB: One day we will implement a batching rule for acos_.
640        # If/when we do, this test should be replaced to test the fallback
641        # path on another operator to avoid bitrot.
642        op = Tensor.acos_
643        B0, B1, B2 = 2, 3, 10000
644
645        x = torch.randn(B0, 5)
646        self._assert_uses_vmap_fallback((op,), (x,))
647
648        # Single vmap
649        x_orig = torch.rand(B0, 5)
650        x = x_orig.clone()
651        result = vmap(op)(x)
652        self.assertTrue(result is x)
653        self.assertEqual(result, x_orig.acos())
654
655        # Single vmap + different out_dim produces a view(!)
656        x_orig = torch.rand(B0, 5)
657        x = x_orig.clone()
658        result = vmap(op, out_dims=(1,))(x)
659        self.assertTrue(result._base is x)
660        self.assertEqual(result, x_orig.t().acos())
661
662        # Nested vmap
663        x_orig = torch.randn(B0, B1, 5)
664        x = x_orig.clone()
665        result = vmap(vmap(op))(x)
666        self.assertTrue(result is x)
667        self.assertEqual(result, x_orig.acos())
668
669        # Nested vmap, large batch size
670        x_orig = torch.randn(B0, B1, B2, 5)
671        x = x_orig.clone()
672        result = vmap(vmap(vmap(op)))(x)
673        self.assertTrue(result is x)
674        self.assertEqual(result, x_orig.acos())
675
676    def test_inplace_fallback_nary_same_levels(self):
677        # NB: One day we will implement a batching rule for atan2_
678        # If/when we do, this test should be replaced to test the fallback
679        # path on another operator to avoid bitrot.
680        op = Tensor.atan2_
681        outplace_op = torch.atan2
682
683        x = torch.randn(5, 7, 11)
684        y = torch.randn(5, 7, 11)
685        self._assert_uses_vmap_fallback((op,), (x, y))
686
687        # Single vmap
688        B0 = 5
689        x_orig = torch.randn(7, 11, B0)
690        x = x_orig.clone()
691        y = torch.randn(B0, 7, 11)
692        vmap(op, (2, 0))(x, y)
693        self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))
694
695        # Nested vmap
696        B0, B1 = 5, 7
697        x_orig = torch.randn(B1, 11, B0)
698        x = x_orig.clone()
699        y = torch.randn(B0, B1, 11)
700        vmap(vmap(op), (2, 0))(x, y)
701        self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))
702
703        # big batch size (total 10000)
704        B0, B1, B2 = 100, 10, 10
705        x_orig = torch.randn(B0, B1, B2, 5)
706        x = x_orig.clone()
707        y = torch.randn(B0, B1, B2)
708        result = vmap(vmap(vmap(op)))(x, y)
709        self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))
710
711    def test_inplace_fallback_nary_different_levels(self):
712        # NB: One day we will implement a batching rule for atan2_
713        # If/when we do, this test should be replaced to test the fallback
714        # path on another operator to avoid bitrot.
715        op = Tensor.atan2_
716        outplace_op = torch.atan2
717        B0, B1, B2 = 2, 3, 5
718
719        x = torch.rand(B0, 7)
720        y = torch.rand(7)
721        self._assert_uses_vmap_fallback((op, (0, None)), (x, y))
722
723        # op(left, right): All of the levels in right are found in left
724        x_orig = torch.rand(B0, 7)
725        x = x_orig.clone()
726        y = torch.rand(7)
727        vmap(op, in_dims=(0, None))(x, y)
728        self.assertEqual(x, outplace_op(x_orig, y))
729
730        x_orig = torch.rand(B0, B1, 7)
731        x = x_orig.clone()
732        y = torch.rand(B0, 7)
733        vmap(vmap(op, in_dims=(0, None)))(x, y)
734        self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))
735
736        # op(left, right): Some of the levels in right are not found in left
737        msg = r"vmap: aten::atan2_\(self, \*extra_args\) is not possible"
738        x = torch.rand(7)
739        y = torch.rand(B0, 7)
740        with self.assertRaisesRegex(RuntimeError, msg):
741            vmap(op, in_dims=(None, 0))(x, y)
742
743        x = torch.rand(B1, 7)
744        y = torch.rand(B0, 7)
745        with self.assertRaisesRegex(RuntimeError, msg):
746            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)
747
748        x = torch.rand(B1, 7)
749        y = torch.rand(7, B0)
750        with self.assertRaisesRegex(RuntimeError, msg):
751            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)
752
753        x = torch.rand(B0, 7)
754        y = torch.rand(B0, B1, 7)
755        with self.assertRaisesRegex(RuntimeError, msg):
756            vmap(vmap(op, in_dims=(None, 0)))(x, y)
757
758    def test_backward_unsupported_interaction(self):
759        x = torch.randn(3, requires_grad=True)
760        y = torch.randn(5)
761        grad = torch.randn_like(x)
762        err_msg = r"backward\(\) called inside torch.vmap"
763
764        def backward_on_vmapped_tensor(x):
765            x.sum().backward()
766
767        with self.assertRaisesRegex(RuntimeError, err_msg):
768            vmap(backward_on_vmapped_tensor)(x)
769
770        def backward_with_vmapped_grad(x, grad):
771            x.backward(grad)
772
773        with self.assertRaisesRegex(RuntimeError, err_msg):
774            vmap(backward_with_vmapped_grad)(x, grad)
775
776        def completely_unrelated_backward(y):
777            x.sum().backward()
778
779        with self.assertRaisesRegex(RuntimeError, err_msg):
780            vmap(completely_unrelated_backward)(y)
781
782    def test_grad_unsupported_interaction(self):
783        input_tensor = torch.randn(3, requires_grad=True)
784        err_msg = "autograd.grad.* called inside torch.vmap"
785
786        captured = torch.randn(3, requires_grad=True)
787
788        def output_to_grad_is_vmapped(input_tensor):
789            output = (captured * input_tensor).sum()
790            return torch.autograd.grad([output], [captured])[0]
791
792        with self.assertRaisesRegex(RuntimeError, err_msg):
793            vmap(output_to_grad_is_vmapped)(input_tensor)
794
795        output = (input_tensor**2).sum()
796
797        def input_to_grad_is_vmapped(input_tensor):
798            return torch.autograd.grad([output], [input_tensor])[0]
799
800        with self.assertRaisesRegex(RuntimeError, err_msg):
801            vmap(input_to_grad_is_vmapped)(input_tensor)
802
803    def test_batched_gradient_basic(self):
804        N = 3
805        x = torch.randn(N, requires_grad=True)
806        y = torch.randn(N)
807
808        def vjp_mul(v):
809            return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
810
811        batched_v = torch.eye(N)
812        jacobian = vmap(vjp_mul)(batched_v)
813        self.assertEqual(jacobian, torch.diagflat(y))
814
815    def test_functools_partial(self):
816        x = torch.randn(3)
817        y = torch.randn(2, 3)
818        result = vmap(functools.partial(torch.mul, x))(y)
819        self.assertEqual(result, x * y)
820
821    def test_nn_module(self):
822        tensor = torch.randn(2, 3)
823        model = torch.nn.Linear(3, 3, bias=False)
824        result = vmap(model)(tensor)
825        self.assertEqual(result, model(tensor))
826
827    def test_fallback_with_undefined_grad(self):
828        B0 = 7
829        x = torch.randn(2, 3, 4, 5, requires_grad=True)
830        weight = torch.randn(3, 3, 1, 1)
831        v = torch.randn(B0, 2, 3, 4, 5)
832
833        def get_vjp(v):
834            result = torch.nn.functional.conv2d(x, weight)
835            (grad_x,) = torch.autograd.grad(result, x, v)
836            return grad_x
837
838        # Runs vmap(get_vjp)(v), which should not error out.
839        # The backward formula for convolution returns an undefined
840        # Tensor for grad_bias because the original bias does not exist.
841        #
842        # In the future we'll probably add a batching rule for convolution
843        # backward. When this happens, we should modify this test to use a
844        # different op (and/or create and use a dummy operator) to avoid bitrot.
845        self._assert_uses_vmap_fallback([get_vjp], [v])
846
847
848def slice_inputs(inputs, bdims, i):
849    result = []
850    for inp, bdim in zip(inputs, bdims):
851        if bdim is None:
852            result.append(inp)
853        else:
854            result.append(inp.select(bdim, i))
855    return tuple(result)
856
857
858def reference_vmap(op, inputs, in_dims=0, out_dims=0):
859    if isinstance(in_dims, int):
860        in_dims = (in_dims,) * len(inputs)
861    bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
862    assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
863    bdim_size = bdim_sizes[0]
864    results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
865
866    assert len(results) > 0
867    op_has_single_return = not isinstance(results[0], tuple)
868    if op_has_single_return:
869        assert all(isinstance(result, torch.Tensor) for result in results)
870        if isinstance(out_dims, int):
871            out_dims = (out_dims,) * 1
872        return torch.stack(results, dim=out_dims[0])
873
874    assert all(isinstance(result, tuple) for result in results)
875    num_returns = len(results[0])
876    assert all(len(result) == num_returns for result in results)
877    if isinstance(out_dims, int):
878        out_dims = (out_dims,) * num_returns
879    return tuple(
880        torch.stack(result_shards, out_dim)
881        for result_shards, out_dim in zip(zip(*results), out_dims)
882    )
883
884
885class TensorFactory:
886    @staticmethod
887    def rand(size, device="cpu", dtype=torch.float):
888        return torch.rand(size, device=device, dtype=dtype)
889
890    @staticmethod
891    def randn(size, device="cpu", dtype=torch.float):
892        return torch.randn(size, device=device, dtype=dtype)
893
894    @staticmethod
895    def randp1(size, device="cpu", dtype=torch.float):
896        return torch.rand(size, device=device, dtype=dtype) + 1
897
898
899# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
900# (slow) sequential map+stack fallback.
901#
902# check_view: Test if the first returned output is a view of the first input
903# check_propagates_grad: Test if the operation propagates gradients.
904def _vmap_test(
905    self,
906    op,
907    inputs,
908    in_dims=0,
909    out_dims=0,
910    check_view=False,
911    check_propagates_grad=True,
912):
913    result = vmap(op, in_dims, out_dims)(*inputs)
914    reference_result = reference_vmap(op, inputs, in_dims, out_dims)
915    self.assertEqual(result, reference_result)
916    op_has_single_return = not isinstance(result, tuple)
917
918    if check_view:
919        result_as_tuple = (result,) if op_has_single_return else result
920        for output in result_as_tuple:
921            input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
922            self.assertTrue(
923                output._base is input0_base,
924                msg="result was not a view of the first input!",
925            )
926
927    if not check_propagates_grad:
928        return
929    # Assuming input[0] is a floating-point tensor. Check if the vmap
930    # operation propagates the requires_grad flag to the zeroth output.
931    # Some vmap operators are implemented in a way that assumes that
932    # they are composite with respect to autograd. If the operator ever is
933    # changed to not be composite with respect to autograd, then the
934    # following check should fail.
935    inputs_clone = list(inputs)
936    inputs_clone[0] = inputs[0].clone().requires_grad_()
937    result = vmap(op, in_dims, out_dims)(*inputs_clone)
938    result_as_tuple = (result,) if op_has_single_return else result
939    self.assertTrue(result[0].requires_grad)
940
941
942def should_allow_vmap_fallback_usage(fn):
943    return getattr(fn, "_allow_vmap_fallback_usage", False)
944
945
946def allowVmapFallbackUsage(fn):
947    fn._allow_vmap_fallback_usage = True
948    return fn
949
950
951# All tests of TestVmapBaseLegacy check that the slow vmap fallback is never invoked.
952# This is so that we can incrementally add batching rules for operators to
953# replace the slow vmap fallback path for said operators. To skip this check,
954# please use the allowVmapFallbackUsage decorator.
955#
956# NB: Don't add tests to TestVmapBaseLegacy directly, unless you want them to run
957# on every subclass of TestVmapBaseLegacy. Add them to e.g. TestVmapOperators.
958#
959# NB: TestVmapBaseLegacy is a nested class. This prevents test runners from picking
960# it up and running it.
961class Namespace:
962    class TestVmapBaseLegacy(TestCase):
963        def __init__(self, method_name="runTest"):
964            super().__init__(method_name)
965
966            test_method = getattr(self, method_name, None)
967            if test_method is None:
968                return
969
970            if not should_allow_vmap_fallback_usage(test_method):
971                setattr(
972                    self,
973                    method_name,
974                    self._wrap_method_with_vmap_fallback_check(test_method),
975                )
976
977        def _wrap_method_with_vmap_fallback_check(self, method):
978            msg = (
979                "Expected the test to not invoke the vmap fallback path, i.e., "
980                "all of the operators being tested in this test should have batching "
981                "rules implemented. If you are intentionally testing something to "
982                "do with the fallback path, use allowVmapFallbackUsage. Otherwise, "
983                "please make sure that batching rules are implemented for the "
984                "operator(s) being tested."
985            )
986
987            @functools.wraps(method)
988            def wrapper(self, *args, **kwargs):
989                with warnings.catch_warnings(record=True) as wa:
990                    warnings.simplefilter("always")
991                    with EnableVmapFallbackWarnings():
992                        method(*args, **kwargs)
993                    for captured_warning in wa:
994                        self.assertNotRegex(
995                            str(captured_warning.message), FALLBACK_REGEX, msg
996                        )
997
998            return types.MethodType(wrapper, self)
999
1000        @allowVmapFallbackUsage
1001        def test_vmap_fallback_check_ok(self):
1002            # One day we'll implement a batching rule for torch.var_mean.
1003            # When that happens, please change the example to use an
1004            # operator that doesn't have a batching rule implemented.
1005            op_using_fallback = torch.var_mean
1006            vmap(op_using_fallback)(torch.rand(3))
1007
1008        def test_vmap_fallback_check(self):
1009            @self._wrap_method_with_vmap_fallback_check
1010            def no_fallback(self):
1011                pass
1012
1013            # One day we'll implement a batching rule for torch.var_mean.
1014            # When that happens, please change the example to use an
1015            # operator that doesn't have a batching rule implemented.
1016            op_using_fallback = torch.var_mean
1017
1018            @self._wrap_method_with_vmap_fallback_check
1019            def uses_fallback(self):
1020                vmap(op_using_fallback)(torch.rand(3))
1021
1022            no_fallback(self)
1023
1024            with self.assertRaises(AssertionError):
1025                uses_fallback(self)
1026
1027
1028class TestVmapOperatorsLegacy(Namespace.TestVmapBaseLegacy):
1029    def _vmap_test(self, *args, **kwargs):
1030        return _vmap_test(self, *args, **kwargs)
1031
1032    def _vmap_view_test(self, *args, **kwargs):
1033        self._vmap_test(*args, **kwargs, check_view=True)
1034
1035    def _test_unary(self, op, getter, device, *args, **kwargs):
1036        test = functools.partial(self._vmap_test, *args, **kwargs)
1037        B0, B1 = 7, 11
1038
1039        # Single vmap, various in_dims / out_dims
1040        test(op, [getter([B0, 3], device)])
1041        test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
1042        test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
1043
1044        # Doubly nested vmap
1045        test(vmap(op), [getter([B0, B1], device)])
1046        test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
1047        test(
1048            vmap(op, in_dims=2),
1049            [getter([2, 5, B0, B1, 3], device)],
1050            in_dims=2,
1051            out_dims=2,
1052        )
1053
1054    def test_unary_pointwise_ops(self):
1055        cases = [
1056            (torch.abs, TensorFactory.randn),
1057            (torch.acos, TensorFactory.rand),
1058            (torch.asin, TensorFactory.rand),
1059            (torch.atan, TensorFactory.rand),
1060            (torch.ceil, TensorFactory.randn),
1061            (torch.cos, TensorFactory.rand),
1062            (torch.cosh, TensorFactory.rand),
1063            (torch.digamma, TensorFactory.rand),
1064            (torch.exp, TensorFactory.randn),
1065            (torch.expm1, TensorFactory.randn),
1066            (torch.floor, TensorFactory.randn),
1067            (torch.frac, TensorFactory.randn),
1068            (torch.lgamma, TensorFactory.rand),
1069            (torch.log, TensorFactory.randp1),
1070            (torch.log10, TensorFactory.randp1),
1071            (torch.log1p, TensorFactory.randp1),
1072            (torch.log2, TensorFactory.randp1),
1073            (torch.neg, TensorFactory.randn),
1074            (torch.reciprocal, TensorFactory.randp1),
1075            (torch.relu, TensorFactory.randn),
1076            (torch.round, TensorFactory.randn),
1077            (torch.rsqrt, TensorFactory.randp1),
1078            (torch.sigmoid, TensorFactory.randn),
1079            (torch.sign, TensorFactory.randn),
1080            (torch.sin, TensorFactory.rand),
1081            (torch.sinh, TensorFactory.rand),
1082            (torch.sqrt, TensorFactory.rand),
1083            (torch.tan, TensorFactory.rand),
1084            (torch.tanh, TensorFactory.rand),
1085            (torch.trunc, TensorFactory.randn),
1086        ]
1087        for op, getter in cases:
1088            self._test_unary(op, getter, "cpu")
1089
1090    def test_clone(self):
1091        # Some basic tests
1092        self._test_unary(lambda x: x.clone(), TensorFactory.randn, "cpu")
1093        self._test_unary(
1094            lambda x: x.clone(memory_format=torch.preserve_format),
1095            TensorFactory.randn,
1096            "cpu",
1097        )
1098        self._test_unary(
1099            lambda x: x.clone(memory_format=torch.contiguous_format),
1100            TensorFactory.randn,
1101            "cpu",
1102        )
1103
1104        # Test that the per-examples are contiguous when using torch.contiguous_format
1105        def clone_contiguous(x):
1106            return x.clone(memory_format=torch.contiguous_format)
1107
1108        B0, B1 = 3, 5
1109        x = torch.randn(2, B0, 7)
1110        y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x)
1111        self.assertTrue(y.movedim(1, 0).is_contiguous())
1112        self.assertTrue(y[:, 0, :].is_contiguous())
1113
1114        x = torch.randn(2, B0, 7, B1)
1115        y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x)
1116        self.assertTrue(y.is_contiguous())
1117        self.assertTrue(y[0][0].is_contiguous())
1118
1119        msg = r"only supported with memory_format torch.preserve_format or torch.contiguous_format"
1120        with self.assertRaisesRegex(RuntimeError, msg):
1121            vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0))
1122        with self.assertRaisesRegex(RuntimeError, msg):
1123            vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(
1124                torch.randn(B0)
1125            )
1126
1127    def test_binary_pointwise_ops(self):
1128        def get_number(getter):
1129            return getter([]).item()
1130
1131        def make_case(op, input_getter=TensorFactory.randn):
1132            return (op, input_getter)
1133
1134        cases = [
1135            # Basic arithmetic
1136            make_case(torch.add),
1137            make_case(lambda x, y: x + y),
1138            make_case(torch.sub),
1139            make_case(lambda x, y: x - y),
1140            make_case(torch.mul),
1141            make_case(lambda x, y: x * y),
1142            make_case(torch.div, input_getter=TensorFactory.randp1),
1143            make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
1144            make_case(torch.pow, input_getter=TensorFactory.randp1),
1145            make_case(lambda x, y: x**y, input_getter=TensorFactory.randp1),
1146        ]
1147        test = self._vmap_test
1148
1149        for op, getter in cases:
1150            device = "cpu"
1151            B0, B1 = 7, 11
1152
1153            # Single vmap: op(Tensor, Tensor)
1154            test(op, (getter([B0, 3], device), getter([B0, 3], device)))
1155            test(op, (getter([B0], device), getter([B0, 2, 3], device)))
1156            test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
1157            test(
1158                op,
1159                (getter([B0], device), getter([2, B0, 3], device)),
1160                in_dims=(0, 1),
1161                out_dims=1,
1162            )
1163            test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
1164            test(
1165                op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None)
1166            )
1167
1168            # Nested vmap: op(Tensor, Tensor)
1169            test(
1170                vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))
1171            )
1172            test(
1173                vmap(op, in_dims=(None, 0)),
1174                (getter([B0, 2, 3], device), getter([B1, 3], device)),
1175                in_dims=(0, None),
1176            )
1177
1178            # Python number overload: op(Tensor, Number) (and vice-versa)
1179            number = get_number(getter)
1180            self._test_unary(lambda t: op(t, number), getter, device)
1181            number = get_number(getter)
1182            self._test_unary(lambda t: op(number, t), getter, device)
1183
1184            # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
1185            test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
1186            test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
1187            test(op, (getter([B0], device), getter([B0], device)))
1188
1189            # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
1190            test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
1191            test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
1192
1193            if not torch.cuda.is_available():
1194                continue
1195
1196            # TODO(rzou): fix the following
1197            # # Test cross-device scalars
1198            # number = get_number(getter)
1199            # self._test_unary(lambda t: op(t, number), getter, device='cuda')
1200            # self._test_unary(lambda t: op(number, t), getter, device='cuda')
1201            # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
1202
1203    def test_as_strided(self):
1204        def _test(sizes, strides, offset, tensor, lambd):
1205            result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor)
1206            expected = vmap(lambd)(tensor)
1207            self.assertTrue(result._base is expected._base)
1208            self.assertEqual(result, expected)
1209
1210        # single vmap test
1211        B0 = 5
1212        tensors = [
1213            # contiguous
1214            torch.randn(B0, 2, 3),
1215            # non-contiguous
1216            torch.randn(B0, 3, 2).transpose(1, 2),
1217            # non-zero storage offset
1218            torch.randn(2, B0, 2, 3)[1],
1219            # non-contiguous strides, zero storage offset
1220            torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0],
1221            # non-contiguous strides, non-zero storage offset
1222            torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1],
1223        ]
1224
1225        for x in tensors:
1226            S0, S1 = x.stride()[1:]
1227            offset = x.storage_offset()
1228
1229            # Broadcast
1230            _test(
1231                [5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3)
1232            )
1233            # transpose
1234            _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1))
1235            # select
1236            _test([2], [S0], offset + S1, x, lambda x: x[:, 1])
1237
1238        # Nested vmap test
1239        B1 = 7
1240        x = torch.randn(B1, B0, 2, 3)
1241        S0, S1 = x.stride()[2:]
1242        result = vmap(
1243            vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1
1244        )(x)
1245        expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x)
1246        self.assertTrue(result._base is expected._base)
1247        self.assertEqual(result, expected)
1248
1249        # Check that mal-formatted size/strides doesn't crash
1250        with self.assertRaisesRegex(
1251            RuntimeError, "size and stride must have the same length"
1252        ):
1253            x = torch.randn(B0, 2, 3).transpose(0, 1)
1254            vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x)
1255
1256        # Sanity check #1: we require the batch dims to be at the front of the
1257        # tensor (in memory layout).
1258        msg = "batch dims being vmapped over are at the front of the tensor"
1259        with self.assertRaisesRegex(RuntimeError, msg):
1260            x = torch.randn(2, B0, 3).transpose(0, 1)
1261            vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x)
1262        with self.assertRaisesRegex(RuntimeError, msg):
1263            x = torch.randn(B0, 2, 3, B1).movedim(3, 1)
1264            vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x)
1265
1266        # All the Sanity check #2{a,b,c} cases check that
1267        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1268        # doesn't index memory that is out of bounds of xs[i]. This condition
1269        # is important to the correctness of the as_strided batching rule
1270        # (see NOTE: [When will the as_strided_batching_rule fail?])
1271
1272        # Sanity check #2a: The maximum indexable location of
1273        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1274        # is less than or equal to the maximum indexable location of xs[i].
1275        msg = "This is not supported inside of vmap"
1276        with self.assertRaisesRegex(RuntimeError, msg):
1277            x = torch.randn(B0, 3)
1278            vmap(lambda x: x.as_strided([3], [1], 1))(x)
1279        with self.assertRaisesRegex(RuntimeError, msg):
1280            x = torch.randn(B0, 3, 5)
1281            vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x)
1282        with self.assertRaisesRegex(RuntimeError, msg):
1283            x = torch.randn(B0, B1, 3, 5)
1284            vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x)
1285
1286        # Sanity check #2b: The min indexable location of
1287        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1288        # is greater than or equal to the min indexable location of xs[i].
1289        with self.assertRaisesRegex(RuntimeError, msg):
1290            x = torch.randn(2, B0, 3)[1]
1291            vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x)
1292
1293        # Sanity check #2c:
1294        # xs[i] is a zero-dim tensor, but
1295        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1296        # is not
1297        with self.assertRaisesRegex(RuntimeError, msg):
1298            x = torch.randn(B0, 0, 3)
1299            vmap(lambda x: x.as_strided([3], [1]))(x)
1300
1301    def test_bmm(self):
1302        op = torch.bmm
1303        test = self._vmap_test
1304        B0, B1 = 7, 11
1305
1306        # shape mismatch
1307        msg = "Shape mismatch"
1308        with self.assertRaisesRegex(RuntimeError, msg):
1309            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1310        with self.assertRaisesRegex(RuntimeError, msg):
1311            vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
1312        with self.assertRaisesRegex(RuntimeError, msg):
1313            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
1314
1315        # left arg is vmapped
1316        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
1317        test(
1318            vmap(op, in_dims=(0, None)),
1319            (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
1320            in_dims=(1, None),
1321        )
1322
1323        # right arg is vmapped
1324        test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
1325        test(
1326            vmap(op, in_dims=(None, 0)),
1327            (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
1328            in_dims=(None, 1),
1329        )
1330
1331        # both args are vmapped
1332        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
1333        test(
1334            vmap(op),
1335            (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)),
1336            in_dims=(1, 0),
1337        )
1338        test(
1339            vmap(op, in_dims=(0, None)),
1340            (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)),
1341            in_dims=(None, 0),
1342        )
1343
1344    def test_cat(self):
1345        test = self._vmap_test
1346        B0, B1 = 5, 7
1347
1348        # Quick hack b/c vmap can't accept a list of tensors as an argument
1349        def get_op(dim):
1350            def op(*tensors):
1351                return torch.cat(tensors, dim=dim)
1352
1353            return op
1354
1355        test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
1356        test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
1357        test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
1358        test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
1359        test(
1360            vmap(get_op(0), in_dims=(0, None)),
1361            (torch.rand(B1, 2), torch.rand(B0, 3)),
1362            in_dims=(None, 0),
1363        )
1364        test(
1365            vmap(get_op(0), in_dims=(0, 0)),
1366            (torch.rand(B1, 2), torch.rand(B0, B1, 3)),
1367            in_dims=(None, 0),
1368        )
1369
1370    def test_conj(self):
1371        op = torch.conj
1372
1373        def run_test(dtype):
1374            def get(shape):
1375                return torch.randn(shape, dtype=dtype)
1376
1377            B0, B1 = 7, 11
1378            test = self._vmap_test
1379
1380            # Single vmap, various in_dims / out_dims
1381            test(op, [get([B0, 3])])
1382            test(op, [get([2, 5, B0, 3])], in_dims=2)
1383            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
1384
1385            # Doubly nested vmap
1386            test(vmap(op), [get([B0, B1])])
1387            test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
1388            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
1389
1390        # correctness tests
1391        run_test(torch.float)
1392        run_test(torch.cfloat)
1393
1394        # check that torch.conj on a non-complex tensor returns the same tensor
1395        real_tensor = torch.randn(3)
1396        result = vmap(op)(real_tensor)
1397        self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
1398
1399    def test_contiguous(self):
1400        op = Tensor.contiguous
1401
1402        self._test_unary(op, TensorFactory.randn, "cpu")
1403
1404        # check that contiguous returns the original tensor if the per-examples
1405        # are already contiguous
1406        B0 = 3
1407        x = torch.randn(B0, 2, 5, 7)
1408        x = x.movedim(0, 2)
1409        result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
1410        self.assertTrue(result is x)
1411
1412        msg = "NYI: querying is_contiguous inside of vmap for memory_format"
1413        tensor = torch.randn(B0, 3)
1414        with self.assertRaisesRegex(RuntimeError, msg):
1415            vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
1416        with self.assertRaisesRegex(RuntimeError, msg):
1417            vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)
1418
1419    def test_stride(self):
1420        B0 = 3
1421
1422        x = torch.randn(B0, 2, 5, 7)
1423
1424        def foo(x):
1425            assert x.stride() == (7 * 5, 7, 1)
1426            return x
1427
1428        vmap(foo)(x)
1429
1430        x = torch.randn(2, B0, 5, 7).movedim(1, 0)
1431
1432        def bar(x):
1433            assert x.stride() == (7 * 5 * B0, 7, 1)
1434            return x
1435
1436        vmap(bar)(x)
1437
1438    def test_chunk(self):
1439        test = self._vmap_view_test
1440        op = torch.chunk
1441        B0, B1, B2 = 7, 11, 13
1442
1443        # tests for torch.split(self, split_size: int, dim)
1444        test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
1445        test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
1446        test(
1447            vmap(op, in_dims=(0, None, None)),
1448            (torch.rand(B1, 1023, B0, 5), 4, 0),
1449            in_dims=(2, None, None),
1450        )
1451        test(
1452            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
1453            (torch.rand(B1, 2, B0, 64, B2),),
1454            in_dims=2,
1455        )
1456
1457    def test_clamp(self):
1458        clamp_cases = (
1459            (lambda t: t.clamp(min=-0.5), TensorFactory.randn),
1460            (lambda t: t.clamp(max=0.5), TensorFactory.randn),
1461            (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
1462            (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
1463            (lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
1464        )
1465        for op, getter in clamp_cases:
1466            self._test_unary(op, getter, "cpu")
1467
1468    def test_comparison_ops(self):
1469        test = functools.partial(self._vmap_test, check_propagates_grad=False)
1470
1471        getter = TensorFactory.randn
1472        B0, B1 = 7, 11
1473
1474        ops = (
1475            torch.eq,
1476            lambda x, y: x == y,
1477            torch.gt,
1478            lambda x, y: x > y,
1479            torch.ge,
1480            lambda x, y: x >= y,
1481            torch.le,
1482            lambda x, y: x <= y,
1483            torch.lt,
1484            lambda x, y: x < y,
1485            torch.ne,
1486            lambda x, y: x != y,
1487        )
1488
1489        for op in ops:
1490            # Single vmap: op(Tensor, Tensor)
1491            test(op, (getter([B0, 3]), getter([B0, 3])))
1492            test(op, (getter([B0]), getter([B0, 2, 3])))
1493            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1))
1494            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1)
1495            test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None))
1496            test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None))
1497
1498            # Nested vmap: op(Tensor, Tensor)
1499            test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3])))
1500            test(
1501                vmap(op, in_dims=(None, 0)),
1502                (getter([B0, 2, 3]), getter([B1, 3])),
1503                in_dims=(0, None),
1504            )
1505
1506            # test number as inputs
1507            number = getter([]).item()
1508            self._test_unary(
1509                lambda t: op(t, number), getter, "cpu", check_propagates_grad=False
1510            )
1511
1512    def test_diagonal(self):
1513        tensor = torch.randn(3, 5, 7, 11, 13)
1514        test = self._vmap_view_test
1515        op = torch.diagonal
1516        test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
1517        test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
1518        test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
1519        test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
1520        test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
1521        test(
1522            vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
1523            (tensor,),
1524            in_dims=1,
1525            out_dims=1,
1526        )
1527
1528    def test_dot(self):
1529        op = torch.dot
1530        test = self._vmap_test
1531        B0, B1 = 7, 11
1532
1533        # shape mismatch
1534        msg = "Shape mismatch"
1535        with self.assertRaisesRegex(RuntimeError, msg):
1536            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1537        with self.assertRaisesRegex(RuntimeError, msg):
1538            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
1539        with self.assertRaisesRegex(RuntimeError, msg):
1540            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))
1541
1542        # left arg is vmapped
1543        test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
1544        test(
1545            vmap(op, in_dims=(0, None)),
1546            (torch.rand(B1, B0, 5), torch.rand(5)),
1547            in_dims=(1, None),
1548        )
1549
1550        # right arg is vmapped
1551        test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
1552        test(
1553            vmap(op, in_dims=(None, 0)),
1554            (torch.rand(5), torch.rand(B1, B0, 5)),
1555            in_dims=(None, 1),
1556        )
1557
1558        # both args are vmapped
1559        test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
1560        test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
1561        test(
1562            vmap(op, in_dims=(0, None)),
1563            (torch.rand(B1, 5), torch.rand(B0, 5)),
1564            in_dims=(None, 0),
1565        )
1566
1567    def test_expand_as(self):
1568        op = torch.Tensor.expand_as
1569        test = self._vmap_view_test
1570        B0, B1, B2 = 7, 11, 13
1571        test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
1572        test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
1573        test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
1574        test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
1575        test(
1576            vmap(op),
1577            (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)),
1578            in_dims=(0, 1),
1579        )
1580        test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
1581        test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
1582
1583    def test_fill_and_zero_inplace(self):
1584        test = functools.partial(self._vmap_test, check_propagates_grad=False)
1585        B0, B1 = 7, 11
1586        ops = (
1587            lambda t: t.fill_(0.1),
1588            lambda t: t.fill_(torch.tensor(0.2)),
1589            lambda t: t.zero_(),
1590        )
1591
1592        for op in ops:
1593            # Single vmap, various in_dims / out_dims
1594            test(op, [TensorFactory.randn([B0, 3])])
1595            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2)
1596            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
1597
1598            # Doubly nested vmap
1599            test(vmap(op), [TensorFactory.randn([B0, B1])])
1600            test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2)
1601            test(
1602                vmap(op, in_dims=2),
1603                [TensorFactory.randn([2, 5, B0, B1, 3])],
1604                in_dims=2,
1605                out_dims=2,
1606            )
1607
1608        # test when value is a batched tensor for fill_ operator
1609        B0, B1 = 3, 5
1610        test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)])
1611
1612        with self.assertRaisesRegex(
1613            RuntimeError, r"output with shape .+ doesn't match the broadcast shape"
1614        ):
1615            # Runtime Error is thrown when the tensor being written to isn't being vmapped over
1616            vmap(Tensor.fill_, (None, 0))(
1617                TensorFactory.randn([B0, B1]), TensorFactory.randn([B0])
1618            )
1619
1620    def _test_complex_views(self, op, dtypes):
1621        test = self._vmap_view_test
1622
1623        def run_test(op, dtype):
1624            def get(shape):
1625                return torch.randn(shape, dtype=dtype)
1626
1627            B0, B1 = 7, 11
1628
1629            # Single vmap, various in_dims / out_dims
1630            test(op, [get([B0, 3])])
1631            test(op, [get([3, B0])], in_dims=1)
1632            test(op, [get([2, 5, B0, 3])], in_dims=2)
1633            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
1634
1635            # Doubly nested vmap
1636            test(vmap(op), [get([B0, B1])])
1637            test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4)
1638            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
1639
1640        for dtype in dtypes:
1641            run_test(op, dtype)
1642
1643    def test_real(self):
1644        self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble])
1645
1646    def test_imag(self):
1647        self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble])
1648
1649    def test_view_as_real(self):
1650        self._test_complex_views(
1651            torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble]
1652        )
1653
1654    def test_view_as_complex(self):
1655        def run_test(dtype):
1656            def get(shape):
1657                return torch.randn(shape, dtype=dtype)
1658
1659            op = torch.view_as_complex
1660            test = self._vmap_view_test
1661            B0, B1 = 7, 11
1662
1663            # Single vmap, various in_dims / out_dims
1664            test(op, [get([B0, 3, 2])])
1665            test(op, [get([2, 5, B0, 3, 2])], in_dims=2)
1666            test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2)
1667
1668            # Doubly nested vmap
1669            test(vmap(op), [get([B0, B1, 2])])
1670            test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2)
1671            test(
1672                vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], in_dims=2, out_dims=2
1673            )
1674
1675            # Interesting case #1: Batch dim directly before dim of size 2
1676            test(op, [get([3, B0, 2])], in_dims=1)
1677            test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2)
1678
1679            # Interesting case #2: Batch dim at end of tensor, success cases
1680            # view_as_complex requires that the dim with size 2 have stride 1
1681            # in order for the view to function propertly
1682            test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1)
1683            test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)])
1684            test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)])
1685
1686            # Interesting case #3: Batch dim at end of tensor, failure cases
1687            msg = "Tensor must have a last dimension with stride 1"
1688            with self.assertRaisesRegex(RuntimeError, msg):
1689                vmap(op, in_dims=1)(get([2, B0]))
1690            with self.assertRaisesRegex(RuntimeError, msg):
1691                vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1]))
1692
1693            # Invalid input: no dimension of size 2
1694            msg = "Input tensor must have one or more dimensions"
1695            with self.assertRaisesRegex(RuntimeError, msg):
1696                vmap(op)(get([B0]))
1697            with self.assertRaisesRegex(RuntimeError, msg):
1698                vmap(vmap(op))(get([B0, B1]))
1699
1700            # Invalid input: Batch dim has size 2, but the logical last dim does
1701            # not have size 2
1702            msg = "Tensor must have a last dimension of size 2"
1703            with self.assertRaisesRegex(RuntimeError, msg):
1704                vmap(op, in_dims=1)(get([3, 2]))
1705
1706        for dtype in [torch.float, torch.double]:
1707            run_test(dtype)
1708
1709    def test_is_complex(self):
1710        ctensor = torch.randn(3, dtype=torch.cfloat)
1711        tensor = torch.randn(3)
1712
1713        def foo(x):
1714            if x.is_complex():
1715                return torch.tensor(1)
1716            else:
1717                return torch.tensor(0)
1718
1719        self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
1720        self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
1721
1722    def test_is_floating_point(self):
1723        float_tensor = torch.tensor([1.0, 2.0, 3.0])
1724        long_tensor = torch.tensor([1, 2, 3])
1725
1726        def foo(x):
1727            if x.is_floating_point():
1728                return torch.tensor(1)
1729            else:
1730                return torch.tensor(0)
1731
1732        self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
1733        self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
1734
1735    def test_is_contiguous(self):
1736        def foo(x):
1737            if x.is_contiguous():
1738                return torch.tensor(1.0)
1739            else:
1740                return torch.tensor(0.0)
1741
1742        B0, B1 = 3, 5
1743
1744        # Single batch dim
1745        contig = torch.randn(B0, 2, 7)
1746        self.assertEqual(vmap(foo)(contig), torch.ones(B0))
1747
1748        noncontig = torch.randn(2, B0, 7)
1749        self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))
1750
1751        noncontig = torch.randn(2, B0, 7).movedim(1, 0)
1752        self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))
1753
1754        noncontig = torch.randn(2, 7, B0)
1755        self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))
1756
1757        # Multiple batch dims
1758        contig = torch.randn(B0, B1, 3)
1759        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
1760
1761        contig = torch.randn(B1, B0, 3)
1762        self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))
1763
1764        contig = torch.randn(B1, B0, 3).movedim(0, 1)
1765        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
1766
1767        noncontig = torch.randn(B0, 3, B1)
1768        self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))
1769
1770        # is_contiguous on empty tensor is True
1771        def bar(x):
1772            assert x.is_contiguous()
1773            return x
1774
1775        vmap(bar)(torch.randn(B0, 0, 3))
1776        vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
1777        vmap(bar)(torch.randn(B0, 0, 3).mT)
1778
1779        # is_contiguous with other memory formats
1780        def baz(x, memory_format):
1781            x.is_contiguous(memory_format=memory_format)
1782            return x
1783
1784        msg = "NYI: querying is_contiguous inside of vmap for memory_format"
1785        tensor = torch.randn(B0, 2, 7, 3)
1786        with self.assertRaisesRegex(RuntimeError, msg):
1787            vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
1788        with self.assertRaisesRegex(RuntimeError, msg):
1789            vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
1790
1791    def test_movedim(self):
1792        op = torch.movedim
1793        test = self._vmap_view_test
1794        B0, B1, B2 = 7, 11, 13
1795
1796        # movedim(tensor, int, int) variant
1797        test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
1798        test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
1799        test(
1800            vmap(op, in_dims=(0, None, None)),
1801            (torch.rand(B1, 2, B0, 5), 0, 1),
1802            in_dims=(2, None, None),
1803        )
1804        test(
1805            vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
1806            (torch.rand(B1, 2, B0, 5, B2), 0, 1),
1807            in_dims=(2, None, None),
1808        )
1809
1810        # movedim(tensor, intlist, intlist) variant
1811        test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
1812        test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
1813        test(
1814            vmap(op, in_dims=(0, None, None)),
1815            (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]),
1816            in_dims=(2, None, None),
1817        )
1818        test(
1819            vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
1820            (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]),
1821            in_dims=(2, None, None),
1822        )
1823
1824    def test_mm(self):
1825        op = torch.mm
1826        test = self._vmap_test
1827        B0, B1 = 7, 11
1828
1829        # shape mismatch
1830        msg = "Shape mismatch"
1831        with self.assertRaisesRegex(RuntimeError, msg):
1832            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1833        with self.assertRaisesRegex(RuntimeError, msg):
1834            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
1835        with self.assertRaisesRegex(RuntimeError, msg):
1836            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
1837
1838        # left arg is vmapped
1839        test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
1840        test(
1841            vmap(op, in_dims=(0, None)),
1842            (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
1843            in_dims=(1, None),
1844        )
1845
1846        # right arg is vmapped
1847        test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
1848        test(
1849            vmap(op, in_dims=(None, 0)),
1850            (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
1851            in_dims=(None, 1),
1852        )
1853
1854        # both args are vmapped
1855        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
1856        test(
1857            vmap(op),
1858            (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)),
1859            in_dims=(1, 0),
1860        )
1861        test(
1862            vmap(op, in_dims=(0, None)),
1863            (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)),
1864            in_dims=(None, 0),
1865        )
1866
1867    def test_mv(self):
1868        op = torch.mv
1869        test = self._vmap_test
1870        B0, B1 = 7, 11
1871
1872        # shape mismatch
1873        msg = "Shape mismatch"
1874        with self.assertRaisesRegex(RuntimeError, msg):
1875            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1876        with self.assertRaisesRegex(RuntimeError, msg):
1877            vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
1878        with self.assertRaisesRegex(RuntimeError, msg):
1879            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))
1880
1881        # left arg is vmapped
1882        test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
1883        test(
1884            vmap(op, in_dims=(0, None)),
1885            (torch.rand(B1, B0, 2, 5), torch.rand(5)),
1886            in_dims=(1, None),
1887        )
1888
1889        # right arg is vmapped
1890        test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
1891        test(
1892            vmap(op, in_dims=(None, 0)),
1893            (torch.rand(2, 5), torch.rand(B1, B0, 5)),
1894            in_dims=(None, 1),
1895        )
1896
1897        # both args are vmapped
1898        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
1899        test(
1900            vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)
1901        )
1902        test(
1903            vmap(op, in_dims=(0, None)),
1904            (torch.rand(B1, 2, 5), torch.rand(B0, 5)),
1905            in_dims=(None, 0),
1906        )
1907
1908    def test_narrow(self):
1909        op = torch.narrow
1910        test = self._vmap_view_test
1911        B0, B1, B2 = 7, 11, 13
1912
1913        test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
1914        test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
1915        test(
1916            vmap(op, in_dims=(0, None, None, None)),
1917            (torch.rand(B1, 2, B0, 5), 1, 0, 0),
1918            in_dims=(2, None, None, None),
1919        )
1920        test(
1921            vmap(
1922                vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
1923            ),
1924            (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3),
1925            in_dims=(2, None, None, None),
1926        )
1927
1928    def test_new_empty(self):
1929        # Empty is non-deterministic so we just check that the shape of the
1930        # output tensor is what we expect and that the vmap fallback isn't used.
1931        op = Tensor.new_empty
1932
1933        B0, B1 = 7, 11
1934
1935        result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
1936        self.assertEqual(result.shape, [B0, 2, 3])
1937
1938        result = vmap(lambda x: op(x, []))(torch.randn(B0))
1939        self.assertEqual(result.shape, [B0])
1940
1941        result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
1942        self.assertEqual(result.shape, [B0, B1, 2, 3])
1943
1944    def test_new_empty_strided(self):
1945        # Empty is non-deterministic so we just check that the size and shape
1946        # of the output are what we expect and that the vmap fallback isn't used
1947        B0, B1 = 7, 11
1948
1949        def _test_single_vmap(size, stride, B0):
1950            x = torch.randn(B0)
1951            result = vmap(lambda x: x.new_empty_strided(size, stride))(x)
1952            S = torch.empty_strided(size, stride).storage().size()
1953            self.assertEqual(result.shape, [B0] + size)
1954            self.assertEqual(result.stride(), [S] + stride)
1955
1956        def _test_double_vmap(size, stride, B0, B1):
1957            x = torch.randn(B0, B1)
1958            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x)
1959            S = torch.empty_strided(size, stride).storage().size()
1960            self.assertEqual(result.shape, [B0, B1] + size)
1961            self.assertEqual(result.stride(), [B1 * S, S] + stride)
1962
1963            x = torch.randn(B1, B0)
1964            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(
1965                x
1966            )
1967            S = x.new_empty_strided(size, stride).storage().size()
1968            self.assertEqual(result.shape, [B0, B1] + size)
1969            self.assertEqual(result.stride(), [B1 * S, S] + stride)
1970
1971        # contiguous case
1972        _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0)
1973        _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1)
1974
1975        # expanded
1976        _test_single_vmap([2, 3, 5], [0, 5, 1], B0)
1977        _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1)
1978
1979        # some of these cases are pretty strange, just verifying that if
1980        # empty_strided allows them then BatchedTensor.new_empty_strided
1981        # can as well
1982        for shape in [[2, 3, 4], [0, 2, 0]]:
1983            for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]:
1984                _test_single_vmap(shape, strides, B0)
1985                _test_double_vmap(shape, strides, B0, B1)
1986
1987    def test_new_zeros(self):
1988        op = Tensor.new_zeros
1989        test = functools.partial(self._vmap_test, check_propagates_grad=False)
1990        B0, B1 = 7, 11
1991
1992        test(lambda x: op(x, 2, 3), (torch.rand(B0),))
1993        test(lambda x: op(x, []), (torch.rand(B0),))
1994        test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))
1995
1996    def test_select(self):
1997        op = torch.select
1998        test = self._vmap_view_test
1999        B0, B1, B2 = 7, 11, 13
2000        test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
2001        test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
2002        test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2003        test(
2004            vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)),
2005            (torch.rand(B1, 2, B0, B2, 5),),
2006            in_dims=2,
2007        )
2008
2009    def test_stack(self):
2010        test = self._vmap_test
2011        B0, B1 = 5, 7
2012
2013        # Quick hack b/c vmap can't accept a list of tensors as an argument
2014        def get_op(dim):
2015            def op(*tensors):
2016                return torch.stack(tensors, dim=dim)
2017
2018            return op
2019
2020        test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
2021        test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
2022        test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
2023        test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
2024        test(
2025            vmap(get_op(0), in_dims=(0, None)),
2026            (torch.rand(B1, 2), torch.rand(B0, 2)),
2027            in_dims=(None, 0),
2028        )
2029        test(
2030            vmap(get_op(0), in_dims=(0, 0)),
2031            (torch.rand(B1, 2), torch.rand(B0, B1, 2)),
2032            in_dims=(None, 0),
2033        )
2034
2035    def test_slice(self):
2036        test = self._vmap_view_test
2037        B0, B1, B2 = 7, 11, 13
2038        test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
2039        test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
2040        test(
2041            vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2
2042        )
2043        test(
2044            vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
2045            (torch.rand(3, 5, B0, B1, B2),),
2046            in_dims=2,
2047        )
2048
2049    def test_squeeze(self):
2050        test = self._vmap_view_test
2051        op = torch.squeeze
2052        B0, B1 = 1, 11
2053        test(op, (torch.rand(B0),))
2054        test(op, (torch.rand(B0, 3, 5),))
2055        test(op, (torch.rand(1, B0, 5),), in_dims=1)
2056        test(op, (torch.rand(B0, 0, 1, 5, 1),))
2057        test(op, (torch.rand(B0, 1, 1, 1, 1),))
2058        test(vmap(op), (torch.rand(B0, B1, 1),))
2059        test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
2060
2061    def test_sum_dim(self):
2062        test = self._vmap_test
2063        B0, B1 = 5, 7
2064
2065        # Single vmap, various in_dims / out_dims
2066        test(lambda x: x.sum(()), [torch.randn([B0])])
2067        test(lambda x: x.sum(()), [torch.randn([B0, 2])])
2068        test(lambda x: x.sum(0), [torch.randn([B0])])
2069        test(lambda x: x.sum(-1), [torch.randn([B0])])
2070        test(lambda x: x.sum(0), [torch.randn([B0, 3])])
2071        test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2)
2072        test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
2073
2074        # Doubly nested vmap
2075        test(vmap(lambda x: x.sum(())), [torch.randn([B0, B1])])
2076        test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])])
2077        test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])])
2078        test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
2079        test(
2080            vmap(lambda x: x.sum(2), in_dims=2),
2081            [torch.randn([2, 5, B0, B1, 3])],
2082            in_dims=2,
2083            out_dims=2,
2084        )
2085
2086    def test_reshape(self):
2087        test = self._vmap_test
2088        B0, B1, B2 = 7, 11, 13
2089        op = torch.reshape
2090        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
2091        test(
2092            op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False
2093        )
2094        test(
2095            vmap(lambda t: t.reshape([-1])),
2096            (torch.rand(B0, B1, 2, 5),),
2097            check_view=True,
2098        )
2099        test(
2100            vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
2101            (torch.rand(3, B1, 2, B2, 5, B0),),
2102            in_dims=5,
2103            check_view=False,
2104        )
2105
2106    def test_reshape_as(self):
2107        test = self._vmap_test
2108        B0, B1, B2 = 7, 11, 13
2109        op = torch.Tensor.reshape_as
2110        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
2111        test(
2112            op,
2113            (torch.rand(2 * 5), torch.rand(B0, 2, 5)),
2114            in_dims=(None, 0),
2115            check_view=True,
2116        )
2117        test(
2118            op,
2119            (torch.rand(B0, 2 * 5), torch.rand(2, 5)),
2120            in_dims=(0, None),
2121            check_view=True,
2122        )
2123
2124        test(
2125            op,
2126            (torch.rand(2, B0, 5), torch.rand(1, 1, 10)),
2127            in_dims=(1, None),
2128            check_view=False,
2129        )
2130
2131        test(
2132            vmap(op),
2133            (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)),
2134            check_view=True,
2135        )
2136        test(
2137            vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
2138            (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
2139            in_dims=(5, 0),
2140            check_view=False,
2141        )
2142
2143    def test_result_type(self):
2144        def scalar_tensor_with_dtype(op):
2145            def wrapped(*args, **kwargs):
2146                dtype = op(*args, **kwargs)
2147                return torch.ones([], dtype=dtype)
2148
2149            return wrapped
2150
2151        test = self._vmap_test
2152        op = scalar_tensor_with_dtype(torch.result_type)
2153
2154        B0 = 2
2155
2156        test(
2157            op,
2158            (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
2159            check_propagates_grad=False,
2160        )
2161        test(
2162            op,
2163            (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
2164            check_propagates_grad=False,
2165        )
2166
2167        test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
2168        test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)
2169
2170        test(
2171            lambda x: op(x, torch.tensor(1)),
2172            (torch.randn(B0),),
2173            check_propagates_grad=False,
2174        )
2175        test(
2176            lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
2177            (torch.randn(B0),),
2178            check_propagates_grad=False,
2179        )
2180
2181        test(
2182            op,
2183            (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
2184            check_propagates_grad=False,
2185        )
2186        test(
2187            op,
2188            (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
2189            check_propagates_grad=False,
2190        )
2191
2192        test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
2193        test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)
2194
2195        test(
2196            lambda x: op(x, torch.tensor(1)),
2197            (torch.randn(B0, 2),),
2198            check_propagates_grad=False,
2199        )
2200        test(
2201            lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
2202            (torch.randn(B0, 2),),
2203            check_propagates_grad=False,
2204        )
2205
2206        test(
2207            op,
2208            (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
2209            check_propagates_grad=False,
2210        )
2211        test(
2212            op,
2213            (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
2214            check_propagates_grad=False,
2215        )
2216
2217    @skipIfTorchDynamo("too slow")
2218    def test_tensor_split(self):
2219        test = self._vmap_view_test
2220        op = torch.tensor_split
2221        B0, B1, B2 = 7, 11, 13
2222
2223        # tests for torch.tensor_split(self, indices_or_sections: int, dim)
2224        test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
2225        test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
2226        test(
2227            vmap(op, in_dims=(0, None, None)),
2228            (torch.rand(B1, 1023, B0, 5), 256, 0),
2229            in_dims=(2, None, None),
2230        )
2231        test(
2232            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
2233            (torch.rand(B1, 2, B0, 64, B2),),
2234            in_dims=2,
2235        )
2236
2237        # tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
2238        test(
2239            op,
2240            (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1),
2241            in_dims=(0, None, None),
2242        )
2243        test(
2244            op,
2245            (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1),
2246            in_dims=(1, None, None),
2247        )
2248        test(
2249            vmap(op, in_dims=(0, None, None)),
2250            (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
2251            in_dims=(2, None, None),
2252        )
2253        test(
2254            vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
2255            (torch.rand(B1, 2, B0, 64, B2),),
2256            in_dims=2,
2257        )
2258
2259    def test_split(self):
2260        test = self._vmap_view_test
2261        op = torch.split
2262        B0, B1, B2 = 7, 11, 13
2263
2264        # tests for torch.split(self, split_size: int, dim)
2265        test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
2266        test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
2267        test(
2268            vmap(op, in_dims=(0, None, None)),
2269            (torch.rand(B1, 1023, B0, 5), 256, 0),
2270            in_dims=(2, None, None),
2271        )
2272        test(
2273            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
2274            (torch.rand(B1, 2, B0, 64, B2),),
2275            in_dims=2,
2276        )
2277
2278        # tests for torch.split(self, split_size: List[int], dim)
2279        test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
2280        test(
2281            op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None)
2282        )
2283        test(
2284            vmap(op, in_dims=(0, None, None)),
2285            (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
2286            in_dims=(2, None, None),
2287        )
2288        test(
2289            vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
2290            (torch.rand(B1, 2, B0, 64, B2),),
2291            in_dims=2,
2292        )
2293
2294    def test_trace(self):
2295        op = torch.trace
2296        test = self._vmap_test
2297        B0, B1, B2 = 7, 11, 13
2298
2299        test(op, (torch.rand(B0, 2, 5),))
2300        test(op, (torch.rand(2, B0, 5),), in_dims=1)
2301        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2302        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
2303
2304    def test_transpose(self):
2305        op = torch.transpose
2306        test = self._vmap_view_test
2307
2308        B0, B1, B2 = 7, 11, 13
2309        test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),))
2310        test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),))
2311        test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),))
2312        test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1)
2313        test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2314        test(
2315            vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)),
2316            (torch.rand(B1, 2, B0, 5, B2),),
2317            in_dims=2,
2318        )
2319
2320        # Special case: scalar tensor
2321        for dim1, dim2 in itertools.product([0, -1], [0, -1]):
2322            x = torch.rand(B0)
2323            result = vmap(lambda x: op(x, dim1, dim2))(x)
2324            self.assertTrue(result is x)
2325
2326    def test_t(self):
2327        op = torch.t
2328        test = self._vmap_view_test
2329        B0, B1, B2 = 7, 11, 13
2330        test(op, (torch.rand(B0, 2, 5),))
2331        test(op, (torch.rand(2, B0, 5),), in_dims=1)
2332        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2333        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
2334
2335    def test_T_numpy(self):
2336        def op(t):
2337            return t.T
2338
2339        test = self._vmap_view_test
2340        B0, B1, B2 = 7, 11, 13
2341        test(op, (torch.rand(B0, 2, 3, 5),))
2342        test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
2343        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2344        test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
2345        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
2346
2347    def test_to(self):
2348        test = self._vmap_test
2349        B0, B1 = 7, 11
2350
2351        test(lambda t: t.to("cpu"), (torch.rand(B0),))
2352        test(lambda t: t.to(torch.double), (torch.rand(B0),))
2353        test(
2354            lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64))
2355        )
2356        test(
2357            lambda t, o: t.to(o),
2358            (torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
2359            in_dims=(0, None),
2360        )
2361        test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
2362
2363        # also test some casting methods
2364        test(lambda t: t.double(), (torch.rand(B0),))
2365        test(lambda t: t.float(), (torch.rand(B0),))
2366        test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
2367        test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
2368
2369    def test_unfold(self):
2370        op = torch.Tensor.unfold
2371        test = self._vmap_view_test
2372        B0, B1, B2 = 3, 2, 5
2373
2374        test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
2375        test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
2376        test(
2377            vmap(op, in_dims=(0, None, None, None)),
2378            (torch.rand(B1, 7, B0, 11), 1, 5, 1),
2379            in_dims=(2, None, None, None),
2380        )
2381        test(
2382            vmap(
2383                vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
2384            ),
2385            (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4),
2386            in_dims=(2, None, None, None),
2387        )
2388
2389    def test_unbind(self):
2390        test = self._vmap_view_test
2391        op = torch.unbind
2392        B0, B1, B2 = 7, 11, 13
2393
2394        test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
2395        test(op, (torch.rand(B0, 2, 0),))
2396        test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
2397        test(
2398            vmap(op, in_dims=(0, None)),
2399            (torch.rand(B1, 1023, B0, 5), 1),
2400            in_dims=(2, None),
2401        )
2402        test(
2403            vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
2404            (torch.rand(B1, 2, B0, 32, B2),),
2405            in_dims=2,
2406        )
2407
2408    def test_view(self):
2409        test = self._vmap_view_test
2410        B0, B1, B2 = 7, 11, 13
2411        op = torch.Tensor.view
2412
2413        # We should error out if the view would produce an incorrect result
2414        with self.assertRaises(RuntimeError):
2415            vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])
2416
2417        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
2418        test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
2419        test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
2420        test(
2421            vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
2422            (torch.rand(B2, B0, B1, 3, 2, 5),),
2423            in_dims=1,
2424        )
2425
2426    def test_view_as(self):
2427        test = self._vmap_view_test
2428        B0, B1, B2 = 7, 11, 13
2429        op = torch.Tensor.view_as
2430
2431        # We should error out if the view would produce an incorrect result
2432        with self.assertRaises(RuntimeError):
2433            vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))
2434
2435        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
2436        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
2437        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))
2438
2439        test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))
2440
2441        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
2442        test(
2443            vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
2444            (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
2445            in_dims=(2, 0),
2446        )
2447
2448    def test_no_random_op_support(self):
2449        B0 = 2
2450
2451        captured = torch.rand(3)
2452
2453        random_ops = [
2454            # out-of-place on BatchedTensor
2455            (torch.bernoulli, (torch.rand(B0, 1),)),
2456            (lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)),
2457            (lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)),
2458            (torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))),
2459            (lambda t: torch.normal(t, 1.0), (torch.randn(B0, 1),)),
2460            (lambda t: torch.normal(0.0, t), (torch.randn(B0, 1),)),
2461            (torch.poisson, (torch.rand(B0, 1),)),
2462            (torch.rand_like, (torch.rand(B0, 1),)),
2463            (torch.randn_like, (torch.rand(B0, 1),)),
2464            (lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)),
2465            (lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)),
2466            # out-of-place on captured tensor
2467            (lambda t: torch.bernoulli(captured), (torch.rand(B0),)),
2468            (lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)),
2469            (lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)),
2470            (lambda t: torch.normal(captured, captured), (torch.randn(B0),)),
2471            (lambda t: torch.normal(captured, 1.0), (torch.randn(B0),)),
2472            (lambda t: torch.normal(0.0, captured), (torch.randn(B0),)),
2473            (lambda t: torch.poisson(captured), (torch.rand(B0),)),
2474            (lambda t: torch.rand_like(captured), (torch.rand(B0),)),
2475            (lambda t: torch.randn_like(captured), (torch.rand(B0),)),
2476            (lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)),
2477            (lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)),
2478            # in-place on BatchedTensor
2479            (lambda t: t.bernoulli_(), (torch.randn(B0, 1),)),
2480            (lambda t: t.cauchy_(), (torch.randn(B0, 1),)),
2481            (lambda t: t.exponential_(), (torch.randn(B0, 1),)),
2482            (lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)),
2483            (lambda t: t.log_normal_(), (torch.randn(B0, 1),)),
2484            (lambda t: t.normal_(), (torch.randn(B0, 1),)),
2485            (lambda t: t.random_(), (torch.randn(B0, 1),)),
2486            (lambda t: t.random_(0, 2), (torch.randn(B0, 1),)),
2487            (lambda t: t.random_(2), (torch.randn(B0, 1),)),
2488            (lambda t: t.uniform_(), (torch.randn(B0, 1),)),
2489            # in-place on captured tensor
2490            (lambda t: captured.bernoulli_(), (torch.randn(B0),)),
2491            (lambda t: captured.cauchy_(), (torch.randn(B0),)),
2492            (lambda t: captured.exponential_(), (torch.randn(B0),)),
2493            (lambda t: captured.geometric_(0.5), (torch.randn(B0),)),
2494            (lambda t: captured.log_normal_(), (torch.randn(B0),)),
2495            (lambda t: captured.normal_(), (torch.randn(B0),)),
2496            (lambda t: captured.random_(), (torch.randn(B0),)),
2497            (lambda t: captured.random_(0, 2), (torch.randn(B0),)),
2498            (lambda t: captured.random_(2), (torch.randn(B0),)),
2499            (lambda t: captured.uniform_(), (torch.randn(B0),)),
2500            # factory functions
2501            (lambda t: torch.rand(1), (torch.randn(B0),)),
2502            (lambda t: torch.randn(1), (torch.randn(B0),)),
2503            (lambda t: torch.randint(5, [1]), (torch.randn(B0),)),
2504            (lambda t: torch.randperm(5), (torch.randn(B0),)),
2505        ]
2506        for op, args in random_ops:
2507            with self.assertRaisesRegex(
2508                RuntimeError, "vmap: We do not yet support calling random operations"
2509            ):
2510                vmap(op)(*args)
2511
2512
2513def construct_v(output, batch_size):
2514    return torch.randn(
2515        batch_size, *output.shape, dtype=output.dtype, device=output.device
2516    )
2517
2518
2519def as_tuple(x):
2520    if isinstance(x, tuple):
2521        return x
2522    elif isinstance(x, list):
2523        return tuple(x)
2524    else:
2525        return (x,)
2526
2527
2528def differentiable(args):
2529    return tuple(
2530        arg
2531        for arg in as_tuple(args)
2532        if isinstance(arg, torch.Tensor) and arg.requires_grad
2533    )
2534
2535
2536def _get_rand_no_zeros(*args, **kwargs):
2537    requires_grad = kwargs.get("requires_grad", False)
2538    kwargs_without_requires_grad = kwargs.copy()
2539    kwargs_without_requires_grad["requires_grad"] = False
2540    result = torch.rand(*args, **kwargs_without_requires_grad)
2541    return result.clamp_min_(0.1).requires_grad_(requires_grad)
2542
2543
2544class TestVmapBatchedGradientLegacy(Namespace.TestVmapBaseLegacy):
2545    def _vmap_test(self, *args, **kwargs):
2546        return _vmap_test(self, *args, **kwargs)
2547
2548    # Tests batched gradient computation of outputs = op(*args, **kwargs)
2549    # by comparing it to a sequential map+stack fallback.
2550    #
2551    # output_process_fn: a function that maps the outputs to the part
2552    #       that should be differentiated.
2553    # batch_size: the batch dim size for the batched grad
2554    def _batched_grad_test(
2555        self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
2556    ):
2557        if kwargs is None:
2558            kwargs = {}
2559        outputs = op(*args, **kwargs)
2560        outputs = differentiable(output_process_fn(outputs))
2561        batched_vectors = tuple(construct_v(out, batch_size) for out in outputs)
2562
2563        def vector_jacobian_product(*vectors):
2564            return torch.autograd.grad(
2565                outputs, differentiable(args), vectors, retain_graph=True
2566            )
2567
2568        self._vmap_test(
2569            vector_jacobian_product, batched_vectors, check_propagates_grad=False
2570        )
2571
2572    # Tests batched second grad computation of outputs = op(*args, **kwargs).
2573    # by comparing it to a sequential map+stack fallback.
2574    #
2575    # output_process_fn: a function that maps the outputs to the part
2576    #       that should be differentiated.
2577    # batch_size: the batch dim size for the batched grad
2578    #
2579    # NB: we only test computing batched gradients in the second gradient
2580    # computation. One specific use case that does this is computing the hessian
2581    # matrix of a scalar-valued function; this is useful in Bayesian Logistic
2582    # Regression.
2583    # It might be useful to have a test that computes batched first gradients and
2584    # then uses those to compute batched second gradients in the future.
2585    def _batched_grad_grad_test(
2586        self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
2587    ):
2588        if kwargs is None:
2589            kwargs = {}
2590        outputs = op(*args, **kwargs)
2591        outputs = differentiable(output_process_fn(outputs))
2592        ones = tuple(torch.ones_like(out) for out in outputs)
2593        # Same thing as summing together all of the outputs and calling .backward()
2594        first_grads = torch.autograd.grad(
2595            outputs, differentiable(args), ones, create_graph=True
2596        )
2597        first_grads = differentiable(first_grads)
2598        self.assertNotEqual(
2599            len(first_grads), 0, "None of the first grads depend on the input!"
2600        )
2601
2602        batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads)
2603
2604        def vector_hessian_product(*vectors):
2605            outputs = torch.autograd.grad(
2606                first_grads,
2607                differentiable(args),
2608                vectors,
2609                retain_graph=True,
2610                allow_unused=True,
2611            )
2612            outputs = tuple(out for out in outputs if out is not None)
2613            assert len(outputs) > 0
2614            return outputs
2615
2616        self._vmap_test(
2617            vector_hessian_product, batched_vectors, check_propagates_grad=False
2618        )
2619
2620    def _test_arithmetic(self, op, device, test_grad_grad=True):
2621        x = torch.randn(2, 3, requires_grad=True, device=device)
2622        y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2623        scalar = 3.14
2624        self._batched_grad_test(op, (x, y))
2625        self._batched_grad_test(op, (scalar, y))
2626        self._batched_grad_test(op, (x, scalar))
2627
2628        if test_grad_grad:
2629            self._batched_grad_grad_test(op, (x, y))
2630
2631    def test_add(self, device):
2632        self._test_arithmetic(torch.add, device, test_grad_grad=False)
2633        self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)
2634
2635    def test_sub(self, device):
2636        self._test_arithmetic(torch.sub, device, test_grad_grad=False)
2637        self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)
2638
2639    def test_mul(self, device):
2640        self._test_arithmetic(torch.mul, device)
2641        self._test_arithmetic(lambda x, y: x * y, device)
2642
2643    def test_div(self, device):
2644        self._test_arithmetic(torch.div, device)
2645        self._test_arithmetic(lambda x, y: x / y, device)
2646
2647    @allowVmapFallbackUsage
2648    def test_binary_cross_entropy(self, device):
2649        x = torch.sigmoid(torch.randn(3, 2, device=device, requires_grad=True))
2650        target = torch.rand(3, 2, device=device)
2651
2652        op = functools.partial(F.binary_cross_entropy, target=target)
2653
2654        self._batched_grad_test(op, (x,), {})
2655        self._batched_grad_grad_test(op, (x,), {})
2656
2657    def test_expand(self, device):
2658        x = torch.randn(2, 3, device=device, requires_grad=True)
2659
2660        def op(x):
2661            return x.expand(5, 5, 2, 3)
2662
2663        self._batched_grad_test(op, (x,))
2664
2665    @allowVmapFallbackUsage
2666    def test_index(self, device):
2667        x = torch.randn(2, 3, requires_grad=True, device=device)
2668        index = torch.tensor([[0, 0], [1, 1]], device=device)
2669
2670        def op(x):
2671            y = x * x
2672            return y[index]
2673
2674        self._batched_grad_test(op, (x,))
2675        self._batched_grad_grad_test(op, (x,))
2676
2677    def test_lgamma(self, device):
2678        x = torch.randn(2, 3, requires_grad=True, device=device)
2679        self._batched_grad_test(Tensor.lgamma, (x,))
2680        self._batched_grad_grad_test(Tensor.lgamma, (x,))
2681
2682    def test_log(self, device):
2683        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2684        self._batched_grad_test(torch.log, (x,))
2685        self._batched_grad_grad_test(torch.log, (x,))
2686
2687    def test_logsumexp(self, device):
2688        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2689
2690        def op(x):
2691            return torch.logsumexp(x, -1)
2692
2693        self._batched_grad_test(op, (x,))
2694        self._batched_grad_grad_test(op, (x,))
2695
2696    def test_log1p(self, device):
2697        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2698        self._batched_grad_test(torch.log1p, (x,))
2699        self._batched_grad_grad_test(torch.log1p, (x,))
2700
2701    @allowVmapFallbackUsage
2702    def test_max(self, device):
2703        x = torch.randn(2, 3, requires_grad=True, device=device)
2704        self._batched_grad_test(torch.max, (x,))
2705
2706    @allowVmapFallbackUsage
2707    def test_median(self, device):
2708        x = torch.randn(2, 3, requires_grad=True, device=device)
2709        self._batched_grad_test(torch.median, (x,))
2710
2711    @allowVmapFallbackUsage
2712    def test_min(self, device):
2713        x = torch.randn(2, 3, requires_grad=True, device=device)
2714        self._batched_grad_test(torch.min, (x,))
2715
2716    def test_permute(self, device):
2717        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
2718
2719        def op(x):
2720            return x.permute(2, 0, 1)
2721
2722        self._batched_grad_test(op, (x,))
2723
2724    def test_reshape(self, device):
2725        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
2726
2727        def op(x):
2728            return x.reshape([2 * 3, 5])
2729
2730        self._batched_grad_test(op, (x,))
2731
2732    def test_sigmoid(self, device):
2733        x = torch.randn(2, 3, requires_grad=True, device=device)
2734        self._batched_grad_test(Tensor.sigmoid, (x,))
2735        self._batched_grad_grad_test(Tensor.sigmoid, (x,))
2736
2737    def test_stack(self, device):
2738        x = torch.randn(2, 3, device=device, requires_grad=True)
2739        y = torch.randn(2, 3, device=device, requires_grad=True)
2740
2741        def op(x, y):
2742            return torch.stack([x, y])
2743
2744        self._batched_grad_test(op, (x, y))
2745
2746    def test_select(self, device):
2747        x = torch.randn(2, 3, device=device, requires_grad=True)
2748        self._batched_grad_test(lambda x: x[1], (x,))
2749        self._batched_grad_test(lambda x: x.select(1, 2), (x,))
2750        self._batched_grad_test(lambda x: x.select(-1, 0), (x,))
2751
2752    def test_slice(self, device):
2753        x = torch.randn(2, 3, 5, device=device, requires_grad=True)
2754        self._batched_grad_test(lambda x: x[0:1], (x,))
2755        self._batched_grad_test(lambda x: x[:, 1:3], (x,))
2756        self._batched_grad_test(lambda x: x[..., 1:3], (x,))
2757
2758    def test_trace(self, device):
2759        x = torch.randn(2, 3, device=device, requires_grad=True)
2760        self._batched_grad_test(Tensor.trace, (x,))
2761
2762    def test_threshold(self, device):
2763        x = torch.randn(2, 3, device=device, requires_grad=True)
2764        self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
2765
2766    @allowVmapFallbackUsage
2767    def test_inplace_on_view(self, device):
2768        leaf = torch.randn(4, 5, requires_grad=True)
2769
2770        def func(leaf):
2771            # Make sure the function is non-trivially twice differentiable
2772            base = leaf * leaf
2773            view = base[0]
2774            view.cos_()
2775            return view
2776
2777        self._batched_grad_test(func, (leaf,), {})
2778        self._batched_grad_grad_test(func, (leaf,), {})
2779
2780    @allowVmapFallbackUsage
2781    def test_inplace_manyview(self, device):
2782        leaf = torch.randn(4, 4, 5, requires_grad=True)
2783
2784        def func(leaf):
2785            # Make sure the function is non-trivially twice differentiable
2786            base = leaf * leaf
2787            view = base.transpose(0, 2)
2788            view = view[1]
2789            view = view.diagonal()
2790            view = view[::2]
2791            view.cos_()
2792            return view
2793
2794        self._batched_grad_test(func, (leaf,), {})
2795        self._batched_grad_grad_test(func, (leaf,), {})
2796
2797    def test_diagonal(self, device):
2798        x = torch.randn(4, 5, device=device, requires_grad=True)
2799        self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,))
2800
2801        x = torch.randn(3, 4, 5, device=device, requires_grad=True)
2802        self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))
2803
2804    @allowVmapFallbackUsage
2805    def test_unrelated_output(self, device):
2806        B0 = 3
2807        x = torch.randn([], requires_grad=True)
2808        y = torch.randn([], requires_grad=True)
2809        gy = torch.randn(B0, requires_grad=True)
2810
2811        def vjp(v):
2812            (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
2813            return torch.zeros_like(x) if res is None else res
2814
2815        result = vmap(vjp)(gy)
2816        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
2817
2818    @allowVmapFallbackUsage
2819    def test_unrelated_output_multiple_grad(self, device):
2820        B0 = 3
2821        x = torch.randn([], requires_grad=True)
2822        y = torch.randn([], requires_grad=True)
2823        gy = torch.randn(B0, requires_grad=True)
2824
2825        def vjp(v):
2826            (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
2827            return torch.zeros_like(x) if res is None else res
2828
2829        _ = vjp(gy[0])
2830        result = vmap(vjp)(gy)
2831        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
2832
2833
2834instantiate_device_type_tests(TestVmapBatchedGradientLegacy, globals(), None)
2835
2836if __name__ == "__main__":
2837    run_tests()
2838