xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/legacy_vmap_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/LegacyBatchedTensorImpl.h>
5 #include <ATen/LegacyVmapTransforms.h>
6 #include <c10/util/irange.h>
7 
8 using namespace at;
9 
10 namespace {
11 
TEST(VmapTest,TestBatchedTensor)12 TEST(VmapTest, TestBatchedTensor) {
13   {
14     // NOLINTNEXTLINE(bugprone-argument-comment)
15     Tensor x = addBatchDim(ones({2, 3, 4}), /*lvl=*/1, /*dim=*/1);
16     std::vector<int64_t> expected_size = {2, 4};
17     ASSERT_EQ(x.sizes(), expected_size);
18     ASSERT_EQ(x.dim(), 2);
19     ASSERT_EQ(x.numel(), 8);
20     ASSERT_EQ(x.is_contiguous(), false);
21     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
22     ASSERT_THROW(x.storage(), c10::Error);
23     ASSERT_EQ(x.storage_offset(), 0);
24   }
25   {
26     // Test multiple batch dims
27     // NOLINTNEXTLINE(bugprone-argument-comment)
28     Tensor x = addBatchDim(ones({2, 3, 4}), /*lvl=*/1, /*dim=*/1);
29     // NOLINTNEXTLINE(bugprone-argument-comment)
30     x = addBatchDim(x, /*lvl=*/2, /*dim=*/1);
31     std::vector<int64_t> expected_size = {2};
32     ASSERT_EQ(x.sizes(), expected_size);
33     ASSERT_EQ(x.dim(), 1);
34     ASSERT_EQ(x.numel(), 2);
35   }
36   {
37     // Test vmap tensor dimensionality limit
38 
39     // Should not throw
40     std::vector<int64_t> sizes(kVmapMaxTensorDims, 1);
41     // NOLINTNEXTLINE(bugprone-argument-comment)
42     Tensor x = addBatchDim(ones(sizes), /*lvl=*/1, /*dim=*/1);
43 
44     // Should throw
45     std::vector<int64_t> too_many_sizes(kVmapMaxTensorDims + 1, 1);
46     auto big_dim_tensor = ones(too_many_sizes);
47     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto,bugprone-argument-comment)
48     ASSERT_THROW(addBatchDim(big_dim_tensor, /*lvl=*/1, /*dim=*/1), c10::Error);
49   }
50   {
51     // Create a "scalar" BatchedTensor. Should not crash.
52     Tensor tensor = addBatchDim(ones({3}), /*lvl*/1, /*dim*/0);
53   }
54 }
55 
56 // returns {{lvl=0,dim=0}, {lvl=1,dim=1}, ..., {lvl=kVmapNumLevels-1,dim=kVmapNumLevels-1}};
maxBatchDimsAtFront()57 static BatchDims maxBatchDimsAtFront() {
58   BatchDims result;
59   for (const auto lvl : c10::irange(kVmapNumLevels)) {
60     result.emplace_back(lvl, /*dim=*/lvl);
61   }
62   return result;
63 }
64 
TEST(VmapTest,TestBatchedTensorMaxLevel)65 TEST(VmapTest, TestBatchedTensorMaxLevel) {
66   {
67     // Should not throw
68     auto tensor = ones({2, 3, 4});
69     makeBatched(ones({2, 3, 4}), {{/*lvl*/kVmapNumLevels - 1, /*dim*/0}});
70   }
71   {
72     auto tensor = ones({2, 3, 4});
73     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
74     ASSERT_THROW(
75         makeBatched(ones({2, 3, 4}), {{/*lvl*/kVmapNumLevels, /*dim*/0}}),
76         c10::Error);
77   }
78   {
79     auto tensor = ones({2, 3, 4});
80     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
81     ASSERT_THROW(
82         makeBatched(ones({2, 3, 4}), {{/*lvl*/kVmapNumLevels + 5, /*dim*/0}}),
83         c10::Error);
84   }
85   {
86     // create a BatchedTensor with kVmapNumLevels levels.
87     // Should not throw
88     auto tensor = ones(std::vector<int64_t>(kVmapNumLevels, 1));
89     makeBatched(tensor, maxBatchDimsAtFront());
90   }
91   {
92     // create a BatchedTensor with kVmapNumLevels+1 levels.
93     auto tensor = ones(std::vector<int64_t>(kVmapNumLevels + 1, 1));
94     auto batch_dims = maxBatchDimsAtFront();
95     batch_dims.emplace_back(/*lvl*/kVmapNumLevels, /*dim*/kVmapNumLevels);
96     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
97     ASSERT_THROW(makeBatched(tensor, batch_dims), c10::Error);
98   }
99 }
100 
TEST(VmapTest,TestBatchedTensorActualDim)101 TEST(VmapTest, TestBatchedTensorActualDim) {
102   {
103     // No batch dims
104     Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {});
105     auto* batched = maybeGetBatchedImpl(tensor);
106     ASSERT_EQ(batched->actualDim(0), 0);
107     ASSERT_EQ(batched->actualDim(1), 1);
108     ASSERT_EQ(batched->actualDim(3), 3);
109 
110     // Test wrap around
111     ASSERT_EQ(batched->actualDim(-1), 3);
112     ASSERT_EQ(batched->actualDim(-4), 0);
113     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
114     ASSERT_THROW(batched->actualDim(-5), c10::Error);
115     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
116     ASSERT_THROW(batched->actualDim(4), c10::Error);
117 
118     // test wrap_dim = False
119     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
120     ASSERT_THROW(batched->actualDim(-1, /*wrap_dim*/false), c10::Error);
121     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
122     ASSERT_THROW(batched->actualDim(-4, /*wrap_dim*/false), c10::Error);
123   }
124   {
125     // Single batch dim at front
126     Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {{/*lvl*/1, /*dim*/0}});
127     auto* batched = maybeGetBatchedImpl(tensor);
128     ASSERT_EQ(batched->actualDim(0), 1);
129     ASSERT_EQ(batched->actualDim(2), 3);
130     ASSERT_EQ(batched->actualDim(-1), 3);
131     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
132     ASSERT_THROW(batched->actualDim(3), c10::Error);
133   }
134   {
135     // Single batch dim in middle
136     Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {{/*lvl*/1, /*dim*/1}});
137     auto* batched = maybeGetBatchedImpl(tensor);
138     ASSERT_EQ(batched->actualDim(0), 0);
139     ASSERT_EQ(batched->actualDim(1), 2);
140     ASSERT_EQ(batched->actualDim(2), 3);
141   }
142   {
143     // Single batch dim at end
144     Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {{/*lvl*/1, /*dim*/1}});
145     auto* batched = maybeGetBatchedImpl(tensor);
146     ASSERT_EQ(batched->actualDim(0), 0);
147     ASSERT_EQ(batched->actualDim(2), 3);
148     ASSERT_EQ(batched->actualDim(-1), 3);
149   }
150   {
151     // Multiple (2) batch dims at front
152     Tensor tensor = makeBatched(
153         ones({2, 3, 5, 7}),
154         {{/*lvl*/1, /*dim*/0}, {/*lvl*/2, /*dim*/1}});
155     auto* batched = maybeGetBatchedImpl(tensor);
156     ASSERT_EQ(batched->actualDim(0), 2);
157     ASSERT_EQ(batched->actualDim(1), 3);
158   }
159   {
160     // Multiple (2) batch dims, misc places
161     Tensor tensor = makeBatched(
162         ones({2, 3, 5, 7}),
163         {{/*lvl*/1, /*dim*/1}, {/*lvl*/2, /*dim*/3}});
164     auto* batched = maybeGetBatchedImpl(tensor);
165     ASSERT_EQ(batched->actualDim(0), 0);
166     ASSERT_EQ(batched->actualDim(1), 2);
167     ASSERT_EQ(batched->actualDim(-1), 2);
168     ASSERT_EQ(batched->actualDim(-2), 0);
169   }
170   {
171     // ActualDim on kVmapMaxTensorDims sized underlying tensor
172     auto tensor = ones({});
173     for (C10_UNUSED const auto i : c10::irange(kVmapMaxTensorDims)) {
174       tensor = tensor.unsqueeze(0);
175     }
176     ASSERT_EQ(tensor.dim(), kVmapMaxTensorDims);
177 
178     auto batched = addBatchDim(tensor, /*lvl*/1, /*dim*/0);
179     auto* batched_impl = maybeGetBatchedImpl(batched);
180     ASSERT_EQ(
181         batched_impl->actualDim(kVmapMaxTensorDims - 2),
182         kVmapMaxTensorDims - 1);
183     ASSERT_EQ(
184         batched_impl->actualDim(-1),
185         kVmapMaxTensorDims - 1);
186   }
187 }
TEST(VmapTest,TestMultiBatchVmapTransform)188 TEST(VmapTest, TestMultiBatchVmapTransform) {
189   {
190     // Input is regular Tensor
191     auto tensor = ones({2, 3, 5});
192     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
193     ASSERT_THROW(MultiBatchVmapTransform::logicalToPhysical(tensor), c10::Error);
194   }
195   {
196     // Input is BatchedTensor, Batch dims are already at the front
197     auto tensor = ones({2, 3, 5});
198     BatchDims bdims = {{/*lvl*/1, /*dim*/0}, {/*lvl*/3, /*dim*/1}};
199     auto batched = makeBatched(tensor, bdims);
200 
201     auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
202     ASSERT_TRUE(result.tensor().is_same(tensor));
203   }
204   {
205     // Single batch dim, not at front
206     auto tensor = ones({2, 3, 5});
207     BatchDims bdims = {{/*lvl*/1, /*dim*/1}};
208     auto batched = makeBatched(tensor, bdims);
209 
210     auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
211     ASSERT_EQ(result.tensor().data_ptr(), tensor.data_ptr());
212     ASSERT_TRUE(at::allclose(result.tensor(), tensor.permute({1, 0, 2})));
213   }
214   {
215     // Multiple batch dims, not at front.
216     auto tensor = ones({2, 3, 5});
217     BatchDims bdims = {{/*lvl*/1, /*dim*/1}, {/*lvl*/2,/*dim*/2}, {/*lvl*/3,/*dim*/0}};
218     auto batched = makeBatched(tensor, bdims);
219 
220     auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
221     ASSERT_EQ(result.tensor().data_ptr(), tensor.data_ptr());
222     ASSERT_TRUE(at::allclose(result.tensor(), tensor.permute({1, 2, 0})));
223   }
224   {
225     // Edge case: kVmapNumLevels levels; batch dims are already at front.
226 
227     // sizes=[2, 1, 3, 1, 1, 7, 1, 1, 1, 1, ...]
228     auto sizes = std::vector<int64_t>(kVmapNumLevels, 1);
229     sizes[0] = 2;
230     sizes[2] = 3;
231     sizes[5] = 7;
232 
233     // bdims = {{lvl=0,dim=0,lvl=1,dim=1,...,{lvl=63,dim=63}}
234     auto batch_dims = maxBatchDimsAtFront();
235     auto tensor = ones(sizes);
236 
237     auto batched = makeBatched(tensor, batch_dims);
238     auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
239     ASSERT_TRUE(result.tensor().is_same(tensor));
240   }
241   {
242     // Edge case: kVmapNumLevels levels; batch dims are not at front
243 
244     // sizes=[1, 3, 2, 1, 1, 7, 1, 1, 1, 1, ..., 1, 1, 5]
245     auto sizes = std::vector<int64_t>(kVmapNumLevels, 1);
246     sizes[1] = 3;
247     sizes[2] = 2;
248     sizes[5] = 7;
249     sizes[kVmapNumLevels - 1] = 5;
250 
251     // The goal is to permute sizes such that the final sizes are:
252     // [2, 3, 5, 7, 1, 1, 1, 1, 1, ...]
253     auto expected_result_sizes = std::vector<int64_t>(kVmapNumLevels, 1);
254     expected_result_sizes[0] = 2;
255     expected_result_sizes[1] = 3;
256     expected_result_sizes[2] = 5;
257     expected_result_sizes[3] = 7;
258 
259     // bdims = {{0, 2}, {1, 1}, {2, 63}, {3, 5}, {4, 0}, {5, 3}, {6, 4},
260     //          {7, 6}, {8, 7}, {9, 8}, ..., {63, 62}}
261     BatchDims batch_dims = {
262       {0, 2}, {1, 1}, {2, kVmapNumLevels - 1}, {3, 5}, {4, 0}, {5, 3}, {6, 4}
263     };
264     for (const auto level : c10::irange(7, kVmapNumLevels)) {
265       batch_dims.emplace_back(level, /*dim=*/level - 1);
266     }
267     auto tensor = ones(sizes);
268 
269     auto batched = makeBatched(tensor, batch_dims);
270     auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
271     ASSERT_EQ(result.tensor().data_ptr(), tensor.data_ptr());
272     ASSERT_EQ(result.tensor().sizes(), expected_result_sizes);
273   }
274 }
TEST(VmapTest,TestVmapPhysicalViewGetPhysicalDim)275 TEST(VmapTest, TestVmapPhysicalViewGetPhysicalDim) {
276   VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 1 | 4);
277 
278   // Positive dims
279   ASSERT_EQ(physical_view.getPhysicalDim(0), 2);
280   ASSERT_EQ(physical_view.getPhysicalDim(1), 3);
281   ASSERT_EQ(physical_view.getPhysicalDim(2), 4);
282   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
283   ASSERT_THROW(physical_view.getPhysicalDim(3), c10::Error);
284 
285   // Negative dims (testing wrap dim behavior)
286   ASSERT_EQ(physical_view.getPhysicalDim(-1), 4);
287   ASSERT_EQ(physical_view.getPhysicalDim(-2), 3);
288   ASSERT_EQ(physical_view.getPhysicalDim(-3), 2);
289   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
290   ASSERT_THROW(physical_view.getPhysicalDim(-4), c10::Error);
291 }
TEST(VmapTest,TestVmapPhysicalViewGetPhysicalDims)292 TEST(VmapTest, TestVmapPhysicalViewGetPhysicalDims) {
293   VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 2 | 8 | 16);
294 
295   ASSERT_EQ(
296       physical_view.getPhysicalDims({0, 1, -1, -2}),
297       VmapDimVector({3, 4, 4, 3}));
298 
299   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
300   ASSERT_THROW(physical_view.getPhysicalDims({2, 0}), c10::Error);
301   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
302   ASSERT_THROW(physical_view.getPhysicalDims({0, -3}), c10::Error);
303 }
304 
checkBatchDimsEqual(BatchDimsRef bdims,BatchDimsRef expected_bdims)305 static void checkBatchDimsEqual(BatchDimsRef bdims, BatchDimsRef expected_bdims) {
306   ASSERT_EQ(bdims.size(), expected_bdims.size());
307   for (const auto idx : c10::irange(bdims.size())) {
308     ASSERT_EQ(bdims[idx].dim(), expected_bdims[idx].dim());
309     ASSERT_EQ(bdims[idx].level(), expected_bdims[idx].level());
310   }
311 }
312 
TEST(VmapTest,TestVmapPhysicalViewNewLogicalFromPhysical)313 TEST(VmapTest, TestVmapPhysicalViewNewLogicalFromPhysical) {
314   {
315     // Simple case: single level
316     VmapPhysicalView physical_view(ones({2, 3, 4}), /*levels = {2}*/4);
317     Tensor physical = ones({2, 6, 7});
318 
319     auto result = physical_view.getPhysicalToLogicalMap().apply(physical);
320     auto* batched = maybeGetBatchedImpl(result);
321     ASSERT_TRUE(batched != nullptr);
322     ASSERT_TRUE(batched->value().is_same(physical));
323     checkBatchDimsEqual(batched->bdims(), {{2, 0}});
324   }
325   {
326     // Multiple levels
327     VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), /*levels = {1, 3, 4}*/2 | 8 | 16);
328     Tensor physical = ones({2, 3, 4, 7});
329 
330     auto result = physical_view.getPhysicalToLogicalMap().apply(physical);
331     auto* batched = maybeGetBatchedImpl(result);
332     ASSERT_TRUE(batched != nullptr);
333     ASSERT_TRUE(batched->value().is_same(physical));
334     checkBatchDimsEqual(batched->bdims(), {{1, 0}, {3, 1}, {4, 2}});
335   }
336   {
337     // Logical dimensions is [].
338     VmapPhysicalView physical_view(ones({2}), /*levels = {2}*/4);
339     Tensor physical = ones({2});
340 
341     auto result = physical_view.getPhysicalToLogicalMap().apply(physical);
342     auto* batched = maybeGetBatchedImpl(result);
343     ASSERT_TRUE(batched != nullptr);
344     ASSERT_TRUE(batched->value().is_same(physical));
345     checkBatchDimsEqual(batched->bdims(), {{2, 0}});
346   }
347 }
348 
349 // Basic test for BatchedTensor::sum.
350 // NB: We don't need to write tests in C++ for batching rules if we can test them
351 // in Python via the vmap API. These are here to bootstrap that process.
TEST(VmapTest,TestBatchedTensorSum)352 TEST(VmapTest, TestBatchedTensorSum) {
353   {
354     // Simple: single batch dim, single reduce dim
355     Tensor x = at::randn({2, 3, 5, 7});
356 
357     Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/0}});
358     Tensor batched_out = batched_x.sum(0);
359     const auto& out = maybeGetBatchedImpl(batched_out)->value();
360 
361     ASSERT_TRUE(at::allclose(out, x.sum(1)));
362   }
363   {
364     // single batch dim, -1 reduce dim handling
365     Tensor x = at::randn({2, 3});
366 
367     Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/1}});
368     Tensor batched_out = batched_x.sum(-1);
369     const auto& out = maybeGetBatchedImpl(batched_out)->value();
370 
371     ASSERT_TRUE(at::allclose(out, x.sum(0)));
372   }
373   {
374     // single batch dim, multiple reduce dim
375     Tensor x = at::randn({2, 3, 5, 7});
376 
377     Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/1}});
378     Tensor batched_out = batched_x.sum(std::vector<int64_t>{0, 1});
379     const auto& out = maybeGetBatchedImpl(batched_out)->value();
380 
381     ASSERT_TRUE(at::allclose(out, x.sum(std::vector<int64_t>{0, 2})));
382   }
383   {
384     // multiple batch dim, multiple reduce dim
385     Tensor x = at::randn({2, 3, 5, 7});
386 
387     Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/0}, {/*lvl*/2, /*dim*/1}});
388     Tensor batched_out = batched_x.sum(std::vector<int64_t>{0, 1});
389     const auto& out = maybeGetBatchedImpl(batched_out)->value();
390 
391     ASSERT_TRUE(at::allclose(out, x.sum(std::vector<int64_t>{2, 3})));
392   }
393 }
394 
checkBroadcastingVmapTransform(TensorList inputs,TensorList expected_outputs)395 static void checkBroadcastingVmapTransform(TensorList inputs, TensorList expected_outputs) {
396   auto outputs = BroadcastingVmapTransform::logicalToPhysical(inputs);
397   ASSERT_EQ(outputs.size(), expected_outputs.size());
398   for (const auto idx : c10::irange(outputs.size())) {
399     const auto& output = outputs[idx].tensor();
400     ASSERT_EQ(output.data_ptr(), expected_outputs[idx].data_ptr());
401     ASSERT_TRUE(at::allclose(output, expected_outputs[idx]));
402   }
403 }
404 
TEST(VmapTest,TestBroadcastingVmapTransformBatchedBatched)405 TEST(VmapTest, TestBroadcastingVmapTransformBatchedBatched) {
406   {
407     // Check that batch dims get moved to the front
408     int64_t B0 = 5, B1 = 7;
409     Tensor x = at::randn({2, B0, 3, B1});
410     Tensor y = at::randn({B1, 2, 3, B0});
411     Tensor batched_x = makeBatched(x, {{0, 1}, {1, 3}});
412     Tensor batched_y = makeBatched(y, {{0, 3}, {1, 0}});
413 
414     checkBroadcastingVmapTransform(
415         {batched_x, batched_y},
416         {x.permute({1, 3, 0, 2}), y.permute({3, 0, 1, 2})});
417   }
418   {
419     // Check that batch dims become aligned (i.e. extra 1 dims get added)
420     int64_t B0 = 5, B1 = 7, B2 = 9;
421     Tensor x = at::randn({B0, B2, 2, 3});
422     Tensor y = at::randn({B0, B1, 2, 3});
423     Tensor batched_x = makeBatched(x, {{0, 0}, {2, 1}});
424     Tensor batched_y = makeBatched(y, {{0, 0}, {1, 1}});
425 
426     checkBroadcastingVmapTransform(
427         {batched_x, batched_y},
428         {x.unsqueeze(1), y.unsqueeze(2)});
429   }
430   {
431     // Check that the "example" gets padded with extra dims of size 1.
432     int64_t B0 = 5;
433     Tensor x = at::randn({B0, 3});
434     Tensor y = at::randn({B0, 2, 3});
435     Tensor batched_x = makeBatched(x, {{0, 0}});
436     Tensor batched_y = makeBatched(y, {{0, 0}});
437 
438     checkBroadcastingVmapTransform(
439         {batched_x, batched_y},
440         {x.unsqueeze(1), y});
441   }
442   {
443     // Check batch dims get moved to front, batch dims get aligned,
444     // and the example gets padded correctly.
445     int64_t B0 = 5, B1 = 7, B2 = 11, B3 = 13;
446     Tensor x = at::randn({2, B0, 3, B2});
447     Tensor y = at::randn({B3, 3, B1});
448     Tensor batched_x = makeBatched(x, {{0, 1}, {2, 3}});
449     Tensor batched_y = makeBatched(y, {{1, 2}, {3, 0}});
450 
451     checkBroadcastingVmapTransform(
452         {batched_x, batched_y},
453         {
454           x.permute({1, 3, 0, 2}).view({B0, 1, B2, 1, 2, 3}),
455           y.permute({2, 0, 1}).view({1, B1, 1, B3, 1, 3}),
456         });
457   }
458   {
459     // Edge case: BatchedTensor "scalar" handling
460     int64_t B0 = 5, B2 = 11;
461     Tensor x = at::randn({B0});
462     Tensor y = at::randn({B0, B2});
463     Tensor batched_x = makeBatched(x, {{0, 0}});
464     Tensor batched_y = makeBatched(y, {{0, 0}, {1, 1}});
465 
466     checkBroadcastingVmapTransform({batched_x, batched_y}, {x.view({B0, 1}), y});
467     checkBroadcastingVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1})});
468   }
469   {
470     // Edge case: Only one tensor is a "batchedtensor scalar"
471     int64_t B0 = 5, B2 = 11;
472     Tensor x = at::randn({B0});
473     Tensor y = at::randn({B0, B2, 2});
474     Tensor batched_x = makeBatched(x, {{0, 0}});
475     Tensor batched_y = makeBatched(y, {{0, 0}, {1, 1}});
476 
477     checkBroadcastingVmapTransform({batched_x, batched_y}, {x.view({B0, 1, 1}), y});
478     checkBroadcastingVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1, 1})});
479   }
480 }
481 
TEST(VmapTest,TestBroadcastingVmapTransformBatchedUnbatched)482 TEST(VmapTest, TestBroadcastingVmapTransformBatchedUnbatched) {
483   {
484     // Check same example size
485     int64_t B0 = 5, B1 = 7;
486     Tensor x = at::randn({2, B0, 3, B1});
487     Tensor y = at::randn({2, 3});
488     Tensor batched_x = makeBatched(x, {{0, 1}, {1, 3}});
489 
490     checkBroadcastingVmapTransform(
491         {batched_x, y},
492         {x.permute({1, 3, 0, 2}), y.view({1, 1, 2, 3})});
493     checkBroadcastingVmapTransform(
494         {y, batched_x},
495         {y.view({1, 1, 2, 3}), x.permute({1, 3, 0, 2})});
496   }
497   {
498     // BatchedTensor has higher example dim than non-batched-tensor
499     int64_t B0 = 5, B1 = 7;
500     Tensor x = at::randn({B0, B1, 2, 3});
501     Tensor y = at::randn({3});
502     Tensor batched_x = makeBatched(x, {{0, 0}, {1, 1}});
503 
504     checkBroadcastingVmapTransform(
505         {batched_x, y}, {x, y.view({1, 1, 1, 3})});
506     checkBroadcastingVmapTransform(
507         {y, batched_x}, {y.view({1, 1, 1, 3}), x});
508   }
509   {
510     // BatchedTensor has lower example dim than non-batched-tensor
511     int64_t B0 = 5, B1 = 7;
512     Tensor x = at::randn({B0, B1, 3});
513     Tensor y = at::randn({2, 3});
514     Tensor batched_x = makeBatched(x, {{0, 0}, {1, 1}});
515 
516     checkBroadcastingVmapTransform(
517         {batched_x, y}, {x.view({B0, B1, 1, 3}), y.view({1, 1, 2, 3})});
518     checkBroadcastingVmapTransform(
519         {y, batched_x}, {y.view({1, 1, 2, 3}), x.view({B0, B1, 1, 3})});
520   }
521   {
522     // Scalar handling
523     int64_t B0 = 5, B1 = 7;
524     Tensor x = at::randn({B0, B1});
525     Tensor y = at::randn({});
526     Tensor batched_x = makeBatched(x, {{0, 0}, {1, 1}});
527 
528     checkBroadcastingVmapTransform({batched_x, y}, {x, y.view({1, 1})});
529     checkBroadcastingVmapTransform({y, batched_x}, {y.view({1, 1}), x});
530   }
531 }
532 
TEST(VmapTest,TestBroadcastingVmapTransformMaxLevels)533 TEST(VmapTest, TestBroadcastingVmapTransformMaxLevels) {
534   {
535     // inputs have all 64 levels
536     auto x = randn(std::vector<int64_t>(kVmapNumLevels, 1));
537     auto y = randn(std::vector<int64_t>(kVmapNumLevels, 1));
538     auto batched_x = makeBatched(x, maxBatchDimsAtFront());
539     auto batched_y = makeBatched(y, maxBatchDimsAtFront());
540 
541     checkBroadcastingVmapTransform({batched_x, batched_y}, {x, y});
542   }
543   {
544     // inputs don't have all 64 levels, but results do.
545     int64_t split = 19;
546     auto x = randn(std::vector<int64_t>(split, 1));
547     auto y = randn(std::vector<int64_t>(kVmapNumLevels - split, 1));
548 
549     auto tmp = maxBatchDimsAtFront();
550     BatchDims x_bdims(tmp.begin(), tmp.begin() + split);
551 
552     // Construct y_bdims.
553     int64_t dim = 0;
554     auto y_bdims_vector = fmap(
555         ArrayRef<BatchDim>(tmp.begin() + split, tmp.end()),
556         [&](const BatchDim& bdim) -> BatchDim {
557           return { bdim.level(), dim++ };
558         });
559     BatchDims y_bdims(y_bdims_vector.begin(), y_bdims_vector.end());
560 
561     auto batched_x = makeBatched(x, x_bdims);
562     auto batched_y = makeBatched(y, y_bdims);
563 
564     auto expected_size = std::vector<int64_t>(kVmapNumLevels, 1);
565     checkBroadcastingVmapTransform(
566         {batched_x, batched_y},
567         {x.view(expected_size), y.view(expected_size)});
568   }
569 }
570 
571 // Basic test for BatchedTensor::mul.
TEST(VmapTest,TestBatchedTensorMul)572 TEST(VmapTest, TestBatchedTensorMul) {
573   {
574     // batched * batched
575     Tensor x = at::randn({2, 3});
576     Tensor y = at::randn({2, 3});
577 
578     Tensor Bx = addBatchDim(x, /*lvl*/1, /*dim*/0);
579     Tensor By = addBatchDim(y, /*lvl*/1, /*dim*/0);
580     Tensor Bout = Bx * By;
581 
582     const auto& out = maybeGetBatchedImpl(Bout)->value();
583     std::vector<int64_t> expected_size = {2, 3};
584     ASSERT_EQ(out.sizes(), expected_size);
585     ASSERT_TRUE(at::allclose(out, x * y));
586   }
587   {
588     // batched * unbatched
589     Tensor x = at::randn({2, 3});
590     Tensor y = at::randn({3});
591 
592     Tensor Bx = addBatchDim(x, /*lvl*/1, /*dim*/0);
593     Tensor Bout = Bx * y;
594     const auto& out = maybeGetBatchedImpl(Bout)->value();
595     std::vector<int64_t> expected_size = {2, 3};
596     ASSERT_EQ(out.sizes(), expected_size);
597     ASSERT_TRUE(at::allclose(out, x * y));
598   }
599   {
600     // batched (level 1) * batched (level 2)
601     Tensor x = at::randn({2, 3});
602     Tensor y = at::randn({5, 3});
603 
604     Tensor Bx = addBatchDim(x, /*lvl*/1, /*dim*/0);
605     Tensor By = addBatchDim(y, /*lvl*/2, /*dim*/0);
606     Tensor Bout = Bx * By;
607 
608     // We get a doubly wrapped BatchTensor...
609     const auto& out = maybeGetBatchedImpl(Bout)->value();
610     std::vector<int64_t> expected_size = {2, 5, 3};
611     ASSERT_EQ(out.sizes(), expected_size);
612     ASSERT_TRUE(at::allclose(out, x.unsqueeze(1) * y));
613   }
614   {
615     // batched (level 2, 3, 4) * batched (level 3, 1, 2)
616     Tensor x = at::randn({3, 5, 7});
617     Tensor y = at::randn({5, 2, 3});
618 
619     // Each BatchDim is constructed in {dim, level} format.
620     Tensor Bx = makeBatched(x, {{2, 0}, {3, 1}, {4, 2}});
621     Tensor By = makeBatched(y, {{1, 1}, {2, 2}, {3, 0}});
622     Tensor Bout = Bx * By;
623 
624     const auto& out = maybeGetBatchedImpl(Bout)->value();
625 
626     // The batching rule aligns dimensions in the order of their `level`.
627     // It just happened that we chose sizes to be in the same order as the level.
628     std::vector<int64_t> expected_size = {2, 3, 5, 7};
629     ASSERT_EQ(out.sizes(), expected_size);
630     ASSERT_TRUE(at::allclose(out, x * y.permute({1, 2, 0}).unsqueeze(3)));
631   }
632 }
633 
634 // test for BatchedTensor::size(int).
TEST(VmapTest,TestBatchedTensorSize)635 TEST(VmapTest, TestBatchedTensorSize) {
636   {
637     // Single batch dim at front
638     Tensor x = at::randn({3, 5, 7});
639     Tensor Bx = makeBatched(x, {{0, 0}});
640 
641     ASSERT_EQ(Bx.size(0), 5);
642     ASSERT_EQ(Bx.size(1), 7);
643     ASSERT_EQ(Bx.size(-1), 7);
644     ASSERT_EQ(Bx.size(-2), 5);
645     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
646     ASSERT_THROW(Bx.size(2), c10::Error);
647     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
648     ASSERT_THROW(Bx.size(-3), c10::Error);
649   }
650   {
651     // multiple batch dims not at front
652     Tensor x = at::randn({2, 3, 5, 7, 11});
653     Tensor Bx = makeBatched(x, {{0, 3}, {1, 1}});
654 
655     ASSERT_EQ(Bx.size(0), 2);
656     ASSERT_EQ(Bx.size(1), 5);
657     ASSERT_EQ(Bx.size(2), 11);
658     ASSERT_EQ(Bx.size(-1), 11);
659     ASSERT_EQ(Bx.size(-2), 5);
660     ASSERT_EQ(Bx.size(-3), 2);
661     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
662     ASSERT_THROW(Bx.size(3), c10::Error);
663     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
664     ASSERT_THROW(Bx.size(-4), c10::Error);
665   }
666 }
667 
TEST(VmapTest,TestVmapPhysicalViewGetPhysicalShape)668 TEST(VmapTest, TestVmapPhysicalViewGetPhysicalShape) {
669   {
670     VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 1 | 4);
671     ASSERT_EQ(physical_view.getPhysicalShape({}), VmapDimVector({2, 3}));
672     ASSERT_EQ(physical_view.getPhysicalShape({7}), VmapDimVector({2, 3, 7}));
673     ASSERT_EQ(physical_view.getPhysicalShape({7, 11, 13}), VmapDimVector({2, 3, 7, 11, 13}));
674     ASSERT_EQ(physical_view.getPhysicalShape({7, 11, 13, 17}), VmapDimVector({2, 3, 7, 11, 13, 17}));
675   }
676   {
677     VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 2);
678     ASSERT_EQ(physical_view.getPhysicalShape({}), VmapDimVector({2}));
679     ASSERT_EQ(physical_view.getPhysicalShape({7}), VmapDimVector({2, 7}));
680   }
681 }
682 
683 // Basic test for BatchedTensor::expand
TEST(VmapTest,TestBatchedTensorExpand)684 TEST(VmapTest, TestBatchedTensorExpand) {
685   {
686     // Expand size is too small
687     auto tensor = at::randn({2, 3, 5});
688     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
689     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
690     ASSERT_THROW(batched.expand({5}), c10::Error);
691   }
692   {
693     // Expand size has same dimensionality as the logical dim
694     auto tensor = at::randn({2, 1, 5});
695     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
696     auto batched_out = batched.expand({3, 5});
697     const auto& out = maybeGetBatchedImpl(batched_out)->value();
698 
699     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
700     ASSERT_TRUE(at::allclose(out, tensor.expand({2, 3, 5})));
701   }
702   {
703     // Expand size has same dimensionality as the logical dim, incorrect expand size
704     auto tensor = at::randn({2, 1, 5});
705     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
706     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
707     ASSERT_THROW(batched.expand({1, 25}), c10::Error);
708   }
709   {
710     // Expand size has greater dimensionality as the logical dim
711     auto tensor = at::randn({2, 3, 5});
712     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
713     auto batched_out = batched.expand({7, 3, 5});
714     const auto& out = maybeGetBatchedImpl(batched_out)->value();
715 
716     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
717     ASSERT_TRUE(at::allclose(out, tensor.view({2, 1, 3, 5}).expand({2, 7, 3, 5})));
718   }
719   {
720     // Expand size has greater dimensionality as the logical dim, incorrect expand size
721     auto tensor = at::randn({2, 3, 5});
722     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
723     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
724     ASSERT_THROW(batched.expand({7, 9, 5}), c10::Error);
725   }
726   {
727     // logical dim is 0, expand size has same dimensionality as logical dim
728     auto tensor = at::randn({2, 3});
729     auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
730     auto batched_out = batched.expand(c10::IntArrayRef({}));
731     const auto& out = maybeGetBatchedImpl(batched_out)->value();
732     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
733     ASSERT_TRUE(at::allclose(out, tensor));
734   }
735   {
736     // logical dim is 0, expand size has greater dimensionality than logical dim
737     auto tensor = at::randn({2, 3});
738     auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
739     auto batched_out = batched.expand({5, 7});
740     const auto& out = maybeGetBatchedImpl(batched_out)->value();
741     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
742     ASSERT_TRUE(at::allclose(out, tensor.view({2, 3, 1, 1}).expand({2, 3, 5, 7})));
743   }
744 }
745 // Basic test for BatchedTensor::unsqueeze
TEST(VmapTest,TestBatchedTensorUnsqueeze)746 TEST(VmapTest, TestBatchedTensorUnsqueeze) {
747   {
748     // Basic test
749     auto tensor = at::randn({2, 3, 5});  // NOLINT
750     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
751 
752     auto batched_out = batched.unsqueeze(0);
753     const auto& out = maybeGetBatchedImpl(batched_out)->value();
754     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
755     ASSERT_TRUE(at::allclose(out, tensor.unsqueeze(1)));
756   }
757   {
758     // Test with multiple levels
759     auto tensor = at::randn({2, 3, 5});  // NOLINT
760     auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
761 
762     auto batched_out = batched.unsqueeze(0);
763     const auto& out = maybeGetBatchedImpl(batched_out)->value();
764     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
765     ASSERT_TRUE(at::allclose(out, tensor.unsqueeze(2)));
766   }
767   {
768     // Negative dim
769     auto tensor = at::randn({2, 3, 5});  // NOLINT
770     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
771 
772     auto batched_out = batched.unsqueeze(-1);
773     const auto& out = maybeGetBatchedImpl(batched_out)->value();
774     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
775     ASSERT_TRUE(at::allclose(out, tensor.unsqueeze(-1)));
776   }
777 }
778 // Basic test for BatchedTensor::squeeze(dim)
TEST(VmapTest,TestBatchedTensorSqueeze)779 TEST(VmapTest, TestBatchedTensorSqueeze) {
780   {
781     // Basic test
782     auto tensor = at::randn({2, 1, 5});  // NOLINT
783     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
784 
785     auto batched_out = batched.squeeze(0);
786     const auto& out = maybeGetBatchedImpl(batched_out)->value();
787     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
788     ASSERT_TRUE(at::allclose(out, tensor.squeeze(1)));
789   }
790   {
791     // Test with multiple levels
792     auto tensor = at::randn({2, 3, 1});  // NOLINT
793     auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
794 
795     auto batched_out = batched.squeeze(0);
796     const auto& out = maybeGetBatchedImpl(batched_out)->value();
797     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
798     ASSERT_TRUE(at::allclose(out, tensor.squeeze(2)));
799   }
800   {
801     // Negative dim
802     auto tensor = at::randn({2, 3, 1});  // NOLINT
803     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
804 
805     auto batched_out = batched.squeeze(-1);
806     const auto& out = maybeGetBatchedImpl(batched_out)->value();
807     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
808     ASSERT_TRUE(at::allclose(out, tensor.squeeze(-1)));
809   }
810 }
811 // Basic test for BatchedTensor::transpose
TEST(VmapTest,TestBatchedTensorTranspose)812 TEST(VmapTest, TestBatchedTensorTranspose) {
813   {
814     // Basic test
815     auto tensor = at::randn({2, 3, 5});  // NOLINT
816     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
817 
818     auto batched_out = batched.transpose(0, 1);
819     const auto& out = maybeGetBatchedImpl(batched_out)->value();
820     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
821     ASSERT_TRUE(at::allclose(out, tensor.transpose(1, 2)));
822   }
823   {
824     // Test with multiple levels
825     auto tensor = at::randn({2, 3, 5, 7, 11});  // NOLINT
826     auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
827 
828     auto batched_out = batched.transpose(0, 2);
829     const auto& out = maybeGetBatchedImpl(batched_out)->value();
830     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
831     ASSERT_TRUE(at::allclose(out, tensor.transpose(2, 4)));
832   }
833   {
834     // Negative dims
835     auto tensor = at::randn({2, 3, 5, 7});  // NOLINT
836     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
837 
838     auto batched_out = batched.mT();
839     const auto& out = maybeGetBatchedImpl(batched_out)->value();
840     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
841     ASSERT_TRUE(at::allclose(out, tensor.mT()));
842   }
843 }
844 
845 // Basic test for BatchedTensor::permute
TEST(VmapTest,TestBatchedTensorPermute)846 TEST(VmapTest, TestBatchedTensorPermute) {
847   {
848     // Basic test
849     auto tensor = at::randn({2, 3, 5});  // NOLINT
850     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
851 
852     auto batched_out = batched.permute({1, 0});
853     const auto& out = maybeGetBatchedImpl(batched_out)->value();
854     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
855     ASSERT_TRUE(at::allclose(out, tensor.permute({0, 2, 1})));
856   }
857   {
858     // Test with multiple levels
859     auto tensor = at::randn({2, 3, 5, 7, 11});  // NOLINT
860     auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
861 
862     auto batched_out = batched.permute({2, 1, 0});
863     const auto& out = maybeGetBatchedImpl(batched_out)->value();
864     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
865     ASSERT_TRUE(at::allclose(out, tensor.permute({0, 1, 4, 3, 2})));
866   }
867   {
868     // Negative dims
869     auto tensor = at::randn({2, 3, 5, 7});  // NOLINT
870     auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
871 
872     auto batched_out = batched.permute({-1, -2, -3});
873     const auto& out = maybeGetBatchedImpl(batched_out)->value();
874     ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
875     ASSERT_TRUE(at::allclose(out, tensor.permute({0, -1, -2, -3})));
876   }
877 }
878 
checkMultiBatchVmapTransform(TensorList inputs,TensorList expected_outputs)879 static void checkMultiBatchVmapTransform(TensorList inputs, TensorList expected_outputs) {
880   auto outputs = MultiBatchVmapTransform::logicalToPhysical(inputs);
881   ASSERT_EQ(outputs.size(), expected_outputs.size());
882   for (const auto idx : c10::irange(outputs.size())) {
883     const auto& output = outputs[idx].tensor();
884     ASSERT_EQ(output.data_ptr(), expected_outputs[idx].data_ptr());
885     ASSERT_EQ(output.sizes(), expected_outputs[idx].sizes());
886     ASSERT_TRUE(at::allclose(output, expected_outputs[idx]));
887   }
888 }
889 
TEST(VmapTest,TestMultiBatchVmapTransformBatchedBatched)890 TEST(VmapTest, TestMultiBatchVmapTransformBatchedBatched) {
891   {
892     // Check that batch dims get moved to the front
893     int64_t B0 = 5, B1 = 7;
894     Tensor x = at::randn({2, B0, 3, B1});
895     Tensor y = at::randn({B1, 2, 3, B0});
896     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
897     Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/3}, {/*lvl*/1, /*dim*/0}});
898 
899     checkMultiBatchVmapTransform(
900         {batched_x, batched_y},
901         {at::movedim(x, {1, 3}, {0, 1}), at::movedim(y, {0, 3}, {1, 0})});
902   }
903   {
904     // Check that batch dims become broadcasted and are present in all returns
905     int64_t B0 = 5, B1 = 7, B2 = 9;
906     Tensor x = at::randn({B0, B2, 2, 3});
907     Tensor y = at::randn({B0, B1, 2, 3});
908     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/2, /*dim*/1}});
909     Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
910 
911     checkMultiBatchVmapTransform(
912         {batched_x, batched_y},
913         {x.unsqueeze(1).expand({B0, B1, B2, 2, 3}), y.unsqueeze(2).expand({B0, B1, B2, 2, 3})});
914   }
915   {
916     // Check operation on tensors of different logical dims
917     int64_t B0 = 5;
918     Tensor x = at::randn({B0, 3});
919     Tensor y = at::randn({B0, 2, 3});
920     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
921     Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}});
922 
923     checkMultiBatchVmapTransform({batched_x, batched_y}, {x, y});
924   }
925   {
926     // More complicated example with two tensors.
927     int64_t B0 = 5, B1 = 7, B2 = 11, B3 = 13;
928     Tensor x = at::randn({2, B0, 3, B2});
929     Tensor y = at::randn({B3, 3, B1});
930     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/2, /*dim*/3}});
931     Tensor batched_y = makeBatched(y, {{/*lvl*/1, /*dim*/2}, {/*lvl*/3, /*dim*/0}});
932 
933     checkMultiBatchVmapTransform(
934         {batched_x, batched_y},
935         {
936           x.permute({1, 3, 0, 2}).view({B0, 1, B2, 1, 2, 3}).expand({B0, B1, B2, B3, 2, 3}),
937           y.permute({2, 0, 1}).view({1, B1, 1, B3, 3}).expand({B0, B1, B2, B3, 3}),
938         });
939   }
940   {
941     // Edge case: BatchedTensor "scalar" handling
942     int64_t B0 = 5, B2 = 11;
943     Tensor x = at::randn({B0});
944     Tensor y = at::randn({B0, B2});
945     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
946     Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
947 
948     checkMultiBatchVmapTransform({batched_x, batched_y}, {x.view({B0, 1}).expand({B0, B2}), y});
949     checkMultiBatchVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1}).expand({B0, B2})});
950   }
951   {
952     // Edge case: Only one tensor is a "batchedtensor scalar"
953     int64_t B0 = 5, B2 = 11;
954     Tensor x = at::randn({B0});
955     Tensor y = at::randn({B0, B2, 2});
956     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
957     Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
958 
959     checkMultiBatchVmapTransform({batched_x, batched_y}, {x.view({B0, 1}).expand({B0, B2}), y});
960     checkMultiBatchVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1}).expand({B0, B2})});
961   }
962 }
963 
TEST(VmapTest,TestMultiBatchVmapTransformBatchedUnbatched)964 TEST(VmapTest, TestMultiBatchVmapTransformBatchedUnbatched) {
965   {
966     // Check same example size
967     int64_t B0 = 5, B1 = 7;
968     Tensor x = at::randn({2, B0, 3, B1});
969     Tensor y = at::randn({2, 3});
970     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
971 
972     checkMultiBatchVmapTransform(
973         {batched_x, y},
974         {at::movedim(x, {1, 3}, {0, 1}), y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3})});
975     checkMultiBatchVmapTransform(
976         {y, batched_x},
977         {y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3}), at::movedim(x, {1, 3}, {0, 1})});
978   }
979   {
980     // BatchedTensor has higher example dim than non-batched-tensor
981     int64_t B0 = 5, B1 = 7;
982     Tensor x = at::randn({B0, B1, 2, 3});
983     Tensor y = at::randn({3});
984     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
985 
986     checkMultiBatchVmapTransform(
987         {batched_x, y}, {x, y.view({1, 1, 3}).expand({B0, B1, 3})});
988     checkMultiBatchVmapTransform(
989         {y, batched_x}, {y.view({1, 1, 3}).expand({B0, B1, 3}), x});
990   }
991   {
992     // BatchedTensor has lower example dim than non-batched-tensor
993     int64_t B0 = 5, B1 = 7;
994     Tensor x = at::randn({B0, B1, 3});
995     Tensor y = at::randn({2, 3});
996     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
997 
998     checkMultiBatchVmapTransform(
999         {batched_x, y}, {x.view({B0, B1, 3}), y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3})});
1000     checkMultiBatchVmapTransform(
1001         {y, batched_x}, {y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3}), x.view({B0, B1, 3})});
1002   }
1003   {
1004     // Scalar handling
1005     int64_t B0 = 5, B1 = 7;
1006     Tensor x = at::randn({B0, B1});
1007     Tensor y = at::randn({});
1008     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
1009 
1010     checkMultiBatchVmapTransform({batched_x, y}, {x, y.view({1, 1}).expand({B0, B1})});
1011     checkMultiBatchVmapTransform({y, batched_x}, {y.view({1, 1}).expand({B0, B1}), x});
1012   }
1013 }
1014 
TEST(VmapTest,TestMultiBatchVmapTransformMaxLevels)1015 TEST(VmapTest, TestMultiBatchVmapTransformMaxLevels) {
1016   {
1017     // inputs have all 64 levels
1018     auto x = randn(std::vector<int64_t>(kVmapNumLevels, 1));
1019     auto y = randn(std::vector<int64_t>(kVmapNumLevels, 1));
1020     auto batched_x = makeBatched(x, maxBatchDimsAtFront());
1021     auto batched_y = makeBatched(y, maxBatchDimsAtFront());
1022 
1023     checkMultiBatchVmapTransform({batched_x, batched_y}, {x, y});
1024   }
1025   {
1026     // inputs don't have all 64 levels, but results do.
1027     int64_t split = 19;
1028     auto x = randn(std::vector<int64_t>(split, 1));
1029     auto y = randn(std::vector<int64_t>(kVmapNumLevels - split, 1));
1030 
1031     auto tmp = maxBatchDimsAtFront();
1032     BatchDims x_bdims(tmp.begin(), tmp.begin() + split);
1033 
1034     // Construct y_bdims.
1035     int64_t dim = 0;
1036     auto y_bdims_vector = fmap(
1037         ArrayRef<BatchDim>(tmp.begin() + split, tmp.end()),
1038         [&](const BatchDim& bdim) -> BatchDim {
1039           return { bdim.level(), dim++ };
1040         });
1041     BatchDims y_bdims(y_bdims_vector.begin(), y_bdims_vector.end());
1042 
1043     auto batched_x = makeBatched(x, x_bdims);
1044     auto batched_y = makeBatched(y, y_bdims);
1045 
1046     auto expected_size = std::vector<int64_t>(kVmapNumLevels, 1);
1047     checkMultiBatchVmapTransform(
1048         {batched_x, batched_y},
1049         {x.view(expected_size), y.view(expected_size)});
1050   }
1051 }
1052 
TEST(VmapTest,TestMultiBatchVmapTransformMultipleTensors)1053 TEST(VmapTest, TestMultiBatchVmapTransformMultipleTensors) {
1054   // Test with three (all batched) tensors
1055   {
1056     int64_t B0 = 5, B1 = 7, B2 = 9;
1057     Tensor x = at::randn({2, B0, 3, B1});
1058     Tensor y = at::randn({B1, 4});
1059     Tensor z = at::randn({2, B2});
1060     Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
1061     Tensor batched_y = makeBatched(y, {{/*lvl*/1, /*dim*/0}});
1062     Tensor batched_z = makeBatched(z, {{/*lvl*/2, /*dim*/1}});
1063 
1064     checkMultiBatchVmapTransform(
1065         {batched_x, batched_y, batched_z},
1066         {
1067           at::movedim(x, {1, 3}, {0, 1}).view({B0, B1, 1, 2, 3}).expand({B0, B1, B2, 2, 3}),
1068           y.view({1, B1, 1, 4}).expand({B0, B1, B2, 4}),
1069           z.t().view({1, 1, B2, 2}).expand({B0, B1, B2, 2}),
1070         });
1071   }
1072   // Test with three tensors, some batched, some unbatched
1073   {
1074     int64_t B0 = 5, B1 = 7, B2 = 9;
1075     Tensor x = at::randn({2, 3});
1076     Tensor y = at::randn({4, B0});
1077     Tensor z = at::randn({B1, 2, B2});
1078     Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/1}});
1079     Tensor batched_z = makeBatched(z, {{/*lvl*/1, /*dim*/0}, {/*lvl*/2, /*dim*/2}});
1080 
1081     checkMultiBatchVmapTransform(
1082         {x, batched_y, batched_z},
1083         {
1084           x.view({1, 1, 1, 2, 3}).expand({B0, B1, B2, 2, 3}),
1085           y.t().view({B0, 1, 1, 4}).expand({B0, B1, B2, 4}),
1086           z.permute({0, 2, 1}).view({1, B1, B2, 2}).expand({B0, B1, B2, 2}),
1087         });
1088   }
1089 }
1090 
1091 
1092 } // namespace
1093