1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/LegacyBatchedTensorImpl.h>
5 #include <ATen/LegacyVmapTransforms.h>
6 #include <c10/util/irange.h>
7
8 using namespace at;
9
10 namespace {
11
TEST(VmapTest,TestBatchedTensor)12 TEST(VmapTest, TestBatchedTensor) {
13 {
14 // NOLINTNEXTLINE(bugprone-argument-comment)
15 Tensor x = addBatchDim(ones({2, 3, 4}), /*lvl=*/1, /*dim=*/1);
16 std::vector<int64_t> expected_size = {2, 4};
17 ASSERT_EQ(x.sizes(), expected_size);
18 ASSERT_EQ(x.dim(), 2);
19 ASSERT_EQ(x.numel(), 8);
20 ASSERT_EQ(x.is_contiguous(), false);
21 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
22 ASSERT_THROW(x.storage(), c10::Error);
23 ASSERT_EQ(x.storage_offset(), 0);
24 }
25 {
26 // Test multiple batch dims
27 // NOLINTNEXTLINE(bugprone-argument-comment)
28 Tensor x = addBatchDim(ones({2, 3, 4}), /*lvl=*/1, /*dim=*/1);
29 // NOLINTNEXTLINE(bugprone-argument-comment)
30 x = addBatchDim(x, /*lvl=*/2, /*dim=*/1);
31 std::vector<int64_t> expected_size = {2};
32 ASSERT_EQ(x.sizes(), expected_size);
33 ASSERT_EQ(x.dim(), 1);
34 ASSERT_EQ(x.numel(), 2);
35 }
36 {
37 // Test vmap tensor dimensionality limit
38
39 // Should not throw
40 std::vector<int64_t> sizes(kVmapMaxTensorDims, 1);
41 // NOLINTNEXTLINE(bugprone-argument-comment)
42 Tensor x = addBatchDim(ones(sizes), /*lvl=*/1, /*dim=*/1);
43
44 // Should throw
45 std::vector<int64_t> too_many_sizes(kVmapMaxTensorDims + 1, 1);
46 auto big_dim_tensor = ones(too_many_sizes);
47 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto,bugprone-argument-comment)
48 ASSERT_THROW(addBatchDim(big_dim_tensor, /*lvl=*/1, /*dim=*/1), c10::Error);
49 }
50 {
51 // Create a "scalar" BatchedTensor. Should not crash.
52 Tensor tensor = addBatchDim(ones({3}), /*lvl*/1, /*dim*/0);
53 }
54 }
55
56 // returns {{lvl=0,dim=0}, {lvl=1,dim=1}, ..., {lvl=kVmapNumLevels-1,dim=kVmapNumLevels-1}};
maxBatchDimsAtFront()57 static BatchDims maxBatchDimsAtFront() {
58 BatchDims result;
59 for (const auto lvl : c10::irange(kVmapNumLevels)) {
60 result.emplace_back(lvl, /*dim=*/lvl);
61 }
62 return result;
63 }
64
TEST(VmapTest,TestBatchedTensorMaxLevel)65 TEST(VmapTest, TestBatchedTensorMaxLevel) {
66 {
67 // Should not throw
68 auto tensor = ones({2, 3, 4});
69 makeBatched(ones({2, 3, 4}), {{/*lvl*/kVmapNumLevels - 1, /*dim*/0}});
70 }
71 {
72 auto tensor = ones({2, 3, 4});
73 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
74 ASSERT_THROW(
75 makeBatched(ones({2, 3, 4}), {{/*lvl*/kVmapNumLevels, /*dim*/0}}),
76 c10::Error);
77 }
78 {
79 auto tensor = ones({2, 3, 4});
80 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
81 ASSERT_THROW(
82 makeBatched(ones({2, 3, 4}), {{/*lvl*/kVmapNumLevels + 5, /*dim*/0}}),
83 c10::Error);
84 }
85 {
86 // create a BatchedTensor with kVmapNumLevels levels.
87 // Should not throw
88 auto tensor = ones(std::vector<int64_t>(kVmapNumLevels, 1));
89 makeBatched(tensor, maxBatchDimsAtFront());
90 }
91 {
92 // create a BatchedTensor with kVmapNumLevels+1 levels.
93 auto tensor = ones(std::vector<int64_t>(kVmapNumLevels + 1, 1));
94 auto batch_dims = maxBatchDimsAtFront();
95 batch_dims.emplace_back(/*lvl*/kVmapNumLevels, /*dim*/kVmapNumLevels);
96 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
97 ASSERT_THROW(makeBatched(tensor, batch_dims), c10::Error);
98 }
99 }
100
TEST(VmapTest,TestBatchedTensorActualDim)101 TEST(VmapTest, TestBatchedTensorActualDim) {
102 {
103 // No batch dims
104 Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {});
105 auto* batched = maybeGetBatchedImpl(tensor);
106 ASSERT_EQ(batched->actualDim(0), 0);
107 ASSERT_EQ(batched->actualDim(1), 1);
108 ASSERT_EQ(batched->actualDim(3), 3);
109
110 // Test wrap around
111 ASSERT_EQ(batched->actualDim(-1), 3);
112 ASSERT_EQ(batched->actualDim(-4), 0);
113 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
114 ASSERT_THROW(batched->actualDim(-5), c10::Error);
115 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
116 ASSERT_THROW(batched->actualDim(4), c10::Error);
117
118 // test wrap_dim = False
119 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
120 ASSERT_THROW(batched->actualDim(-1, /*wrap_dim*/false), c10::Error);
121 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
122 ASSERT_THROW(batched->actualDim(-4, /*wrap_dim*/false), c10::Error);
123 }
124 {
125 // Single batch dim at front
126 Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {{/*lvl*/1, /*dim*/0}});
127 auto* batched = maybeGetBatchedImpl(tensor);
128 ASSERT_EQ(batched->actualDim(0), 1);
129 ASSERT_EQ(batched->actualDim(2), 3);
130 ASSERT_EQ(batched->actualDim(-1), 3);
131 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
132 ASSERT_THROW(batched->actualDim(3), c10::Error);
133 }
134 {
135 // Single batch dim in middle
136 Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {{/*lvl*/1, /*dim*/1}});
137 auto* batched = maybeGetBatchedImpl(tensor);
138 ASSERT_EQ(batched->actualDim(0), 0);
139 ASSERT_EQ(batched->actualDim(1), 2);
140 ASSERT_EQ(batched->actualDim(2), 3);
141 }
142 {
143 // Single batch dim at end
144 Tensor tensor = makeBatched(ones({2, 3, 5, 7}), {{/*lvl*/1, /*dim*/1}});
145 auto* batched = maybeGetBatchedImpl(tensor);
146 ASSERT_EQ(batched->actualDim(0), 0);
147 ASSERT_EQ(batched->actualDim(2), 3);
148 ASSERT_EQ(batched->actualDim(-1), 3);
149 }
150 {
151 // Multiple (2) batch dims at front
152 Tensor tensor = makeBatched(
153 ones({2, 3, 5, 7}),
154 {{/*lvl*/1, /*dim*/0}, {/*lvl*/2, /*dim*/1}});
155 auto* batched = maybeGetBatchedImpl(tensor);
156 ASSERT_EQ(batched->actualDim(0), 2);
157 ASSERT_EQ(batched->actualDim(1), 3);
158 }
159 {
160 // Multiple (2) batch dims, misc places
161 Tensor tensor = makeBatched(
162 ones({2, 3, 5, 7}),
163 {{/*lvl*/1, /*dim*/1}, {/*lvl*/2, /*dim*/3}});
164 auto* batched = maybeGetBatchedImpl(tensor);
165 ASSERT_EQ(batched->actualDim(0), 0);
166 ASSERT_EQ(batched->actualDim(1), 2);
167 ASSERT_EQ(batched->actualDim(-1), 2);
168 ASSERT_EQ(batched->actualDim(-2), 0);
169 }
170 {
171 // ActualDim on kVmapMaxTensorDims sized underlying tensor
172 auto tensor = ones({});
173 for (C10_UNUSED const auto i : c10::irange(kVmapMaxTensorDims)) {
174 tensor = tensor.unsqueeze(0);
175 }
176 ASSERT_EQ(tensor.dim(), kVmapMaxTensorDims);
177
178 auto batched = addBatchDim(tensor, /*lvl*/1, /*dim*/0);
179 auto* batched_impl = maybeGetBatchedImpl(batched);
180 ASSERT_EQ(
181 batched_impl->actualDim(kVmapMaxTensorDims - 2),
182 kVmapMaxTensorDims - 1);
183 ASSERT_EQ(
184 batched_impl->actualDim(-1),
185 kVmapMaxTensorDims - 1);
186 }
187 }
TEST(VmapTest,TestMultiBatchVmapTransform)188 TEST(VmapTest, TestMultiBatchVmapTransform) {
189 {
190 // Input is regular Tensor
191 auto tensor = ones({2, 3, 5});
192 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
193 ASSERT_THROW(MultiBatchVmapTransform::logicalToPhysical(tensor), c10::Error);
194 }
195 {
196 // Input is BatchedTensor, Batch dims are already at the front
197 auto tensor = ones({2, 3, 5});
198 BatchDims bdims = {{/*lvl*/1, /*dim*/0}, {/*lvl*/3, /*dim*/1}};
199 auto batched = makeBatched(tensor, bdims);
200
201 auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
202 ASSERT_TRUE(result.tensor().is_same(tensor));
203 }
204 {
205 // Single batch dim, not at front
206 auto tensor = ones({2, 3, 5});
207 BatchDims bdims = {{/*lvl*/1, /*dim*/1}};
208 auto batched = makeBatched(tensor, bdims);
209
210 auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
211 ASSERT_EQ(result.tensor().data_ptr(), tensor.data_ptr());
212 ASSERT_TRUE(at::allclose(result.tensor(), tensor.permute({1, 0, 2})));
213 }
214 {
215 // Multiple batch dims, not at front.
216 auto tensor = ones({2, 3, 5});
217 BatchDims bdims = {{/*lvl*/1, /*dim*/1}, {/*lvl*/2,/*dim*/2}, {/*lvl*/3,/*dim*/0}};
218 auto batched = makeBatched(tensor, bdims);
219
220 auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
221 ASSERT_EQ(result.tensor().data_ptr(), tensor.data_ptr());
222 ASSERT_TRUE(at::allclose(result.tensor(), tensor.permute({1, 2, 0})));
223 }
224 {
225 // Edge case: kVmapNumLevels levels; batch dims are already at front.
226
227 // sizes=[2, 1, 3, 1, 1, 7, 1, 1, 1, 1, ...]
228 auto sizes = std::vector<int64_t>(kVmapNumLevels, 1);
229 sizes[0] = 2;
230 sizes[2] = 3;
231 sizes[5] = 7;
232
233 // bdims = {{lvl=0,dim=0,lvl=1,dim=1,...,{lvl=63,dim=63}}
234 auto batch_dims = maxBatchDimsAtFront();
235 auto tensor = ones(sizes);
236
237 auto batched = makeBatched(tensor, batch_dims);
238 auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
239 ASSERT_TRUE(result.tensor().is_same(tensor));
240 }
241 {
242 // Edge case: kVmapNumLevels levels; batch dims are not at front
243
244 // sizes=[1, 3, 2, 1, 1, 7, 1, 1, 1, 1, ..., 1, 1, 5]
245 auto sizes = std::vector<int64_t>(kVmapNumLevels, 1);
246 sizes[1] = 3;
247 sizes[2] = 2;
248 sizes[5] = 7;
249 sizes[kVmapNumLevels - 1] = 5;
250
251 // The goal is to permute sizes such that the final sizes are:
252 // [2, 3, 5, 7, 1, 1, 1, 1, 1, ...]
253 auto expected_result_sizes = std::vector<int64_t>(kVmapNumLevels, 1);
254 expected_result_sizes[0] = 2;
255 expected_result_sizes[1] = 3;
256 expected_result_sizes[2] = 5;
257 expected_result_sizes[3] = 7;
258
259 // bdims = {{0, 2}, {1, 1}, {2, 63}, {3, 5}, {4, 0}, {5, 3}, {6, 4},
260 // {7, 6}, {8, 7}, {9, 8}, ..., {63, 62}}
261 BatchDims batch_dims = {
262 {0, 2}, {1, 1}, {2, kVmapNumLevels - 1}, {3, 5}, {4, 0}, {5, 3}, {6, 4}
263 };
264 for (const auto level : c10::irange(7, kVmapNumLevels)) {
265 batch_dims.emplace_back(level, /*dim=*/level - 1);
266 }
267 auto tensor = ones(sizes);
268
269 auto batched = makeBatched(tensor, batch_dims);
270 auto result = MultiBatchVmapTransform::logicalToPhysical(batched);
271 ASSERT_EQ(result.tensor().data_ptr(), tensor.data_ptr());
272 ASSERT_EQ(result.tensor().sizes(), expected_result_sizes);
273 }
274 }
TEST(VmapTest,TestVmapPhysicalViewGetPhysicalDim)275 TEST(VmapTest, TestVmapPhysicalViewGetPhysicalDim) {
276 VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 1 | 4);
277
278 // Positive dims
279 ASSERT_EQ(physical_view.getPhysicalDim(0), 2);
280 ASSERT_EQ(physical_view.getPhysicalDim(1), 3);
281 ASSERT_EQ(physical_view.getPhysicalDim(2), 4);
282 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
283 ASSERT_THROW(physical_view.getPhysicalDim(3), c10::Error);
284
285 // Negative dims (testing wrap dim behavior)
286 ASSERT_EQ(physical_view.getPhysicalDim(-1), 4);
287 ASSERT_EQ(physical_view.getPhysicalDim(-2), 3);
288 ASSERT_EQ(physical_view.getPhysicalDim(-3), 2);
289 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
290 ASSERT_THROW(physical_view.getPhysicalDim(-4), c10::Error);
291 }
TEST(VmapTest,TestVmapPhysicalViewGetPhysicalDims)292 TEST(VmapTest, TestVmapPhysicalViewGetPhysicalDims) {
293 VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 2 | 8 | 16);
294
295 ASSERT_EQ(
296 physical_view.getPhysicalDims({0, 1, -1, -2}),
297 VmapDimVector({3, 4, 4, 3}));
298
299 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
300 ASSERT_THROW(physical_view.getPhysicalDims({2, 0}), c10::Error);
301 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
302 ASSERT_THROW(physical_view.getPhysicalDims({0, -3}), c10::Error);
303 }
304
checkBatchDimsEqual(BatchDimsRef bdims,BatchDimsRef expected_bdims)305 static void checkBatchDimsEqual(BatchDimsRef bdims, BatchDimsRef expected_bdims) {
306 ASSERT_EQ(bdims.size(), expected_bdims.size());
307 for (const auto idx : c10::irange(bdims.size())) {
308 ASSERT_EQ(bdims[idx].dim(), expected_bdims[idx].dim());
309 ASSERT_EQ(bdims[idx].level(), expected_bdims[idx].level());
310 }
311 }
312
TEST(VmapTest,TestVmapPhysicalViewNewLogicalFromPhysical)313 TEST(VmapTest, TestVmapPhysicalViewNewLogicalFromPhysical) {
314 {
315 // Simple case: single level
316 VmapPhysicalView physical_view(ones({2, 3, 4}), /*levels = {2}*/4);
317 Tensor physical = ones({2, 6, 7});
318
319 auto result = physical_view.getPhysicalToLogicalMap().apply(physical);
320 auto* batched = maybeGetBatchedImpl(result);
321 ASSERT_TRUE(batched != nullptr);
322 ASSERT_TRUE(batched->value().is_same(physical));
323 checkBatchDimsEqual(batched->bdims(), {{2, 0}});
324 }
325 {
326 // Multiple levels
327 VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), /*levels = {1, 3, 4}*/2 | 8 | 16);
328 Tensor physical = ones({2, 3, 4, 7});
329
330 auto result = physical_view.getPhysicalToLogicalMap().apply(physical);
331 auto* batched = maybeGetBatchedImpl(result);
332 ASSERT_TRUE(batched != nullptr);
333 ASSERT_TRUE(batched->value().is_same(physical));
334 checkBatchDimsEqual(batched->bdims(), {{1, 0}, {3, 1}, {4, 2}});
335 }
336 {
337 // Logical dimensions is [].
338 VmapPhysicalView physical_view(ones({2}), /*levels = {2}*/4);
339 Tensor physical = ones({2});
340
341 auto result = physical_view.getPhysicalToLogicalMap().apply(physical);
342 auto* batched = maybeGetBatchedImpl(result);
343 ASSERT_TRUE(batched != nullptr);
344 ASSERT_TRUE(batched->value().is_same(physical));
345 checkBatchDimsEqual(batched->bdims(), {{2, 0}});
346 }
347 }
348
349 // Basic test for BatchedTensor::sum.
350 // NB: We don't need to write tests in C++ for batching rules if we can test them
351 // in Python via the vmap API. These are here to bootstrap that process.
TEST(VmapTest,TestBatchedTensorSum)352 TEST(VmapTest, TestBatchedTensorSum) {
353 {
354 // Simple: single batch dim, single reduce dim
355 Tensor x = at::randn({2, 3, 5, 7});
356
357 Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/0}});
358 Tensor batched_out = batched_x.sum(0);
359 const auto& out = maybeGetBatchedImpl(batched_out)->value();
360
361 ASSERT_TRUE(at::allclose(out, x.sum(1)));
362 }
363 {
364 // single batch dim, -1 reduce dim handling
365 Tensor x = at::randn({2, 3});
366
367 Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/1}});
368 Tensor batched_out = batched_x.sum(-1);
369 const auto& out = maybeGetBatchedImpl(batched_out)->value();
370
371 ASSERT_TRUE(at::allclose(out, x.sum(0)));
372 }
373 {
374 // single batch dim, multiple reduce dim
375 Tensor x = at::randn({2, 3, 5, 7});
376
377 Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/1}});
378 Tensor batched_out = batched_x.sum(std::vector<int64_t>{0, 1});
379 const auto& out = maybeGetBatchedImpl(batched_out)->value();
380
381 ASSERT_TRUE(at::allclose(out, x.sum(std::vector<int64_t>{0, 2})));
382 }
383 {
384 // multiple batch dim, multiple reduce dim
385 Tensor x = at::randn({2, 3, 5, 7});
386
387 Tensor batched_x = makeBatched(x, {{/*lvl*/1, /*dim*/0}, {/*lvl*/2, /*dim*/1}});
388 Tensor batched_out = batched_x.sum(std::vector<int64_t>{0, 1});
389 const auto& out = maybeGetBatchedImpl(batched_out)->value();
390
391 ASSERT_TRUE(at::allclose(out, x.sum(std::vector<int64_t>{2, 3})));
392 }
393 }
394
checkBroadcastingVmapTransform(TensorList inputs,TensorList expected_outputs)395 static void checkBroadcastingVmapTransform(TensorList inputs, TensorList expected_outputs) {
396 auto outputs = BroadcastingVmapTransform::logicalToPhysical(inputs);
397 ASSERT_EQ(outputs.size(), expected_outputs.size());
398 for (const auto idx : c10::irange(outputs.size())) {
399 const auto& output = outputs[idx].tensor();
400 ASSERT_EQ(output.data_ptr(), expected_outputs[idx].data_ptr());
401 ASSERT_TRUE(at::allclose(output, expected_outputs[idx]));
402 }
403 }
404
TEST(VmapTest,TestBroadcastingVmapTransformBatchedBatched)405 TEST(VmapTest, TestBroadcastingVmapTransformBatchedBatched) {
406 {
407 // Check that batch dims get moved to the front
408 int64_t B0 = 5, B1 = 7;
409 Tensor x = at::randn({2, B0, 3, B1});
410 Tensor y = at::randn({B1, 2, 3, B0});
411 Tensor batched_x = makeBatched(x, {{0, 1}, {1, 3}});
412 Tensor batched_y = makeBatched(y, {{0, 3}, {1, 0}});
413
414 checkBroadcastingVmapTransform(
415 {batched_x, batched_y},
416 {x.permute({1, 3, 0, 2}), y.permute({3, 0, 1, 2})});
417 }
418 {
419 // Check that batch dims become aligned (i.e. extra 1 dims get added)
420 int64_t B0 = 5, B1 = 7, B2 = 9;
421 Tensor x = at::randn({B0, B2, 2, 3});
422 Tensor y = at::randn({B0, B1, 2, 3});
423 Tensor batched_x = makeBatched(x, {{0, 0}, {2, 1}});
424 Tensor batched_y = makeBatched(y, {{0, 0}, {1, 1}});
425
426 checkBroadcastingVmapTransform(
427 {batched_x, batched_y},
428 {x.unsqueeze(1), y.unsqueeze(2)});
429 }
430 {
431 // Check that the "example" gets padded with extra dims of size 1.
432 int64_t B0 = 5;
433 Tensor x = at::randn({B0, 3});
434 Tensor y = at::randn({B0, 2, 3});
435 Tensor batched_x = makeBatched(x, {{0, 0}});
436 Tensor batched_y = makeBatched(y, {{0, 0}});
437
438 checkBroadcastingVmapTransform(
439 {batched_x, batched_y},
440 {x.unsqueeze(1), y});
441 }
442 {
443 // Check batch dims get moved to front, batch dims get aligned,
444 // and the example gets padded correctly.
445 int64_t B0 = 5, B1 = 7, B2 = 11, B3 = 13;
446 Tensor x = at::randn({2, B0, 3, B2});
447 Tensor y = at::randn({B3, 3, B1});
448 Tensor batched_x = makeBatched(x, {{0, 1}, {2, 3}});
449 Tensor batched_y = makeBatched(y, {{1, 2}, {3, 0}});
450
451 checkBroadcastingVmapTransform(
452 {batched_x, batched_y},
453 {
454 x.permute({1, 3, 0, 2}).view({B0, 1, B2, 1, 2, 3}),
455 y.permute({2, 0, 1}).view({1, B1, 1, B3, 1, 3}),
456 });
457 }
458 {
459 // Edge case: BatchedTensor "scalar" handling
460 int64_t B0 = 5, B2 = 11;
461 Tensor x = at::randn({B0});
462 Tensor y = at::randn({B0, B2});
463 Tensor batched_x = makeBatched(x, {{0, 0}});
464 Tensor batched_y = makeBatched(y, {{0, 0}, {1, 1}});
465
466 checkBroadcastingVmapTransform({batched_x, batched_y}, {x.view({B0, 1}), y});
467 checkBroadcastingVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1})});
468 }
469 {
470 // Edge case: Only one tensor is a "batchedtensor scalar"
471 int64_t B0 = 5, B2 = 11;
472 Tensor x = at::randn({B0});
473 Tensor y = at::randn({B0, B2, 2});
474 Tensor batched_x = makeBatched(x, {{0, 0}});
475 Tensor batched_y = makeBatched(y, {{0, 0}, {1, 1}});
476
477 checkBroadcastingVmapTransform({batched_x, batched_y}, {x.view({B0, 1, 1}), y});
478 checkBroadcastingVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1, 1})});
479 }
480 }
481
TEST(VmapTest,TestBroadcastingVmapTransformBatchedUnbatched)482 TEST(VmapTest, TestBroadcastingVmapTransformBatchedUnbatched) {
483 {
484 // Check same example size
485 int64_t B0 = 5, B1 = 7;
486 Tensor x = at::randn({2, B0, 3, B1});
487 Tensor y = at::randn({2, 3});
488 Tensor batched_x = makeBatched(x, {{0, 1}, {1, 3}});
489
490 checkBroadcastingVmapTransform(
491 {batched_x, y},
492 {x.permute({1, 3, 0, 2}), y.view({1, 1, 2, 3})});
493 checkBroadcastingVmapTransform(
494 {y, batched_x},
495 {y.view({1, 1, 2, 3}), x.permute({1, 3, 0, 2})});
496 }
497 {
498 // BatchedTensor has higher example dim than non-batched-tensor
499 int64_t B0 = 5, B1 = 7;
500 Tensor x = at::randn({B0, B1, 2, 3});
501 Tensor y = at::randn({3});
502 Tensor batched_x = makeBatched(x, {{0, 0}, {1, 1}});
503
504 checkBroadcastingVmapTransform(
505 {batched_x, y}, {x, y.view({1, 1, 1, 3})});
506 checkBroadcastingVmapTransform(
507 {y, batched_x}, {y.view({1, 1, 1, 3}), x});
508 }
509 {
510 // BatchedTensor has lower example dim than non-batched-tensor
511 int64_t B0 = 5, B1 = 7;
512 Tensor x = at::randn({B0, B1, 3});
513 Tensor y = at::randn({2, 3});
514 Tensor batched_x = makeBatched(x, {{0, 0}, {1, 1}});
515
516 checkBroadcastingVmapTransform(
517 {batched_x, y}, {x.view({B0, B1, 1, 3}), y.view({1, 1, 2, 3})});
518 checkBroadcastingVmapTransform(
519 {y, batched_x}, {y.view({1, 1, 2, 3}), x.view({B0, B1, 1, 3})});
520 }
521 {
522 // Scalar handling
523 int64_t B0 = 5, B1 = 7;
524 Tensor x = at::randn({B0, B1});
525 Tensor y = at::randn({});
526 Tensor batched_x = makeBatched(x, {{0, 0}, {1, 1}});
527
528 checkBroadcastingVmapTransform({batched_x, y}, {x, y.view({1, 1})});
529 checkBroadcastingVmapTransform({y, batched_x}, {y.view({1, 1}), x});
530 }
531 }
532
TEST(VmapTest,TestBroadcastingVmapTransformMaxLevels)533 TEST(VmapTest, TestBroadcastingVmapTransformMaxLevels) {
534 {
535 // inputs have all 64 levels
536 auto x = randn(std::vector<int64_t>(kVmapNumLevels, 1));
537 auto y = randn(std::vector<int64_t>(kVmapNumLevels, 1));
538 auto batched_x = makeBatched(x, maxBatchDimsAtFront());
539 auto batched_y = makeBatched(y, maxBatchDimsAtFront());
540
541 checkBroadcastingVmapTransform({batched_x, batched_y}, {x, y});
542 }
543 {
544 // inputs don't have all 64 levels, but results do.
545 int64_t split = 19;
546 auto x = randn(std::vector<int64_t>(split, 1));
547 auto y = randn(std::vector<int64_t>(kVmapNumLevels - split, 1));
548
549 auto tmp = maxBatchDimsAtFront();
550 BatchDims x_bdims(tmp.begin(), tmp.begin() + split);
551
552 // Construct y_bdims.
553 int64_t dim = 0;
554 auto y_bdims_vector = fmap(
555 ArrayRef<BatchDim>(tmp.begin() + split, tmp.end()),
556 [&](const BatchDim& bdim) -> BatchDim {
557 return { bdim.level(), dim++ };
558 });
559 BatchDims y_bdims(y_bdims_vector.begin(), y_bdims_vector.end());
560
561 auto batched_x = makeBatched(x, x_bdims);
562 auto batched_y = makeBatched(y, y_bdims);
563
564 auto expected_size = std::vector<int64_t>(kVmapNumLevels, 1);
565 checkBroadcastingVmapTransform(
566 {batched_x, batched_y},
567 {x.view(expected_size), y.view(expected_size)});
568 }
569 }
570
571 // Basic test for BatchedTensor::mul.
TEST(VmapTest,TestBatchedTensorMul)572 TEST(VmapTest, TestBatchedTensorMul) {
573 {
574 // batched * batched
575 Tensor x = at::randn({2, 3});
576 Tensor y = at::randn({2, 3});
577
578 Tensor Bx = addBatchDim(x, /*lvl*/1, /*dim*/0);
579 Tensor By = addBatchDim(y, /*lvl*/1, /*dim*/0);
580 Tensor Bout = Bx * By;
581
582 const auto& out = maybeGetBatchedImpl(Bout)->value();
583 std::vector<int64_t> expected_size = {2, 3};
584 ASSERT_EQ(out.sizes(), expected_size);
585 ASSERT_TRUE(at::allclose(out, x * y));
586 }
587 {
588 // batched * unbatched
589 Tensor x = at::randn({2, 3});
590 Tensor y = at::randn({3});
591
592 Tensor Bx = addBatchDim(x, /*lvl*/1, /*dim*/0);
593 Tensor Bout = Bx * y;
594 const auto& out = maybeGetBatchedImpl(Bout)->value();
595 std::vector<int64_t> expected_size = {2, 3};
596 ASSERT_EQ(out.sizes(), expected_size);
597 ASSERT_TRUE(at::allclose(out, x * y));
598 }
599 {
600 // batched (level 1) * batched (level 2)
601 Tensor x = at::randn({2, 3});
602 Tensor y = at::randn({5, 3});
603
604 Tensor Bx = addBatchDim(x, /*lvl*/1, /*dim*/0);
605 Tensor By = addBatchDim(y, /*lvl*/2, /*dim*/0);
606 Tensor Bout = Bx * By;
607
608 // We get a doubly wrapped BatchTensor...
609 const auto& out = maybeGetBatchedImpl(Bout)->value();
610 std::vector<int64_t> expected_size = {2, 5, 3};
611 ASSERT_EQ(out.sizes(), expected_size);
612 ASSERT_TRUE(at::allclose(out, x.unsqueeze(1) * y));
613 }
614 {
615 // batched (level 2, 3, 4) * batched (level 3, 1, 2)
616 Tensor x = at::randn({3, 5, 7});
617 Tensor y = at::randn({5, 2, 3});
618
619 // Each BatchDim is constructed in {dim, level} format.
620 Tensor Bx = makeBatched(x, {{2, 0}, {3, 1}, {4, 2}});
621 Tensor By = makeBatched(y, {{1, 1}, {2, 2}, {3, 0}});
622 Tensor Bout = Bx * By;
623
624 const auto& out = maybeGetBatchedImpl(Bout)->value();
625
626 // The batching rule aligns dimensions in the order of their `level`.
627 // It just happened that we chose sizes to be in the same order as the level.
628 std::vector<int64_t> expected_size = {2, 3, 5, 7};
629 ASSERT_EQ(out.sizes(), expected_size);
630 ASSERT_TRUE(at::allclose(out, x * y.permute({1, 2, 0}).unsqueeze(3)));
631 }
632 }
633
634 // test for BatchedTensor::size(int).
TEST(VmapTest,TestBatchedTensorSize)635 TEST(VmapTest, TestBatchedTensorSize) {
636 {
637 // Single batch dim at front
638 Tensor x = at::randn({3, 5, 7});
639 Tensor Bx = makeBatched(x, {{0, 0}});
640
641 ASSERT_EQ(Bx.size(0), 5);
642 ASSERT_EQ(Bx.size(1), 7);
643 ASSERT_EQ(Bx.size(-1), 7);
644 ASSERT_EQ(Bx.size(-2), 5);
645 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
646 ASSERT_THROW(Bx.size(2), c10::Error);
647 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
648 ASSERT_THROW(Bx.size(-3), c10::Error);
649 }
650 {
651 // multiple batch dims not at front
652 Tensor x = at::randn({2, 3, 5, 7, 11});
653 Tensor Bx = makeBatched(x, {{0, 3}, {1, 1}});
654
655 ASSERT_EQ(Bx.size(0), 2);
656 ASSERT_EQ(Bx.size(1), 5);
657 ASSERT_EQ(Bx.size(2), 11);
658 ASSERT_EQ(Bx.size(-1), 11);
659 ASSERT_EQ(Bx.size(-2), 5);
660 ASSERT_EQ(Bx.size(-3), 2);
661 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
662 ASSERT_THROW(Bx.size(3), c10::Error);
663 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
664 ASSERT_THROW(Bx.size(-4), c10::Error);
665 }
666 }
667
TEST(VmapTest,TestVmapPhysicalViewGetPhysicalShape)668 TEST(VmapTest, TestVmapPhysicalViewGetPhysicalShape) {
669 {
670 VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 1 | 4);
671 ASSERT_EQ(physical_view.getPhysicalShape({}), VmapDimVector({2, 3}));
672 ASSERT_EQ(physical_view.getPhysicalShape({7}), VmapDimVector({2, 3, 7}));
673 ASSERT_EQ(physical_view.getPhysicalShape({7, 11, 13}), VmapDimVector({2, 3, 7, 11, 13}));
674 ASSERT_EQ(physical_view.getPhysicalShape({7, 11, 13, 17}), VmapDimVector({2, 3, 7, 11, 13, 17}));
675 }
676 {
677 VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), 2);
678 ASSERT_EQ(physical_view.getPhysicalShape({}), VmapDimVector({2}));
679 ASSERT_EQ(physical_view.getPhysicalShape({7}), VmapDimVector({2, 7}));
680 }
681 }
682
683 // Basic test for BatchedTensor::expand
TEST(VmapTest,TestBatchedTensorExpand)684 TEST(VmapTest, TestBatchedTensorExpand) {
685 {
686 // Expand size is too small
687 auto tensor = at::randn({2, 3, 5});
688 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
689 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
690 ASSERT_THROW(batched.expand({5}), c10::Error);
691 }
692 {
693 // Expand size has same dimensionality as the logical dim
694 auto tensor = at::randn({2, 1, 5});
695 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
696 auto batched_out = batched.expand({3, 5});
697 const auto& out = maybeGetBatchedImpl(batched_out)->value();
698
699 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
700 ASSERT_TRUE(at::allclose(out, tensor.expand({2, 3, 5})));
701 }
702 {
703 // Expand size has same dimensionality as the logical dim, incorrect expand size
704 auto tensor = at::randn({2, 1, 5});
705 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
706 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
707 ASSERT_THROW(batched.expand({1, 25}), c10::Error);
708 }
709 {
710 // Expand size has greater dimensionality as the logical dim
711 auto tensor = at::randn({2, 3, 5});
712 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
713 auto batched_out = batched.expand({7, 3, 5});
714 const auto& out = maybeGetBatchedImpl(batched_out)->value();
715
716 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
717 ASSERT_TRUE(at::allclose(out, tensor.view({2, 1, 3, 5}).expand({2, 7, 3, 5})));
718 }
719 {
720 // Expand size has greater dimensionality as the logical dim, incorrect expand size
721 auto tensor = at::randn({2, 3, 5});
722 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
723 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
724 ASSERT_THROW(batched.expand({7, 9, 5}), c10::Error);
725 }
726 {
727 // logical dim is 0, expand size has same dimensionality as logical dim
728 auto tensor = at::randn({2, 3});
729 auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
730 auto batched_out = batched.expand(c10::IntArrayRef({}));
731 const auto& out = maybeGetBatchedImpl(batched_out)->value();
732 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
733 ASSERT_TRUE(at::allclose(out, tensor));
734 }
735 {
736 // logical dim is 0, expand size has greater dimensionality than logical dim
737 auto tensor = at::randn({2, 3});
738 auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
739 auto batched_out = batched.expand({5, 7});
740 const auto& out = maybeGetBatchedImpl(batched_out)->value();
741 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
742 ASSERT_TRUE(at::allclose(out, tensor.view({2, 3, 1, 1}).expand({2, 3, 5, 7})));
743 }
744 }
745 // Basic test for BatchedTensor::unsqueeze
TEST(VmapTest,TestBatchedTensorUnsqueeze)746 TEST(VmapTest, TestBatchedTensorUnsqueeze) {
747 {
748 // Basic test
749 auto tensor = at::randn({2, 3, 5}); // NOLINT
750 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
751
752 auto batched_out = batched.unsqueeze(0);
753 const auto& out = maybeGetBatchedImpl(batched_out)->value();
754 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
755 ASSERT_TRUE(at::allclose(out, tensor.unsqueeze(1)));
756 }
757 {
758 // Test with multiple levels
759 auto tensor = at::randn({2, 3, 5}); // NOLINT
760 auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
761
762 auto batched_out = batched.unsqueeze(0);
763 const auto& out = maybeGetBatchedImpl(batched_out)->value();
764 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
765 ASSERT_TRUE(at::allclose(out, tensor.unsqueeze(2)));
766 }
767 {
768 // Negative dim
769 auto tensor = at::randn({2, 3, 5}); // NOLINT
770 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
771
772 auto batched_out = batched.unsqueeze(-1);
773 const auto& out = maybeGetBatchedImpl(batched_out)->value();
774 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
775 ASSERT_TRUE(at::allclose(out, tensor.unsqueeze(-1)));
776 }
777 }
778 // Basic test for BatchedTensor::squeeze(dim)
TEST(VmapTest,TestBatchedTensorSqueeze)779 TEST(VmapTest, TestBatchedTensorSqueeze) {
780 {
781 // Basic test
782 auto tensor = at::randn({2, 1, 5}); // NOLINT
783 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
784
785 auto batched_out = batched.squeeze(0);
786 const auto& out = maybeGetBatchedImpl(batched_out)->value();
787 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
788 ASSERT_TRUE(at::allclose(out, tensor.squeeze(1)));
789 }
790 {
791 // Test with multiple levels
792 auto tensor = at::randn({2, 3, 1}); // NOLINT
793 auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
794
795 auto batched_out = batched.squeeze(0);
796 const auto& out = maybeGetBatchedImpl(batched_out)->value();
797 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
798 ASSERT_TRUE(at::allclose(out, tensor.squeeze(2)));
799 }
800 {
801 // Negative dim
802 auto tensor = at::randn({2, 3, 1}); // NOLINT
803 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
804
805 auto batched_out = batched.squeeze(-1);
806 const auto& out = maybeGetBatchedImpl(batched_out)->value();
807 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
808 ASSERT_TRUE(at::allclose(out, tensor.squeeze(-1)));
809 }
810 }
811 // Basic test for BatchedTensor::transpose
TEST(VmapTest,TestBatchedTensorTranspose)812 TEST(VmapTest, TestBatchedTensorTranspose) {
813 {
814 // Basic test
815 auto tensor = at::randn({2, 3, 5}); // NOLINT
816 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
817
818 auto batched_out = batched.transpose(0, 1);
819 const auto& out = maybeGetBatchedImpl(batched_out)->value();
820 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
821 ASSERT_TRUE(at::allclose(out, tensor.transpose(1, 2)));
822 }
823 {
824 // Test with multiple levels
825 auto tensor = at::randn({2, 3, 5, 7, 11}); // NOLINT
826 auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
827
828 auto batched_out = batched.transpose(0, 2);
829 const auto& out = maybeGetBatchedImpl(batched_out)->value();
830 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
831 ASSERT_TRUE(at::allclose(out, tensor.transpose(2, 4)));
832 }
833 {
834 // Negative dims
835 auto tensor = at::randn({2, 3, 5, 7}); // NOLINT
836 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
837
838 auto batched_out = batched.mT();
839 const auto& out = maybeGetBatchedImpl(batched_out)->value();
840 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
841 ASSERT_TRUE(at::allclose(out, tensor.mT()));
842 }
843 }
844
845 // Basic test for BatchedTensor::permute
TEST(VmapTest,TestBatchedTensorPermute)846 TEST(VmapTest, TestBatchedTensorPermute) {
847 {
848 // Basic test
849 auto tensor = at::randn({2, 3, 5}); // NOLINT
850 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
851
852 auto batched_out = batched.permute({1, 0});
853 const auto& out = maybeGetBatchedImpl(batched_out)->value();
854 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
855 ASSERT_TRUE(at::allclose(out, tensor.permute({0, 2, 1})));
856 }
857 {
858 // Test with multiple levels
859 auto tensor = at::randn({2, 3, 5, 7, 11}); // NOLINT
860 auto batched = makeBatched(tensor, {{0, 0}, {1, 1}});
861
862 auto batched_out = batched.permute({2, 1, 0});
863 const auto& out = maybeGetBatchedImpl(batched_out)->value();
864 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
865 ASSERT_TRUE(at::allclose(out, tensor.permute({0, 1, 4, 3, 2})));
866 }
867 {
868 // Negative dims
869 auto tensor = at::randn({2, 3, 5, 7}); // NOLINT
870 auto batched = makeBatched(tensor, {{/*lvl*/0, /*dim*/0}});
871
872 auto batched_out = batched.permute({-1, -2, -3});
873 const auto& out = maybeGetBatchedImpl(batched_out)->value();
874 ASSERT_EQ(out.data_ptr(), tensor.data_ptr());
875 ASSERT_TRUE(at::allclose(out, tensor.permute({0, -1, -2, -3})));
876 }
877 }
878
checkMultiBatchVmapTransform(TensorList inputs,TensorList expected_outputs)879 static void checkMultiBatchVmapTransform(TensorList inputs, TensorList expected_outputs) {
880 auto outputs = MultiBatchVmapTransform::logicalToPhysical(inputs);
881 ASSERT_EQ(outputs.size(), expected_outputs.size());
882 for (const auto idx : c10::irange(outputs.size())) {
883 const auto& output = outputs[idx].tensor();
884 ASSERT_EQ(output.data_ptr(), expected_outputs[idx].data_ptr());
885 ASSERT_EQ(output.sizes(), expected_outputs[idx].sizes());
886 ASSERT_TRUE(at::allclose(output, expected_outputs[idx]));
887 }
888 }
889
TEST(VmapTest,TestMultiBatchVmapTransformBatchedBatched)890 TEST(VmapTest, TestMultiBatchVmapTransformBatchedBatched) {
891 {
892 // Check that batch dims get moved to the front
893 int64_t B0 = 5, B1 = 7;
894 Tensor x = at::randn({2, B0, 3, B1});
895 Tensor y = at::randn({B1, 2, 3, B0});
896 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
897 Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/3}, {/*lvl*/1, /*dim*/0}});
898
899 checkMultiBatchVmapTransform(
900 {batched_x, batched_y},
901 {at::movedim(x, {1, 3}, {0, 1}), at::movedim(y, {0, 3}, {1, 0})});
902 }
903 {
904 // Check that batch dims become broadcasted and are present in all returns
905 int64_t B0 = 5, B1 = 7, B2 = 9;
906 Tensor x = at::randn({B0, B2, 2, 3});
907 Tensor y = at::randn({B0, B1, 2, 3});
908 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/2, /*dim*/1}});
909 Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
910
911 checkMultiBatchVmapTransform(
912 {batched_x, batched_y},
913 {x.unsqueeze(1).expand({B0, B1, B2, 2, 3}), y.unsqueeze(2).expand({B0, B1, B2, 2, 3})});
914 }
915 {
916 // Check operation on tensors of different logical dims
917 int64_t B0 = 5;
918 Tensor x = at::randn({B0, 3});
919 Tensor y = at::randn({B0, 2, 3});
920 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
921 Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}});
922
923 checkMultiBatchVmapTransform({batched_x, batched_y}, {x, y});
924 }
925 {
926 // More complicated example with two tensors.
927 int64_t B0 = 5, B1 = 7, B2 = 11, B3 = 13;
928 Tensor x = at::randn({2, B0, 3, B2});
929 Tensor y = at::randn({B3, 3, B1});
930 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/2, /*dim*/3}});
931 Tensor batched_y = makeBatched(y, {{/*lvl*/1, /*dim*/2}, {/*lvl*/3, /*dim*/0}});
932
933 checkMultiBatchVmapTransform(
934 {batched_x, batched_y},
935 {
936 x.permute({1, 3, 0, 2}).view({B0, 1, B2, 1, 2, 3}).expand({B0, B1, B2, B3, 2, 3}),
937 y.permute({2, 0, 1}).view({1, B1, 1, B3, 3}).expand({B0, B1, B2, B3, 3}),
938 });
939 }
940 {
941 // Edge case: BatchedTensor "scalar" handling
942 int64_t B0 = 5, B2 = 11;
943 Tensor x = at::randn({B0});
944 Tensor y = at::randn({B0, B2});
945 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
946 Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
947
948 checkMultiBatchVmapTransform({batched_x, batched_y}, {x.view({B0, 1}).expand({B0, B2}), y});
949 checkMultiBatchVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1}).expand({B0, B2})});
950 }
951 {
952 // Edge case: Only one tensor is a "batchedtensor scalar"
953 int64_t B0 = 5, B2 = 11;
954 Tensor x = at::randn({B0});
955 Tensor y = at::randn({B0, B2, 2});
956 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}});
957 Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
958
959 checkMultiBatchVmapTransform({batched_x, batched_y}, {x.view({B0, 1}).expand({B0, B2}), y});
960 checkMultiBatchVmapTransform({batched_y, batched_x}, {y, x.view({B0, 1}).expand({B0, B2})});
961 }
962 }
963
TEST(VmapTest,TestMultiBatchVmapTransformBatchedUnbatched)964 TEST(VmapTest, TestMultiBatchVmapTransformBatchedUnbatched) {
965 {
966 // Check same example size
967 int64_t B0 = 5, B1 = 7;
968 Tensor x = at::randn({2, B0, 3, B1});
969 Tensor y = at::randn({2, 3});
970 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
971
972 checkMultiBatchVmapTransform(
973 {batched_x, y},
974 {at::movedim(x, {1, 3}, {0, 1}), y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3})});
975 checkMultiBatchVmapTransform(
976 {y, batched_x},
977 {y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3}), at::movedim(x, {1, 3}, {0, 1})});
978 }
979 {
980 // BatchedTensor has higher example dim than non-batched-tensor
981 int64_t B0 = 5, B1 = 7;
982 Tensor x = at::randn({B0, B1, 2, 3});
983 Tensor y = at::randn({3});
984 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
985
986 checkMultiBatchVmapTransform(
987 {batched_x, y}, {x, y.view({1, 1, 3}).expand({B0, B1, 3})});
988 checkMultiBatchVmapTransform(
989 {y, batched_x}, {y.view({1, 1, 3}).expand({B0, B1, 3}), x});
990 }
991 {
992 // BatchedTensor has lower example dim than non-batched-tensor
993 int64_t B0 = 5, B1 = 7;
994 Tensor x = at::randn({B0, B1, 3});
995 Tensor y = at::randn({2, 3});
996 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
997
998 checkMultiBatchVmapTransform(
999 {batched_x, y}, {x.view({B0, B1, 3}), y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3})});
1000 checkMultiBatchVmapTransform(
1001 {y, batched_x}, {y.view({1, 1, 2, 3}).expand({B0, B1, 2, 3}), x.view({B0, B1, 3})});
1002 }
1003 {
1004 // Scalar handling
1005 int64_t B0 = 5, B1 = 7;
1006 Tensor x = at::randn({B0, B1});
1007 Tensor y = at::randn({});
1008 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/0}, {/*lvl*/1, /*dim*/1}});
1009
1010 checkMultiBatchVmapTransform({batched_x, y}, {x, y.view({1, 1}).expand({B0, B1})});
1011 checkMultiBatchVmapTransform({y, batched_x}, {y.view({1, 1}).expand({B0, B1}), x});
1012 }
1013 }
1014
TEST(VmapTest,TestMultiBatchVmapTransformMaxLevels)1015 TEST(VmapTest, TestMultiBatchVmapTransformMaxLevels) {
1016 {
1017 // inputs have all 64 levels
1018 auto x = randn(std::vector<int64_t>(kVmapNumLevels, 1));
1019 auto y = randn(std::vector<int64_t>(kVmapNumLevels, 1));
1020 auto batched_x = makeBatched(x, maxBatchDimsAtFront());
1021 auto batched_y = makeBatched(y, maxBatchDimsAtFront());
1022
1023 checkMultiBatchVmapTransform({batched_x, batched_y}, {x, y});
1024 }
1025 {
1026 // inputs don't have all 64 levels, but results do.
1027 int64_t split = 19;
1028 auto x = randn(std::vector<int64_t>(split, 1));
1029 auto y = randn(std::vector<int64_t>(kVmapNumLevels - split, 1));
1030
1031 auto tmp = maxBatchDimsAtFront();
1032 BatchDims x_bdims(tmp.begin(), tmp.begin() + split);
1033
1034 // Construct y_bdims.
1035 int64_t dim = 0;
1036 auto y_bdims_vector = fmap(
1037 ArrayRef<BatchDim>(tmp.begin() + split, tmp.end()),
1038 [&](const BatchDim& bdim) -> BatchDim {
1039 return { bdim.level(), dim++ };
1040 });
1041 BatchDims y_bdims(y_bdims_vector.begin(), y_bdims_vector.end());
1042
1043 auto batched_x = makeBatched(x, x_bdims);
1044 auto batched_y = makeBatched(y, y_bdims);
1045
1046 auto expected_size = std::vector<int64_t>(kVmapNumLevels, 1);
1047 checkMultiBatchVmapTransform(
1048 {batched_x, batched_y},
1049 {x.view(expected_size), y.view(expected_size)});
1050 }
1051 }
1052
TEST(VmapTest,TestMultiBatchVmapTransformMultipleTensors)1053 TEST(VmapTest, TestMultiBatchVmapTransformMultipleTensors) {
1054 // Test with three (all batched) tensors
1055 {
1056 int64_t B0 = 5, B1 = 7, B2 = 9;
1057 Tensor x = at::randn({2, B0, 3, B1});
1058 Tensor y = at::randn({B1, 4});
1059 Tensor z = at::randn({2, B2});
1060 Tensor batched_x = makeBatched(x, {{/*lvl*/0, /*dim*/1}, {/*lvl*/1, /*dim*/3}});
1061 Tensor batched_y = makeBatched(y, {{/*lvl*/1, /*dim*/0}});
1062 Tensor batched_z = makeBatched(z, {{/*lvl*/2, /*dim*/1}});
1063
1064 checkMultiBatchVmapTransform(
1065 {batched_x, batched_y, batched_z},
1066 {
1067 at::movedim(x, {1, 3}, {0, 1}).view({B0, B1, 1, 2, 3}).expand({B0, B1, B2, 2, 3}),
1068 y.view({1, B1, 1, 4}).expand({B0, B1, B2, 4}),
1069 z.t().view({1, 1, B2, 2}).expand({B0, B1, B2, 2}),
1070 });
1071 }
1072 // Test with three tensors, some batched, some unbatched
1073 {
1074 int64_t B0 = 5, B1 = 7, B2 = 9;
1075 Tensor x = at::randn({2, 3});
1076 Tensor y = at::randn({4, B0});
1077 Tensor z = at::randn({B1, 2, B2});
1078 Tensor batched_y = makeBatched(y, {{/*lvl*/0, /*dim*/1}});
1079 Tensor batched_z = makeBatched(z, {{/*lvl*/1, /*dim*/0}, {/*lvl*/2, /*dim*/2}});
1080
1081 checkMultiBatchVmapTransform(
1082 {x, batched_y, batched_z},
1083 {
1084 x.view({1, 1, 1, 2, 3}).expand({B0, B1, B2, 2, 3}),
1085 y.t().view({B0, 1, 1, 4}).expand({B0, B1, B2, 4}),
1086 z.permute({0, 2, 1}).view({1, B1, B2, 2}).expand({B0, B1, B2, 2}),
1087 });
1088 }
1089 }
1090
1091
1092 } // namespace
1093