xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_lazy_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/Device.h>
2 #include <c10/core/DeviceType.h>
3 #include <gtest/gtest.h>
4 #include <test/cpp/lazy/test_lazy_ops_util.h>
5 #include <torch/csrc/lazy/core/debug_util.h>
6 #include <torch/csrc/lazy/core/helpers.h>
7 #include <torch/csrc/lazy/core/ir_builder.h>
8 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
9 #include <torch/csrc/lazy/core/metrics.h>
10 #include <torch/csrc/lazy/core/permutation_util.h>
11 #include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
12 #include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
13 #include <torch/torch.h>
14 #include <iostream>
15 
16 namespace torch {
17 namespace lazy {
18 
19 // Lazy Tensor is disabled in FBCODE until addressing non-virtual methods (e.g.
20 // sizes) in TensorImpl
21 #ifndef FBCODE_CAFFE2
22 
23 namespace {
24 // This registers the torchscript backend, without which lazy device won't work.
25 // FIXME: This registers the backend for the whole test binary. We should
26 // probably do it and undo it in the test fixture below.
init_backend()27 static bool inline init_backend() {
28   torch::lazy::InitTorchScriptBackend();
29   return true;
30 }
31 static const bool backend_initialized = init_backend();
32 
33 } // namespace
34 
35 class LazyTsTest : public ::testing::Test {
36  protected:
37   void SetUp() override;
38 
39   void TearDown() override;
40 
CommonSetup()41   static void CommonSetup() {}
42 
ExpectCounterNotChanged(const std::string & counter_regex,const std::unordered_set<std::string> * ignore_set)43   void ExpectCounterNotChanged(
44       const std::string& counter_regex,
45       const std::unordered_set<std::string>* ignore_set) {}
46 
ExpectCounterChanged(const std::string & counter_regex,const std::unordered_set<std::string> * ignore_set)47   void ExpectCounterChanged(
48       const std::string& counter_regex,
49       const std::unordered_set<std::string>* ignore_set) {}
50 
ResetCounters()51   void ResetCounters() {}
52 
53  private:
MakeEndSnapshot()54   void MakeEndSnapshot() {}
55 };
56 
57 class LazyOpsTestBase : public LazyTsTest {
58  protected:
SetUpTestCase()59   static void SetUpTestCase() {}
60 };
61 
SetUp()62 void LazyTsTest::SetUp() {
63   (void)backend_initialized; // avoid unused parameter warning
64   at::manual_seed(42);
65   torch::lazy::LazyGraphExecutor::Get()->SetRngSeed(
66       torch::lazy::BackendDevice(), 42);
67 }
68 
TearDown()69 void LazyTsTest::TearDown() {}
70 
71 namespace {
72 using torch::lazy::DebugUtil;
73 
74 class LazyOpsTest : public LazyOpsTestBase {};
75 
IsCuda()76 static inline bool IsCuda() {
77   return torch::lazy::getBackend()->EagerFallbackDeviceType() == at::kCUDA;
78 }
79 
DefaultDevice()80 static inline at::DeviceType DefaultDevice() {
81   return torch::lazy::getBackend()->EagerFallbackDeviceType();
82 }
83 
84 } // namespace
85 
TEST_F(LazyOpsTest,TestScalarTensor)86 TEST_F(LazyOpsTest, TestScalarTensor) {
87   torch::Tensor scalar_tensor = torch::scalar_tensor(
88       1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
89   ForEachDevice([&](const torch::Device& device) {
90     torch::Tensor lazy_scalar_tensor = torch::scalar_tensor(
91         1., torch::TensorOptions(torch::kFloat).device(torch::kLazy));
92     AllClose(scalar_tensor, lazy_scalar_tensor);
93   });
94 }
95 
TEST_F(LazyOpsTest,TestClone)96 TEST_F(LazyOpsTest, TestClone) {
97   ForEachDevice([&](const torch::Device& device) {
98     torch::Tensor a = torch::rand(
99         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
100     torch::Tensor lazy_a = CopyToDevice(a, device);
101     torch::Tensor lazy_b = lazy_a.clone();
102     AllClose(a, lazy_b);
103     lazy_a.add_(1.0);
104     AllClose(a, lazy_b);
105   });
106 }
107 
TEST_F(LazyOpsTest,TestTo)108 TEST_F(LazyOpsTest, TestTo) {
109   ForEachDevice([&](const torch::Device& device) {
110     torch::Tensor a = torch::rand(
111         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
112     torch::Tensor lazy_a = CopyToDevice(a, device);
113     AllClose(a, lazy_a);
114   });
115 }
116 
TEST_F(LazyOpsTest,TestIsFloatingPoint)117 TEST_F(LazyOpsTest, TestIsFloatingPoint) {
118   ForEachDevice([&](const torch::Device& device) {
119     torch::Tensor a = torch::rand(
120         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
121     torch::Tensor lazy_a = CopyToDevice(a, device);
122     bool is_float = torch::is_floating_point(a);
123     bool lazy_is_float = torch::is_floating_point(lazy_a);
124     EXPECT_EQ(is_float, lazy_is_float);
125   });
126 }
127 
TEST_F(LazyOpsTest,TestIsSigned)128 TEST_F(LazyOpsTest, TestIsSigned) {
129   ForEachDevice([&](const torch::Device& device) {
130     torch::Tensor a = torch::rand(
131         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
132     torch::Tensor lazy_a = CopyToDevice(a, device);
133     bool is_signed = torch::is_signed(a);
134     bool lazy_is_signed = torch::is_signed(lazy_a);
135     EXPECT_EQ(is_signed, lazy_is_signed);
136   });
137 }
138 
TEST_F(LazyOpsTest,TestCastByte)139 TEST_F(LazyOpsTest, TestCastByte) {
140   torch::Tensor a =
141       torch::rand(
142           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
143       100.0;
144   torch::Tensor b = torch::_cast_Byte(a);
145   ForEachDevice([&](const torch::Device& device) {
146     torch::Tensor lazy_a = CopyToDevice(a, device);
147     torch::Tensor lazy_b = torch::_cast_Byte(lazy_a);
148     AllEqual(b, lazy_b);
149   });
150 }
151 
TEST_F(LazyOpsTest,TestCastChar)152 TEST_F(LazyOpsTest, TestCastChar) {
153   torch::Tensor a =
154       torch::rand(
155           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
156       100.0;
157   torch::Tensor b = torch::_cast_Char(a);
158   ForEachDevice([&](const torch::Device& device) {
159     torch::Tensor lazy_a = CopyToDevice(a, device);
160     torch::Tensor lazy_b = torch::_cast_Char(lazy_a);
161     AllEqual(b, lazy_b);
162   });
163 }
164 
TEST_F(LazyOpsTest,TestCastShort)165 TEST_F(LazyOpsTest, TestCastShort) {
166   torch::Tensor a =
167       torch::rand(
168           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
169       100.0;
170   torch::Tensor b = torch::_cast_Short(a);
171   ForEachDevice([&](const torch::Device& device) {
172     torch::Tensor lazy_a = CopyToDevice(a, device);
173     torch::Tensor lazy_b = torch::_cast_Short(lazy_a);
174     AllEqual(b, lazy_b);
175   });
176 }
177 
TEST_F(LazyOpsTest,TestCastInt)178 TEST_F(LazyOpsTest, TestCastInt) {
179   torch::Tensor a =
180       torch::rand(
181           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
182       100.0;
183   torch::Tensor b = torch::_cast_Int(a);
184   ForEachDevice([&](const torch::Device& device) {
185     torch::Tensor lazy_a = CopyToDevice(a, device);
186     torch::Tensor lazy_b = torch::_cast_Int(lazy_a);
187     AllEqual(b, lazy_b);
188   });
189 }
190 
TEST_F(LazyOpsTest,TestCastLong)191 TEST_F(LazyOpsTest, TestCastLong) {
192   torch::Tensor a =
193       torch::rand(
194           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
195       100.0;
196   torch::Tensor b = torch::_cast_Long(a);
197   ForEachDevice([&](const torch::Device& device) {
198     torch::Tensor lazy_a = CopyToDevice(a, device);
199     torch::Tensor lazy_b = torch::_cast_Long(lazy_a);
200     AllEqual(b, lazy_b);
201   });
202 }
203 
TEST_F(LazyOpsTest,TestCastFloat)204 TEST_F(LazyOpsTest, TestCastFloat) {
205   torch::Tensor a =
206       torch::rand(
207           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
208       100.0;
209   torch::Tensor b = torch::_cast_Float(a);
210   ForEachDevice([&](const torch::Device& device) {
211     torch::Tensor lazy_a = CopyToDevice(a, device);
212     torch::Tensor lazy_b = torch::_cast_Float(lazy_a);
213     AllEqual(b, lazy_b);
214   });
215 }
216 
TEST_F(LazyOpsTest,TestRetainType)217 TEST_F(LazyOpsTest, TestRetainType) {
218   torch::Tensor lazy_a = torch::zeros(
219       {2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
220   torch::Tensor lazy_b = torch::ones(
221       {2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
222   torch::Tensor lazy_c = lazy_a + lazy_b;
223   EXPECT_EQ(lazy_c.scalar_type(), torch::ScalarType::Byte);
224 }
225 
TEST_F(LazyOpsTest,TestLogicalTypeWithInterop)226 TEST_F(LazyOpsTest, TestLogicalTypeWithInterop) {
227   torch::Tensor query = torch::rand(
228       {2, 12, 20, 64},
229       torch::TensorOptions(torch::kFloat).device(torch::kLazy));
230   torch::Tensor key = torch::rand(
231       {2, 12, 64, 20},
232       torch::TensorOptions(torch::kFloat).device(torch::kLazy));
233   torch::Tensor scores =
234       torch::matmul(query, key) /
235       torch::scalar_tensor(
236           8, torch::TensorOptions(torch::kDouble).device(torch::kLazy));
237   torch::Tensor p_attn = torch::softmax(scores, /*dim=*/-1);
238   EXPECT_EQ(p_attn.scalar_type(), torch::ScalarType::Float);
239 }
240 
TEST_F(LazyOpsTest,TestAdd)241 TEST_F(LazyOpsTest, TestAdd) {
242   torch::Tensor a = torch::rand(
243       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
244   torch::Tensor b = torch::rand(
245       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
246   torch::Tensor c = torch::add(a, b);
247   ForEachDevice([&](const torch::Device& device) {
248     torch::Tensor lazy_a = CopyToDevice(a, device);
249     torch::Tensor lazy_b = CopyToDevice(b, device);
250     torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
251     AllClose(c, lazy_c);
252   });
253 }
254 
TEST_F(LazyOpsTest,TestAddHalf)255 TEST_F(LazyOpsTest, TestAddHalf) {
256   torch::Tensor a = torch::rand(
257       {2, 2}, torch::TensorOptions(torch::kHalf).device(DefaultDevice()));
258   torch::Tensor b = torch::rand(
259       {2, 2}, torch::TensorOptions(torch::kHalf).device(DefaultDevice()));
260   torch::Tensor c = torch::add(a, b);
261   ForEachDevice([&](const torch::Device& device) {
262     torch::Tensor lazy_a = CopyToDevice(a, device);
263     torch::Tensor lazy_b = CopyToDevice(b, device);
264     torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
265     AllClose(c, lazy_c);
266   });
267 }
268 
TEST_F(LazyOpsTest,TestAddMixedPrecision)269 TEST_F(LazyOpsTest, TestAddMixedPrecision) {
270   torch::Tensor a = torch::rand(
271       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
272   torch::Tensor b = torch::rand(
273       {2, 2}, torch::TensorOptions(torch::kHalf).device(DefaultDevice()));
274   torch::Tensor c = torch::add(a, b);
275   ForEachDevice([&](const torch::Device& device) {
276     torch::Tensor lazy_a = CopyToDevice(a, device);
277     torch::Tensor lazy_b = CopyToDevice(b, device);
278     torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
279     AllClose(c, lazy_c);
280   });
281 }
282 
TEST_F(LazyOpsTest,TestAddInPlace)283 TEST_F(LazyOpsTest, TestAddInPlace) {
284   ForEachDevice([&](const torch::Device& device) {
285     torch::Tensor a = torch::rand(
286         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
287     torch::Tensor lazy_a = CopyToDevice(a, device);
288     torch::Tensor b = torch::rand(
289         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
290     torch::Tensor lazy_b = CopyToDevice(b, device);
291     torch::Tensor c = a.add_(b);
292     torch::Tensor lazy_c = lazy_a.add_(lazy_b);
293     AllClose(a, lazy_a);
294     AllClose(c, lazy_c);
295   });
296 }
297 
TEST_F(LazyOpsTest,TestAddScalar)298 TEST_F(LazyOpsTest, TestAddScalar) {
299   torch::Tensor a = torch::rand(
300       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
301   torch::Scalar b(1);
302   torch::Tensor c = torch::add(a, b);
303   ForEachDevice([&](const torch::Device& device) {
304     torch::Tensor lazy_a = CopyToDevice(a, device);
305     torch::Tensor lazy_c = torch::add(lazy_a, b);
306     AllClose(c, lazy_c);
307   });
308 }
309 
TEST_F(LazyOpsTest,TestAddScalarInPlace)310 TEST_F(LazyOpsTest, TestAddScalarInPlace) {
311   torch::Scalar b(1);
312   ForEachDevice([&](const torch::Device& device) {
313     torch::Tensor a = torch::rand(
314         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
315     torch::Tensor lazy_a = CopyToDevice(a, device);
316     torch::Tensor c = a.add_(b);
317     torch::Tensor lazy_c = lazy_a.add_(b);
318     AllClose(a, lazy_a);
319     AllClose(c, lazy_c);
320   });
321 }
322 
TEST_F(LazyOpsTest,TestAddZeroSizeDim)323 TEST_F(LazyOpsTest, TestAddZeroSizeDim) {
324   torch::Tensor a = torch::rand(
325       {0, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
326   torch::Tensor b = torch::rand(
327       {1, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
328   torch::Tensor c = torch::add(a, b);
329   ForEachDevice([&](const torch::Device& device) {
330     torch::Tensor lazy_a = CopyToDevice(a, device);
331     torch::Tensor lazy_b = CopyToDevice(b, device);
332     torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
333     AllClose(c, lazy_c);
334   });
335 }
336 
TEST_F(LazyOpsTest,TestSub)337 TEST_F(LazyOpsTest, TestSub) {
338   torch::Tensor a = torch::rand(
339       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
340   torch::Tensor b = torch::rand(
341       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
342   torch::Tensor c = torch::sub(a, b);
343   ForEachDevice([&](const torch::Device& device) {
344     torch::Tensor lazy_a = CopyToDevice(a, device);
345     torch::Tensor lazy_b = CopyToDevice(b, device);
346     torch::Tensor lazy_c = torch::sub(lazy_a, lazy_b);
347     AllClose(c, lazy_c);
348   });
349 }
350 
TEST_F(LazyOpsTest,TestSubInPlace)351 TEST_F(LazyOpsTest, TestSubInPlace) {
352   ForEachDevice([&](const torch::Device& device) {
353     torch::Tensor a = torch::rand(
354         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
355     torch::Tensor lazy_a = CopyToDevice(a, device);
356     torch::Tensor b = torch::rand(
357         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
358     torch::Tensor lazy_b = CopyToDevice(b, device);
359     torch::Tensor c = a.sub_(b);
360     torch::Tensor lazy_c = lazy_a.sub_(lazy_b);
361     AllClose(a, lazy_a);
362     AllClose(c, lazy_c);
363   });
364 }
365 
TEST_F(LazyOpsTest,TestSubScalar)366 TEST_F(LazyOpsTest, TestSubScalar) {
367   torch::Tensor a = torch::rand(
368       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
369   torch::Scalar b(1);
370   torch::Tensor c = torch::sub(a, b);
371   ForEachDevice([&](const torch::Device& device) {
372     torch::Tensor lazy_a = CopyToDevice(a, device);
373     torch::Tensor lazy_c = torch::sub(lazy_a, b);
374     AllClose(c, lazy_c);
375   });
376 }
377 
TEST_F(LazyOpsTest,TestSubScalarInPlace)378 TEST_F(LazyOpsTest, TestSubScalarInPlace) {
379   torch::Scalar b(1);
380   ForEachDevice([&](const torch::Device& device) {
381     torch::Tensor a = torch::rand(
382         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
383     torch::Tensor lazy_a = CopyToDevice(a, device);
384     torch::Tensor c = a.sub_(b);
385     torch::Tensor lazy_c = lazy_a.sub_(b);
386     AllClose(a, lazy_a);
387     AllClose(c, lazy_c);
388   });
389 }
390 
TEST_F(LazyOpsTest,TestMul)391 TEST_F(LazyOpsTest, TestMul) {
392   torch::Tensor a = torch::rand(
393       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
394   torch::Tensor b = torch::rand(
395       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
396   torch::Tensor c = torch::mul(a, b);
397   ForEachDevice([&](const torch::Device& device) {
398     torch::Tensor lazy_a = CopyToDevice(a, device);
399     torch::Tensor lazy_b = CopyToDevice(b, device);
400     torch::Tensor lazy_c = torch::mul(lazy_a, lazy_b);
401     AllClose(c, lazy_c);
402   });
403 }
404 
TEST_F(LazyOpsTest,TestMulInPlace)405 TEST_F(LazyOpsTest, TestMulInPlace) {
406   ForEachDevice([&](const torch::Device& device) {
407     torch::Tensor a = torch::rand(
408         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
409     torch::Tensor lazy_a = CopyToDevice(a, device);
410     torch::Tensor b = torch::rand(
411         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
412     torch::Tensor lazy_b = CopyToDevice(b, device);
413     torch::Tensor c = a.mul_(b);
414     torch::Tensor lazy_c = lazy_a.mul_(lazy_b);
415     AllClose(a, lazy_a);
416     AllClose(c, lazy_c);
417   });
418 }
419 
TEST_F(LazyOpsTest,TestMulScalar)420 TEST_F(LazyOpsTest, TestMulScalar) {
421   torch::Tensor a = torch::rand(
422       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
423   torch::Scalar b(3);
424   torch::Tensor c = torch::mul(a, b);
425   ForEachDevice([&](const torch::Device& device) {
426     torch::Tensor lazy_a = CopyToDevice(a, device);
427     torch::Tensor lazy_c = torch::mul(lazy_a, b);
428     AllClose(c, lazy_c);
429   });
430 }
431 
TEST_F(LazyOpsTest,TestMulScalarInPlace)432 TEST_F(LazyOpsTest, TestMulScalarInPlace) {
433   torch::Scalar b(3);
434   ForEachDevice([&](const torch::Device& device) {
435     torch::Tensor a = torch::rand(
436         {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
437     torch::Tensor lazy_a = CopyToDevice(a, device);
438     torch::Tensor c = a.mul_(b);
439     torch::Tensor lazy_c = lazy_a.mul_(b);
440     AllClose(a, lazy_a);
441     AllClose(c, lazy_c);
442   });
443 }
444 
TEST_F(LazyOpsTest,TestDiv)445 TEST_F(LazyOpsTest, TestDiv) {
446   for (torch::ScalarType scalar_type1 :
447        {torch::kFloat,
448         torch::kByte,
449         torch::kChar,
450         torch::kShort,
451         torch::kInt,
452         torch::kLong}) {
453     torch::Tensor a = isFloatingType(scalar_type1)
454         ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
455         : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1));
456     for (torch::ScalarType scalar_type2 :
457          {torch::kFloat,
458           torch::kByte,
459           torch::kChar,
460           torch::kShort,
461           torch::kInt,
462           torch::kLong}) {
463       torch::Tensor b = isFloatingType(scalar_type2)
464           ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
465           : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
466       torch::Tensor c = torch::div(a, b);
467       ForEachDevice([&](const torch::Device& device) {
468         torch::Tensor lazy_a = CopyToDevice(a, device);
469         torch::Tensor lazy_b = CopyToDevice(b, device);
470         torch::Tensor lazy_c = torch::div(lazy_a, lazy_b);
471         AllClose(c, lazy_c);
472       });
473     }
474   }
475 }
476 
TEST_F(LazyOpsTest,TestDivWithRoundingMode)477 TEST_F(LazyOpsTest, TestDivWithRoundingMode) {
478   std::optional<c10::string_view> rounding_modes[] = {
479       "trunc", "floor", std::nullopt};
480   for (const auto& rounding_mode : rounding_modes) {
481     for (torch::ScalarType scalar_type1 :
482          {torch::kFloat,
483           torch::kByte,
484           torch::kChar,
485           torch::kShort,
486           torch::kInt,
487           torch::kLong}) {
488       int lower_bound = (scalar_type1 == torch::kByte) ? 0 : -100;
489       torch::Tensor a = isFloatingType(scalar_type1)
490           ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
491           : torch::randint(
492                 lower_bound, 50, {3, 4}, torch::TensorOptions(scalar_type1));
493       for (torch::ScalarType scalar_type2 :
494            {torch::kFloat,
495             torch::kByte,
496             torch::kChar,
497             torch::kShort,
498             torch::kInt,
499             torch::kLong}) {
500         torch::Tensor b = isFloatingType(scalar_type2)
501             ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
502             : torch::randint(
503                   51, 100, {3, 4}, torch::TensorOptions(scalar_type2));
504         torch::Tensor c = torch::div(a, b, rounding_mode);
505         ForEachDevice([&](const torch::Device& device) {
506           torch::Tensor lazy_a = CopyToDevice(a, device);
507           torch::Tensor lazy_b = CopyToDevice(b, device);
508           torch::Tensor lazy_c = torch::div(lazy_a, lazy_b, rounding_mode);
509           AllClose(c, lazy_c);
510         });
511       }
512     }
513   }
514 }
515 
TEST_F(LazyOpsTest,TestDivInPlace)516 TEST_F(LazyOpsTest, TestDivInPlace) {
517   for (torch::ScalarType scalar_type1 : {torch::kFloat}) {
518     torch::Tensor a = isFloatingType(scalar_type1)
519         ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
520         : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1));
521     for (torch::ScalarType scalar_type2 : {torch::kFloat}) {
522       torch::Tensor b = isFloatingType(scalar_type2)
523           ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
524           : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
525       ForEachDevice([&](const torch::Device& device) {
526         torch::Tensor lazy_a = CopyToDevice(a, device);
527         torch::Tensor c = a.div_(b);
528         torch::Tensor lazy_b = CopyToDevice(b, device);
529         torch::Tensor lazy_c = lazy_a.div_(lazy_b);
530         ;
531         AllClose(c, lazy_c);
532       });
533     }
534   }
535 }
536 
TEST_F(LazyOpsTest,TestDivInPlaceWithRoundingMode)537 TEST_F(LazyOpsTest, TestDivInPlaceWithRoundingMode) {
538   std::optional<c10::string_view> rounding_modes[] = {
539       "trunc", "floor", std::nullopt};
540   for (const auto& rounding_mode : rounding_modes) {
541     for (torch::ScalarType scalar_type1 : {torch::kFloat}) {
542       torch::Tensor a = isFloatingType(scalar_type1)
543           ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
544           : torch::randint(
545                 -100, 100, {3, 4}, torch::TensorOptions(scalar_type1));
546       for (torch::ScalarType scalar_type2 : {torch::kFloat}) {
547         torch::Tensor b = isFloatingType(scalar_type2)
548             ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
549             : torch::randint(
550                   1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
551         ForEachDevice([&](const torch::Device& device) {
552           torch::Tensor lazy_a = CopyToDevice(a, device);
553           torch::Tensor c = a.div_(b, rounding_mode);
554           torch::Tensor lazy_b = CopyToDevice(b, device);
555           torch::Tensor lazy_c = lazy_a.div_(lazy_b, rounding_mode);
556           AllClose(c, lazy_c);
557         });
558       }
559     }
560   }
561 }
562 
TEST_F(LazyOpsTest,TestDivScalar)563 TEST_F(LazyOpsTest, TestDivScalar) {
564   for (torch::ScalarType scalar_type :
565        {torch::kFloat,
566         torch::kByte,
567         torch::kChar,
568         torch::kShort,
569         torch::kInt,
570         torch::kLong}) {
571     torch::Tensor a = isFloatingType(scalar_type)
572         ? torch::rand(
573               {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
574         : torch::randint(
575               1,
576               100,
577               {3, 4},
578               torch::TensorOptions(scalar_type).device(DefaultDevice()));
579     for (bool is_float : {true, false}) {
580       torch::Scalar b = is_float ? torch::Scalar(3.0) : torch::Scalar(3);
581       torch::Tensor c = torch::div(a, b);
582       ForEachDevice([&](const torch::Device& device) {
583         torch::Tensor lazy_a = CopyToDevice(a, device);
584         torch::Tensor lazy_c = torch::div(lazy_a, b);
585         AllClose(c, lazy_c);
586       });
587     }
588   }
589 }
590 
TEST_F(LazyOpsTest,TestDivScalarInPlace)591 TEST_F(LazyOpsTest, TestDivScalarInPlace) {
592   for (torch::ScalarType scalar_type : {torch::kFloat}) {
593     torch::Tensor a = isFloatingType(scalar_type)
594         ? torch::rand(
595               {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
596         : torch::randint(
597               1,
598               100,
599               {3, 4},
600               torch::TensorOptions(scalar_type).device(DefaultDevice()));
601     for (bool is_float : {true, false}) {
602       torch::Scalar b = is_float ? torch::Scalar(3.0) : torch::Scalar(3);
603       ForEachDevice([&](const torch::Device& device) {
604         torch::Tensor lazy_a = CopyToDevice(a, device);
605         torch::Tensor c = a.div_(b);
606         torch::Tensor lazy_c = lazy_a.div_(b);
607         AllClose(c, lazy_c);
608       });
609     }
610   }
611 }
612 
TEST_F(LazyOpsTest,TestDivOut)613 TEST_F(LazyOpsTest, TestDivOut) {
614   for (torch::ScalarType scalar_type : {torch::kFloat, torch::kDouble}) {
615     torch::Tensor a = torch::rand(
616         {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
617     torch::Tensor b = torch::rand(
618         {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
619     torch::Tensor c = torch::empty(
620         {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
621     torch::div_out(c, a, b);
622     ForEachDevice([&](const torch::Device& device) {
623       torch::Tensor lazy_a = CopyToDevice(a, device);
624       torch::Tensor lazy_b = CopyToDevice(b, device);
625       torch::Tensor lazy_c = torch::empty({3, 4}, lazy_b.options());
626       torch::div_out(lazy_c, lazy_a, lazy_b);
627       AllClose(c, lazy_c);
628     });
629   }
630 }
631 
TEST_F(LazyOpsTest,TestRsubScalar)632 TEST_F(LazyOpsTest, TestRsubScalar) {
633   torch::Tensor input = torch::rand(
634       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
635   torch::Scalar other(1.5);
636   torch::Scalar alpha(2.5);
637   torch::Tensor result = torch::rsub(input, other, alpha);
638   ForEachDevice([&](const torch::Device& device) {
639     torch::Tensor lazy_input = CopyToDevice(input, device);
640     torch::Tensor lazy_result = torch::rsub(lazy_input, other, alpha);
641     AllClose(result, lazy_result);
642   });
643 }
644 
TEST_F(LazyOpsTest,TestNe)645 TEST_F(LazyOpsTest, TestNe) {
646   torch::Tensor a = torch::rand(
647       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
648   torch::Tensor b = torch::rand(
649       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
650   torch::Tensor c = torch::ne(a, b);
651   ForEachDevice([&](const torch::Device& device) {
652     torch::Tensor lazy_a = CopyToDevice(a, device);
653     torch::Tensor lazy_b = CopyToDevice(b, device);
654     torch::Tensor lazy_c = torch::ne(lazy_a, lazy_b);
655     AllEqual(c, lazy_c);
656   });
657 }
658 
TEST_F(LazyOpsTest,TestNeInplace)659 TEST_F(LazyOpsTest, TestNeInplace) {
660   torch::Tensor a = torch::rand(
661       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
662   torch::Tensor a_copy = a.clone();
663   torch::Tensor b = a.clone();
664   b[0] += 1;
665   a.ne_(b);
666   ForEachDevice([&](const torch::Device& device) {
667     torch::Tensor lazy_a = CopyToDevice(a_copy, device);
668     torch::Tensor lazy_b = CopyToDevice(b, device);
669     lazy_a.ne_(lazy_b);
670     AllClose(a, lazy_a);
671   });
672 }
673 
TEST_F(LazyOpsTest,TestEq)674 TEST_F(LazyOpsTest, TestEq) {
675   torch::Tensor a = torch::rand(
676       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
677   torch::Tensor b = a.clone();
678   torch::Tensor c = torch::eq(a, b);
679   ForEachDevice([&](const torch::Device& device) {
680     torch::Tensor lazy_a = CopyToDevice(a, device);
681     torch::Tensor lazy_b = CopyToDevice(b, device);
682     torch::Tensor lazy_c = torch::eq(lazy_a, lazy_b);
683     AllEqual(c, lazy_c);
684   });
685 }
686 
TEST_F(LazyOpsTest,TestEqInplace)687 TEST_F(LazyOpsTest, TestEqInplace) {
688   torch::Tensor a = torch::rand(
689       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
690   torch::Tensor b = a.clone();
691   b[0] += 1;
692   torch::Tensor a_copy = a.clone();
693   a.eq_(b);
694   ForEachDevice([&](const torch::Device& device) {
695     torch::Tensor lazy_a = CopyToDevice(a_copy, device);
696     torch::Tensor lazy_b = CopyToDevice(b, device);
697     lazy_a.eq_(lazy_b);
698     AllClose(lazy_a, a);
699   });
700 }
701 
TEST_F(LazyOpsTest,TestGe)702 TEST_F(LazyOpsTest, TestGe) {
703   torch::Tensor a = torch::rand(
704       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
705   torch::Tensor b = a.clone();
706   torch::Tensor c = torch::ge(a, b);
707   ForEachDevice([&](const torch::Device& device) {
708     torch::Tensor lazy_a = CopyToDevice(a, device);
709     torch::Tensor lazy_b = CopyToDevice(b, device);
710     torch::Tensor lazy_c = torch::ge(lazy_a, lazy_b);
711     AllEqual(c, lazy_c);
712   });
713 }
714 
TEST_F(LazyOpsTest,TestGeInplace)715 TEST_F(LazyOpsTest, TestGeInplace) {
716   torch::Tensor a = torch::rand(
717       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
718   torch::Tensor b = a.clone();
719   b[0] += 1;
720   b[1] -= 1;
721   torch::Tensor a_copy = a.clone();
722   a.ge_(b);
723   ForEachDevice([&](const torch::Device& device) {
724     torch::Tensor lazy_a = CopyToDevice(a_copy, device);
725     torch::Tensor lazy_b = CopyToDevice(b, device);
726     lazy_a.ge_(lazy_b);
727     AllClose(lazy_a, a);
728   });
729 }
730 
TEST_F(LazyOpsTest,TestLe)731 TEST_F(LazyOpsTest, TestLe) {
732   torch::Tensor a = torch::rand(
733       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
734   torch::Tensor b = a.clone();
735   torch::Tensor c = torch::le(a, b);
736   ForEachDevice([&](const torch::Device& device) {
737     torch::Tensor lazy_a = CopyToDevice(a, device);
738     torch::Tensor lazy_b = CopyToDevice(b, device);
739     torch::Tensor lazy_c = torch::le(lazy_a, lazy_b);
740     AllEqual(c, lazy_c);
741   });
742 }
743 
TEST_F(LazyOpsTest,TestLeInplace)744 TEST_F(LazyOpsTest, TestLeInplace) {
745   torch::Tensor a = torch::rand(
746       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
747   torch::Tensor b = a.clone();
748   b[0] += 1;
749   b[1] -= 1;
750   torch::Tensor a_copy = a.clone();
751   a.le_(b);
752   ForEachDevice([&](const torch::Device& device) {
753     torch::Tensor lazy_a = CopyToDevice(a_copy, device);
754     torch::Tensor lazy_b = CopyToDevice(b, device);
755     lazy_a.le_(lazy_b);
756     AllClose(lazy_a, a);
757   });
758 }
759 
TEST_F(LazyOpsTest,TestGt)760 TEST_F(LazyOpsTest, TestGt) {
761   torch::Tensor a = torch::rand(
762       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
763   torch::Tensor b = torch::add(a.clone(), torch::ones_like(a));
764   torch::Tensor c = torch::gt(b, a);
765   ForEachDevice([&](const torch::Device& device) {
766     torch::Tensor lazy_a = CopyToDevice(a, device);
767     torch::Tensor lazy_b = CopyToDevice(b, device);
768     torch::Tensor lazy_c = torch::gt(lazy_b, lazy_a);
769     AllEqual(c, lazy_c);
770   });
771 }
772 
TEST_F(LazyOpsTest,TestGtInplace)773 TEST_F(LazyOpsTest, TestGtInplace) {
774   torch::Tensor a = torch::rand(
775       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
776   torch::Tensor b = a.clone();
777   b[0] += 1;
778   b[1] -= 1;
779   torch::Tensor a_copy = a.clone();
780   a.gt_(b);
781   ForEachDevice([&](const torch::Device& device) {
782     torch::Tensor lazy_a = CopyToDevice(a_copy, device);
783     torch::Tensor lazy_b = CopyToDevice(b, device);
784     lazy_a.gt_(lazy_b);
785     AllClose(lazy_a, a);
786   });
787 }
788 
TEST_F(LazyOpsTest,TestLt)789 TEST_F(LazyOpsTest, TestLt) {
790   torch::Tensor a = torch::rand(
791       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
792   torch::Tensor b = torch::add(a.clone(), torch::ones_like(a));
793   torch::Tensor c = torch::lt(a, b);
794   ForEachDevice([&](const torch::Device& device) {
795     torch::Tensor lazy_a = CopyToDevice(a, device);
796     torch::Tensor lazy_b = CopyToDevice(b, device);
797     torch::Tensor lazy_c = torch::lt(lazy_a, lazy_b);
798     AllEqual(c, lazy_c);
799   });
800 }
801 
TEST_F(LazyOpsTest,TestLtInplace)802 TEST_F(LazyOpsTest, TestLtInplace) {
803   torch::Tensor a = torch::rand(
804       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
805   torch::Tensor b = a.clone();
806   b[0] += 1;
807   b[1] -= 1;
808   torch::Tensor a_copy = a.clone();
809   a.lt_(b);
810   ForEachDevice([&](const torch::Device& device) {
811     torch::Tensor lazy_a = CopyToDevice(a_copy, device);
812     torch::Tensor lazy_b = CopyToDevice(b, device);
813     lazy_a.lt_(lazy_b);
814     AllClose(lazy_a, a);
815   });
816 }
817 
TEST_F(LazyOpsTest,TestNeScalar)818 TEST_F(LazyOpsTest, TestNeScalar) {
819   torch::Tensor input = torch::ones({2, 3});
820   torch::Scalar other(float(0));
821   torch::Tensor result = torch::ne(input, other);
822   ForEachDevice([&](const torch::Device& device) {
823     torch::Tensor lazy_input = CopyToDevice(input, device);
824     torch::Tensor lazy_result = torch::ne(lazy_input, other);
825     AllEqual(result, lazy_result);
826   });
827 }
828 
TEST_F(LazyOpsTest,TestEqScalar)829 TEST_F(LazyOpsTest, TestEqScalar) {
830   torch::Tensor input = torch::ones({2, 3});
831   torch::Scalar other(float(1));
832   torch::Tensor result = torch::eq(input, other);
833   ForEachDevice([&](const torch::Device& device) {
834     torch::Tensor lazy_input = CopyToDevice(input, device);
835     torch::Tensor lazy_result = torch::eq(lazy_input, other);
836     AllEqual(result, lazy_result);
837   });
838 }
839 
TEST_F(LazyOpsTest,TestGeScalar)840 TEST_F(LazyOpsTest, TestGeScalar) {
841   torch::Tensor input = torch::ones({2, 3});
842   torch::Scalar other(float(1));
843   torch::Tensor result = torch::ge(input, other);
844   ForEachDevice([&](const torch::Device& device) {
845     torch::Tensor lazy_input = CopyToDevice(input, device);
846     torch::Tensor lazy_result = torch::ge(lazy_input, other);
847     AllEqual(result, lazy_result);
848   });
849 }
850 
TEST_F(LazyOpsTest,TestGeScalarInplace)851 TEST_F(LazyOpsTest, TestGeScalarInplace) {
852   torch::Tensor input = torch::arange(
853       -1.,
854       1.5,
855       0.5,
856       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
857   torch::Scalar other(float(0));
858   torch::Tensor input_copy = input.clone();
859   input.ge_(other);
860   ForEachDevice([&](const torch::Device& device) {
861     torch::Tensor lazy_input = CopyToDevice(input_copy, device);
862     lazy_input.ge_(other);
863     AllClose(lazy_input, input);
864   });
865 }
866 
TEST_F(LazyOpsTest,TestLeScalar)867 TEST_F(LazyOpsTest, TestLeScalar) {
868   torch::Tensor input = torch::ones({2, 3});
869   torch::Scalar other(float(1));
870   torch::Tensor result = torch::le(input, other);
871   ForEachDevice([&](const torch::Device& device) {
872     torch::Tensor lazy_input = CopyToDevice(input, device);
873     torch::Tensor lazy_result = torch::le(lazy_input, other);
874     AllEqual(result, lazy_result);
875   });
876 }
877 
TEST_F(LazyOpsTest,TestLeScalarInplace)878 TEST_F(LazyOpsTest, TestLeScalarInplace) {
879   torch::Tensor input = torch::arange(
880       -1.,
881       1.5,
882       0.5,
883       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
884   torch::Scalar other(float(0));
885   torch::Tensor input_copy = input.clone();
886   input.le_(other);
887   ForEachDevice([&](const torch::Device& device) {
888     torch::Tensor lazy_input = CopyToDevice(input_copy, device);
889     lazy_input.le_(other);
890     AllClose(lazy_input, input);
891   });
892 }
893 
TEST_F(LazyOpsTest,TestGtScalar)894 TEST_F(LazyOpsTest, TestGtScalar) {
895   torch::Tensor input = torch::ones({2, 3});
896   torch::Scalar other(float(0.5));
897   torch::Tensor result = torch::gt(input, other);
898   ForEachDevice([&](const torch::Device& device) {
899     torch::Tensor lazy_input = CopyToDevice(input, device);
900     torch::Tensor lazy_result = torch::gt(lazy_input, other);
901     AllEqual(result, lazy_result);
902   });
903 }
904 
TEST_F(LazyOpsTest,TestGtScalarInplace)905 TEST_F(LazyOpsTest, TestGtScalarInplace) {
906   torch::Tensor input = torch::arange(
907       -1.,
908       1.5,
909       0.5,
910       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
911   torch::Scalar other(float(0));
912   torch::Tensor input_copy = input.clone();
913   input.gt_(other);
914   ForEachDevice([&](const torch::Device& device) {
915     torch::Tensor lazy_input = CopyToDevice(input_copy, device);
916     lazy_input.gt_(other);
917     AllClose(lazy_input, input);
918   });
919 }
920 
TEST_F(LazyOpsTest,TestLtScalar)921 TEST_F(LazyOpsTest, TestLtScalar) {
922   torch::Tensor input = torch::ones({2, 3});
923   torch::Scalar other(float(1.5));
924   torch::Tensor result = torch::lt(input, other);
925   ForEachDevice([&](const torch::Device& device) {
926     torch::Tensor lazy_input = CopyToDevice(input, device);
927     torch::Tensor lazy_result = torch::lt(lazy_input, other);
928     AllEqual(result, lazy_result);
929   });
930 }
931 
TEST_F(LazyOpsTest,TestLtScalarInplace)932 TEST_F(LazyOpsTest, TestLtScalarInplace) {
933   torch::Tensor input = torch::arange(
934       -1.,
935       1.5,
936       0.5,
937       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
938   torch::Scalar other(float(0));
939   torch::Tensor input_copy = input.clone();
940   input.lt_(other);
941   ForEachDevice([&](const torch::Device& device) {
942     torch::Tensor lazy_input = CopyToDevice(input_copy, device);
943     lazy_input.lt_(other);
944     AllClose(lazy_input, input);
945   });
946 }
947 
TEST_F(LazyOpsTest,TestIntegerAdd)948 TEST_F(LazyOpsTest, TestIntegerAdd) {
949   std::vector<torch::ScalarType> types(
950       {torch::kByte, torch::kChar, torch::kShort, torch::kInt, torch::kLong});
951 
952   ForEachDevice([&](const torch::Device& device) {
953     for (auto type : types) {
954       torch::Tensor a =
955           torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
956       torch::Tensor b =
957           torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
958       torch::Scalar one =
959           isIntegralType(type, false) ? torch::Scalar(1) : torch::Scalar(1.0);
960       torch::Tensor c = torch::add(b, one);
961 
962       torch::Tensor lazy_a = CopyToDevice(a, device);
963       torch::Tensor lazy_b = CopyToDevice(b, device);
964       torch::Tensor lazy_c = torch::add(lazy_b, one);
965 
966       AllEqual(c, lazy_c);
967     }
968   });
969 }
970 
TEST_F(LazyOpsTest,TestSVD)971 TEST_F(LazyOpsTest, TestSVD) {
972   static const int dims[] = {4, 7};
973   for (auto m : dims) {
974     for (auto n : dims) {
975       torch::Tensor a = torch::rand(
976           {m, n}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
977       auto b = torch::svd(a, /*some=*/true, /*compute_uv=*/true);
978       ForEachDevice([&](const torch::Device& device) {
979         torch::Tensor lazy_a = CopyToDevice(a, device);
980         auto lazy_b = torch::svd(lazy_a, /*some=*/true, /*compute_uv=*/true);
981         // The U and V matrices might have different sign for column vectors, so
982         // cannot be compared if not by absolute value.
983         AllClose(
984             std::get<0>(b).abs(),
985             std::get<0>(lazy_b).abs(),
986             /*rtol=*/1e-3,
987             /*atol=*/1e-4);
988         torch::Tensor diag = std::get<1>(b);
989         torch::Tensor lazy_diag = std::get<1>(lazy_b);
990         ASSERT_EQ(diag.sizes(), lazy_diag.sizes());
991         AllClose(
992             diag,
993             lazy_diag,
994             /*rtol=*/1e-3,
995             /*atol=*/1e-4);
996         AllClose(
997             std::get<2>(b).abs(),
998             std::get<2>(lazy_b).abs(),
999             /*rtol=*/1e-3,
1000             /*atol=*/1e-4);
1001       });
1002     }
1003   }
1004 }
1005 
TEST_F(LazyOpsTest,TestQR)1006 TEST_F(LazyOpsTest, TestQR) {
1007   static const int dims[] = {4, 7};
1008   for (auto m : dims) {
1009     for (auto n : dims) {
1010       torch::Tensor a = torch::rand(
1011           {m, n}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1012       auto b = torch::qr(a);
1013       ForEachDevice([&](const torch::Device& device) {
1014         torch::Tensor lazy_a = CopyToDevice(a, device);
1015         auto lazy_b = torch::qr(lazy_a);
1016         AllClose(
1017             std::get<0>(b).abs(),
1018             std::get<0>(lazy_b).abs(),
1019             /*rtol=*/1e-3,
1020             /*atol=*/1e-4);
1021         AllClose(
1022             std::get<1>(b).abs(),
1023             std::get<1>(lazy_b).abs(),
1024             /*rtol=*/1e-3,
1025             /*atol=*/1e-4);
1026       });
1027     }
1028   }
1029 }
1030 
TEST_F(LazyOpsTest,TestCholesky)1031 TEST_F(LazyOpsTest, TestCholesky) {
1032   static const int dims[] = {4, 7};
1033   for (auto m : dims) {
1034     for (bool upper : {true, false}) {
1035       torch::Tensor a = torch::rand(
1036           {3, m, m},
1037           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1038       torch::Tensor pd_a =
1039           torch::matmul(a, torch::transpose(a, 1, 2)) +
1040           torch::eye(
1041               m, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1042       auto b = torch::cholesky(pd_a, upper);
1043       ForEachDevice([&](const torch::Device& device) {
1044         torch::Tensor lazy_a = CopyToDevice(pd_a, device);
1045         auto lazy_b = torch::cholesky(lazy_a, upper);
1046         AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-4);
1047       });
1048     }
1049   }
1050 }
1051 
TEST_F(LazyOpsTest,TestLogDet)1052 TEST_F(LazyOpsTest, TestLogDet) {
1053   static const int dims[] = {4, 7};
1054   for (auto m : dims) {
1055     torch::Tensor a = torch::rand(
1056         {3, m, m}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1057     torch::Tensor pd_a = torch::matmul(a, torch::transpose(a, 1, 2)) +
1058         torch::eye(m,
1059                    torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1060     torch::Tensor b = torch::logdet(pd_a);
1061     ForEachDevice([&](const torch::Device& device) {
1062       torch::Tensor lazy_a = CopyToDevice(pd_a, device);
1063       torch::Tensor lazy_b = torch::logdet(lazy_a);
1064       AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-4);
1065     });
1066   }
1067 }
1068 
TEST_F(LazyOpsTest,TestTriangularSolve)1069 TEST_F(LazyOpsTest, TestTriangularSolve) {
1070   static const int dims[] = {4, 7};
1071   for (bool batched_a : {true, false}) {
1072     for (bool batched_b : {true, false}) {
1073       for (auto m : dims) {
1074         for (auto n : dims) {
1075           for (bool upper : {true, false}) {
1076             for (bool transpose : {true, false}) {
1077               for (bool unitriangular : {true, false}) {
1078                 torch::Tensor a = torch::randn(
1079                     {m, m},
1080                     torch::TensorOptions(torch::kFloat)
1081                         .device(DefaultDevice()));
1082                 torch::Tensor b = torch::randn(
1083                     {m, n},
1084                     torch::TensorOptions(torch::kFloat)
1085                         .device(DefaultDevice()));
1086                 a = batched_a ? a.expand({3, m, m}).clone() : a;
1087                 b = batched_b ? b.expand({3, m, n}).clone() : b;
1088                 auto result = torch::triangular_solve(
1089                     b,
1090                     a,
1091                     /*upper=*/upper,
1092                     /*transpose=*/transpose,
1093                     /*unitriangular=*/unitriangular);
1094                 ForEachDevice([&](const torch::Device& device) {
1095                   torch::Tensor lazy_a = CopyToDevice(a, device);
1096                   torch::Tensor lazy_b = CopyToDevice(b, device);
1097                   auto lazy_result = torch::triangular_solve(
1098                       lazy_b,
1099                       lazy_a,
1100                       /*upper=*/upper,
1101                       /*transpose=*/transpose,
1102                       /*unitriangular=*/unitriangular);
1103                   AllClose(
1104                       std::get<0>(result),
1105                       std::get<0>(lazy_result),
1106                       /*rtol=*/1e-3,
1107                       /*atol=*/1e-4);
1108                   AllClose(
1109                       std::get<1>(result),
1110                       std::get<1>(lazy_result),
1111                       /*rtol=*/1e-3,
1112                       /*atol=*/1e-4);
1113                 });
1114               }
1115             }
1116           }
1117         }
1118       }
1119     }
1120   }
1121 }
1122 
TEST_F(LazyOpsTest,TestKthValue)1123 TEST_F(LazyOpsTest, TestKthValue) {
1124   torch::Tensor a = torch::rand(
1125       {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1126   for (int k = 1; k <= 3; ++k) {
1127     int rank = a.dim();
1128     for (int dim = -rank; dim < rank; ++dim) {
1129       for (bool keepdim : {false, true}) {
1130         auto b = torch::kthvalue(a, k, dim, keepdim);
1131         ForEachDevice([&](const torch::Device& device) {
1132           torch::Tensor lazy_a = CopyToDevice(a, device);
1133           auto lazy_b = torch::kthvalue(lazy_a, k, dim, keepdim);
1134           AllClose(std::get<0>(b), std::get<0>(lazy_b));
1135           AllEqual(std::get<1>(b), std::get<1>(lazy_b));
1136         });
1137       }
1138     }
1139   }
1140 }
1141 
TEST_F(LazyOpsTest,TestTopK)1142 TEST_F(LazyOpsTest, TestTopK) {
1143   torch::Tensor a = torch::rand(
1144       {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1145   for (int k = 1; k <= 3; ++k) {
1146     int rank = a.dim();
1147     for (int dim = -rank; dim < rank; ++dim) {
1148       for (bool largest : {false, true}) {
1149         auto b = torch::topk(a, k, dim, largest, /*sorted=*/true);
1150         ForEachDevice([&](const torch::Device& device) {
1151           torch::Tensor lazy_a = CopyToDevice(a, device);
1152           auto lazy_b = torch::topk(lazy_a, k, dim, largest, /*sorted=*/true);
1153           AllClose(std::get<0>(b), std::get<0>(lazy_b));
1154           AllEqual(std::get<1>(b), std::get<1>(lazy_b));
1155         });
1156       }
1157     }
1158   }
1159 }
1160 
TEST_F(LazyOpsTest,TestSort)1161 TEST_F(LazyOpsTest, TestSort) {
1162   torch::Tensor a = torch::rand(
1163       {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1164   for (int k = 1; k <= 3; ++k) {
1165     for (int dim = 0; dim < 3; ++dim) {
1166       for (bool descending : {false, true}) {
1167         auto b = torch::sort(a, dim, descending);
1168         ForEachDevice([&](const torch::Device& device) {
1169           torch::Tensor lazy_a = CopyToDevice(a, device);
1170           auto lazy_b = torch::sort(lazy_a, dim, descending);
1171           AllClose(std::get<0>(b), std::get<0>(lazy_b));
1172           AllEqual(std::get<1>(b), std::get<1>(lazy_b));
1173         });
1174       }
1175     }
1176   }
1177 }
1178 
TEST_F(LazyOpsTest,TestSortDescWithMinValue)1179 TEST_F(LazyOpsTest, TestSortDescWithMinValue) {
1180   std::vector<int8_t> values{-128, 100};
1181   torch::Tensor input =
1182       torch::tensor(values, torch::TensorOptions(torch::kChar));
1183   auto output = torch::sort(input, /*dim=*/0, /*descending=*/true);
1184   ForEachDevice([&](const torch::Device& device) {
1185     torch::Tensor lazy_input = CopyToDevice(input, device);
1186     auto lazy_output = torch::sort(lazy_input, /*dim=*/0, /*descending=*/true);
1187     AllEqual(std::get<0>(output), std::get<0>(lazy_output));
1188     AllEqual(std::get<1>(output), std::get<1>(lazy_output));
1189   });
1190 }
1191 
TEST_F(LazyOpsTest,TestArgSort)1192 TEST_F(LazyOpsTest, TestArgSort) {
1193   torch::Tensor a = torch::rand(
1194       {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1195   for (int k = 1; k <= 3; ++k) {
1196     for (int dim = 0; dim < 3; ++dim) {
1197       for (bool descending : {false, true}) {
1198         torch::Tensor b = torch::argsort(a, dim, descending);
1199         ForEachDevice([&](const torch::Device& device) {
1200           torch::Tensor lazy_a = CopyToDevice(a, device);
1201           torch::Tensor lazy_b = torch::argsort(lazy_a, dim, descending);
1202           AllEqual(b, lazy_b);
1203         });
1204       }
1205     }
1206   }
1207 }
1208 
TEST_F(LazyOpsTest,TestMin)1209 TEST_F(LazyOpsTest, TestMin) {
1210   torch::Tensor a = torch::rand(
1211       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1212   torch::Tensor b = torch::rand(
1213       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1214   torch::Tensor c = torch::min(a, b);
1215   ForEachDevice([&](const torch::Device& device) {
1216     torch::Tensor lazy_a = CopyToDevice(a, device);
1217     torch::Tensor lazy_b = CopyToDevice(b, device);
1218     torch::Tensor lazy_c = torch::min(lazy_a, lazy_b);
1219     AllClose(c, lazy_c);
1220   });
1221 }
1222 
TEST_F(LazyOpsTest,TestMax)1223 TEST_F(LazyOpsTest, TestMax) {
1224   torch::Tensor a = torch::rand(
1225       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1226   torch::Tensor b = torch::rand(
1227       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1228   torch::Tensor c = torch::max(a, b);
1229   ForEachDevice([&](const torch::Device& device) {
1230     torch::Tensor lazy_a = CopyToDevice(a, device);
1231     torch::Tensor lazy_b = CopyToDevice(b, device);
1232     torch::Tensor lazy_c = torch::max(lazy_a, lazy_b);
1233     AllClose(c, lazy_c);
1234   });
1235 }
1236 
TEST_F(LazyOpsTest,TestUnaryMin)1237 TEST_F(LazyOpsTest, TestUnaryMin) {
1238   torch::Tensor input = torch::rand(
1239       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1240   torch::Tensor output = torch::min(input);
1241   ForEachDevice([&](const torch::Device& device) {
1242     torch::Tensor lazy_input = CopyToDevice(input, device);
1243     torch::Tensor lazy_output = torch::min(lazy_input);
1244     AllClose(output, lazy_output);
1245   });
1246 }
1247 
TEST_F(LazyOpsTest,TestUnaryMax)1248 TEST_F(LazyOpsTest, TestUnaryMax) {
1249   torch::Tensor input = torch::rand(
1250       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1251   torch::Tensor output = torch::max(input);
1252   ForEachDevice([&](const torch::Device& device) {
1253     torch::Tensor lazy_input = CopyToDevice(input, device);
1254     torch::Tensor lazy_output = torch::max(lazy_input);
1255     AllClose(output, lazy_output);
1256   });
1257 }
1258 
TEST_F(LazyOpsTest,TestAll)1259 TEST_F(LazyOpsTest, TestAll) {
1260   for (torch::ScalarType scalar_type :
1261        {torch::kFloat,
1262         torch::kByte,
1263         torch::kChar,
1264         torch::kShort,
1265         torch::kInt,
1266         torch::kLong}) {
1267     torch::Tensor a = isFloatingType(scalar_type)
1268         ? torch::rand(
1269               {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
1270         : torch::randint(
1271               100,
1272               {3, 4},
1273               torch::TensorOptions(scalar_type).device(DefaultDevice()));
1274     torch::Tensor b = torch::all(a);
1275     ForEachDevice([&](const torch::Device& device) {
1276       torch::Tensor lazy_a = CopyToDevice(a, device);
1277       torch::Tensor lazy_b = torch::all(lazy_a);
1278       EqualValues(b, lazy_b);
1279     });
1280   }
1281 }
1282 
TEST_F(LazyOpsTest,TestAllDim)1283 TEST_F(LazyOpsTest, TestAllDim) {
1284   torch::Tensor a = torch::randint(
1285       0,
1286       5,
1287       {2, 3, 4},
1288       torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1289   int rank = a.dim();
1290   for (int dim = -rank; dim < rank; ++dim) {
1291     torch::Tensor b = torch::all(a, dim, /*keepdim=*/false);
1292     ForEachDevice([&](const torch::Device& device) {
1293       torch::Tensor lazy_a = CopyToDevice(a, device);
1294       torch::Tensor lazy_b = torch::all(lazy_a, dim, /*keepdim=*/false);
1295       EqualValues(b, lazy_b);
1296     });
1297   }
1298 }
1299 
TEST_F(LazyOpsTest,TestAllDimKeep)1300 TEST_F(LazyOpsTest, TestAllDimKeep) {
1301   torch::Tensor a = torch::randint(
1302       0,
1303       5,
1304       {2, 3, 4},
1305       torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1306   int rank = a.dim();
1307   for (int dim = -rank; dim < rank; ++dim) {
1308     torch::Tensor b = torch::all(a, dim, /*keepdim=*/true);
1309     ForEachDevice([&](const torch::Device& device) {
1310       torch::Tensor lazy_a = CopyToDevice(a, device);
1311       torch::Tensor lazy_b = torch::all(lazy_a, dim, /*keepdim=*/true);
1312       EqualValues(b, lazy_b);
1313     });
1314   }
1315 }
1316 
TEST_F(LazyOpsTest,TestAmax)1317 TEST_F(LazyOpsTest, TestAmax) {
1318   torch::Tensor input = torch::rand(
1319       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1320   int rank = input.dim();
1321   for (bool keepdim : {false, true}) {
1322     for (int dim = -rank; dim < rank; ++dim) {
1323       torch::Tensor values = torch::amax(input, {dim}, /*keepdim=*/keepdim);
1324       ForEachDevice([&](const torch::Device& device) {
1325         torch::Tensor lazy_input = CopyToDevice(input, device);
1326         torch::Tensor lazy_values =
1327             torch::amax(lazy_input, {dim}, /*keepdim=*/keepdim);
1328         AllClose(values, lazy_values);
1329       });
1330     }
1331     for (int dim1 = -rank; dim1 < rank; ++dim1) {
1332       for (int dim2 = -rank; dim2 < rank; ++dim2) {
1333         if ((dim1 == dim2) || (dim1 == rank + dim2) || (dim2 == rank + dim1))
1334           continue;
1335         torch::Tensor values =
1336             torch::amax(input, {dim1, dim2}, /*keepdim=*/keepdim);
1337         ForEachDevice([&](const torch::Device& device) {
1338           torch::Tensor lazy_input = CopyToDevice(input, device);
1339           torch::Tensor lazy_values =
1340               torch::amax(lazy_input, {dim1, dim2}, /*keepdim=*/keepdim);
1341           AllClose(values, lazy_values);
1342         });
1343       }
1344     }
1345   }
1346   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
1347   ExpectCounterChanged("xla::amax", GetIgnoredCounters());
1348 }
1349 
TEST_F(LazyOpsTest,TestAmin)1350 TEST_F(LazyOpsTest, TestAmin) {
1351   torch::Tensor input = torch::rand(
1352       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1353   int rank = input.dim();
1354   for (bool keepdim : {false, true}) {
1355     for (int dim = -rank; dim < rank; ++dim) {
1356       torch::Tensor values = torch::amin(input, {dim}, /*keepdim=*/keepdim);
1357       ForEachDevice([&](const torch::Device& device) {
1358         torch::Tensor lazy_input = CopyToDevice(input, device);
1359         torch::Tensor lazy_values =
1360             torch::amin(lazy_input, {dim}, /*keepdim=*/keepdim);
1361         AllClose(values, lazy_values);
1362       });
1363     }
1364     for (int dim1 = -rank; dim1 < rank; ++dim1) {
1365       for (int dim2 = -rank; dim2 < rank; ++dim2) {
1366         if ((dim1 == dim2) || (dim1 == rank + dim2) || (dim2 == rank + dim1))
1367           continue;
1368         torch::Tensor values =
1369             torch::amin(input, {dim1, dim2}, /*keepdim=*/keepdim);
1370         ForEachDevice([&](const torch::Device& device) {
1371           torch::Tensor lazy_input = CopyToDevice(input, device);
1372           torch::Tensor lazy_values =
1373               torch::amin(lazy_input, {dim1, dim2}, /*keepdim=*/keepdim);
1374           AllClose(values, lazy_values);
1375         });
1376       }
1377     }
1378   }
1379   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
1380   ExpectCounterChanged("xla::amin", GetIgnoredCounters());
1381 }
1382 
TEST_F(LazyOpsTest,TestAny)1383 TEST_F(LazyOpsTest, TestAny) {
1384   for (torch::ScalarType scalar_type :
1385        {torch::kFloat,
1386         torch::kByte,
1387         torch::kChar,
1388         torch::kShort,
1389         torch::kInt,
1390         torch::kLong}) {
1391     torch::Tensor a = isFloatingType(scalar_type)
1392         ? torch::rand(
1393               {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
1394         : torch::randint(
1395               100,
1396               {3, 4},
1397               torch::TensorOptions(scalar_type).device(DefaultDevice()));
1398     torch::Tensor b = torch::any(a);
1399     ForEachDevice([&](const torch::Device& device) {
1400       torch::Tensor lazy_a = CopyToDevice(a, device);
1401       torch::Tensor lazy_b = torch::any(lazy_a);
1402       EqualValues(b, lazy_b);
1403     });
1404   }
1405 }
1406 
TEST_F(LazyOpsTest,TestAnyDim)1407 TEST_F(LazyOpsTest, TestAnyDim) {
1408   torch::Tensor a = torch::randint(
1409       0,
1410       5,
1411       {2, 3, 4},
1412       torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1413   int rank = a.dim();
1414   for (int dim = -rank; dim < rank; ++dim) {
1415     torch::Tensor b = torch::any(a, dim, /*keepdim=*/false);
1416     ForEachDevice([&](const torch::Device& device) {
1417       torch::Tensor lazy_a = CopyToDevice(a, device);
1418       torch::Tensor lazy_b = torch::any(lazy_a, dim, /*keepdim=*/false);
1419       EqualValues(b, lazy_b);
1420     });
1421   }
1422 }
1423 
TEST_F(LazyOpsTest,TestAnyDimKeep)1424 TEST_F(LazyOpsTest, TestAnyDimKeep) {
1425   torch::Tensor a = torch::randint(
1426       0,
1427       5,
1428       {2, 3, 4},
1429       torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1430   int rank = a.dim();
1431   for (int dim = -rank; dim < rank; ++dim) {
1432     torch::Tensor b = torch::any(a, dim, /*keepdim=*/true);
1433     ForEachDevice([&](const torch::Device& device) {
1434       torch::Tensor lazy_a = CopyToDevice(a, device);
1435       torch::Tensor lazy_b = torch::any(lazy_a, dim, /*keepdim=*/true);
1436       EqualValues(b, lazy_b);
1437     });
1438   }
1439 }
1440 
TEST_F(LazyOpsTest,TestMean)1441 TEST_F(LazyOpsTest, TestMean) {
1442   torch::Tensor a = torch::rand(
1443       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1444   torch::Tensor b = torch::mean(a);
1445   ForEachDevice([&](const torch::Device& device) {
1446     torch::Tensor lazy_a = CopyToDevice(a, device);
1447     torch::Tensor lazy_b = torch::mean(lazy_a);
1448     ASSERT_EQ(b.sizes(), lazy_b.sizes());
1449     AllClose(b, lazy_b);
1450   });
1451 }
1452 
TEST_F(LazyOpsTest,TestMeanCast)1453 TEST_F(LazyOpsTest, TestMeanCast) {
1454   torch::Tensor a = torch::rand(
1455       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1456   torch::Tensor b = torch::mean(a, torch::kDouble);
1457   ForEachDevice([&](const torch::Device& device) {
1458     torch::Tensor lazy_a = CopyToDevice(a, device);
1459     torch::Tensor lazy_b = torch::mean(lazy_a, torch::kDouble);
1460     AllClose(b, lazy_b);
1461   });
1462 }
1463 
TEST_F(LazyOpsTest,TestMeanInDim)1464 TEST_F(LazyOpsTest, TestMeanInDim) {
1465   torch::Tensor a = torch::rand(
1466       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1467   int rank = a.dim();
1468   for (int dim = -rank; dim < rank; ++dim) {
1469     torch::Tensor b = torch::mean(a, {dim});
1470     ForEachDevice([&](const torch::Device& device) {
1471       torch::Tensor lazy_a = CopyToDevice(a, device);
1472       torch::Tensor lazy_b = torch::mean(lazy_a, {dim});
1473       AllClose(b, lazy_b);
1474     });
1475   }
1476 }
1477 
TEST_F(LazyOpsTest,TestMeanInDims)1478 TEST_F(LazyOpsTest, TestMeanInDims) {
1479   torch::Tensor a = torch::rand(
1480       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1481   for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1482     torch::Tensor b = torch::mean(a, dims);
1483     ForEachDevice([&](const torch::Device& device) {
1484       torch::Tensor lazy_a = CopyToDevice(a, device);
1485       torch::Tensor lazy_b = torch::mean(lazy_a, dims);
1486       AllClose(b, lazy_b);
1487     });
1488   }
1489 }
1490 
TEST_F(LazyOpsTest,TestMeanInDimsKeepCast)1491 TEST_F(LazyOpsTest, TestMeanInDimsKeepCast) {
1492   torch::Tensor a = torch::rand(
1493       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1494   for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1495     torch::Tensor b = torch::mean(a, dims, true, torch::kDouble);
1496     ForEachDevice([&](const torch::Device& device) {
1497       torch::Tensor lazy_a = CopyToDevice(a, device);
1498       torch::Tensor lazy_b = torch::mean(lazy_a, dims, true, torch::kDouble);
1499       AllClose(b, lazy_b);
1500     });
1501   }
1502 }
1503 
TEST_F(LazyOpsTest,TestMeanInDimOut)1504 TEST_F(LazyOpsTest, TestMeanInDimOut) {
1505   torch::Tensor a = torch::rand(
1506       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1507   int rank = a.dim();
1508   for (int dim = -rank; dim < rank; ++dim) {
1509     torch::Tensor b = torch::empty(
1510         {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1511     torch::mean_out(b, a, {dim});
1512     ForEachDevice([&](const torch::Device& device) {
1513       torch::Tensor lazy_a = CopyToDevice(a, device);
1514       torch::Tensor lazy_b = torch::empty({4, 4}, lazy_a.options());
1515       torch::mean_out(lazy_b, lazy_a, {dim});
1516       AllClose(b, lazy_b);
1517     });
1518   }
1519 }
1520 
TEST_F(LazyOpsTest,TestStd)1521 TEST_F(LazyOpsTest, TestStd) {
1522   torch::Tensor a = torch::rand(
1523       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1524   for (auto unbiased : {true, false}) {
1525     torch::Tensor b = torch::std(a, unbiased);
1526     ForEachDevice([&](const torch::Device& device) {
1527       torch::Tensor lazy_a = CopyToDevice(a, device);
1528       torch::Tensor lazy_b = torch::std(lazy_a, unbiased);
1529       AllClose(b, lazy_b);
1530     });
1531   }
1532 }
1533 
TEST_F(LazyOpsTest,TestStdInDim)1534 TEST_F(LazyOpsTest, TestStdInDim) {
1535   torch::Tensor a = torch::rand(
1536       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1537   int rank = a.dim();
1538   for (auto unbiased : {true, false}) {
1539     for (auto keepdim : {true, false}) {
1540       for (int dim = -rank; dim < rank; ++dim) {
1541         torch::Tensor b = torch::std(a, {dim}, unbiased, keepdim);
1542         ForEachDevice([&](const torch::Device& device) {
1543           torch::Tensor lazy_a = CopyToDevice(a, device);
1544           torch::Tensor lazy_b = torch::std(lazy_a, {dim}, unbiased, keepdim);
1545           AllClose(b, lazy_b);
1546         });
1547       }
1548     }
1549   }
1550 }
1551 
TEST_F(LazyOpsTest,TestStdWithCorrection)1552 TEST_F(LazyOpsTest, TestStdWithCorrection) {
1553   torch::Tensor a = torch::rand(
1554       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1555   // int rank = a.dim();
1556   std::optional<c10::Scalar> corrections[] = {1, 2, std::nullopt};
1557   for (const auto& correction : corrections) {
1558     for (auto keepdim : {true, false}) {
1559       for (const auto& dim :
1560            std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1561         torch::Tensor b = torch::std(a, dim, correction, keepdim);
1562         ForEachDevice([&](const torch::Device& device) {
1563           torch::Tensor lazy_a = CopyToDevice(a, device);
1564           torch::Tensor lazy_b = torch::std(lazy_a, dim, correction, keepdim);
1565           AllClose(b, lazy_b);
1566         });
1567       }
1568     }
1569   }
1570 }
1571 
TEST_F(LazyOpsTest,TestStdMeanWithCorrection)1572 TEST_F(LazyOpsTest, TestStdMeanWithCorrection) {
1573   torch::Tensor a = torch::rand(
1574       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1575   // int rank = a.dim();
1576   std::optional<c10::Scalar> corrections[] = {1, 2, std::nullopt};
1577   for (const auto& correction : corrections) {
1578     for (auto keepdim : {true, false}) {
1579       for (const auto& dim :
1580            std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1581         auto b = torch::std_mean(a, dim, correction, keepdim);
1582         ForEachDevice([&](const torch::Device& device) {
1583           torch::Tensor lazy_a = CopyToDevice(a, device);
1584           auto lazy_b = torch::std_mean(lazy_a, dim, correction, keepdim);
1585           AllClose(std::get<0>(b), std::get<0>(lazy_b));
1586           AllClose(std::get<1>(b), std::get<1>(lazy_b));
1587         });
1588       }
1589     }
1590   }
1591 }
1592 
TEST_F(LazyOpsTest,TestSum)1593 TEST_F(LazyOpsTest, TestSum) {
1594   torch::Tensor a = torch::rand(
1595       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1596   torch::Tensor b = torch::sum(a);
1597   ForEachDevice([&](const torch::Device& device) {
1598     torch::Tensor lazy_a = CopyToDevice(a, device);
1599     torch::Tensor lazy_b = torch::sum(lazy_a);
1600     AllClose(b, lazy_b);
1601   });
1602 }
1603 
TEST_F(LazyOpsTest,TestSumCast)1604 TEST_F(LazyOpsTest, TestSumCast) {
1605   torch::Tensor a = torch::rand(
1606       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1607   torch::Tensor b = torch::sum(a, torch::kDouble);
1608   ForEachDevice([&](const torch::Device& device) {
1609     torch::Tensor lazy_a = CopyToDevice(a, device);
1610     torch::Tensor lazy_b = torch::sum(lazy_a, torch::kDouble);
1611     AllClose(b, lazy_b);
1612   });
1613 }
1614 
TEST_F(LazyOpsTest,TestSumU8)1615 TEST_F(LazyOpsTest, TestSumU8) {
1616   torch::Tensor a = torch::ones(
1617       {256}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1618   torch::Tensor b = torch::sum(a);
1619   ForEachDevice([&](const torch::Device& device) {
1620     torch::Tensor lazy_a = CopyToDevice(a, device);
1621     torch::Tensor lazy_b = torch::sum(lazy_a);
1622     AllEqual(b, lazy_b);
1623   });
1624 }
1625 
TEST_F(LazyOpsTest,TestSumInDim)1626 TEST_F(LazyOpsTest, TestSumInDim) {
1627   torch::Tensor a = torch::rand(
1628       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1629   int rank = a.dim();
1630   for (int dim = -rank; dim < rank; ++dim) {
1631     torch::Tensor b = torch::sum(a, {dim});
1632     ForEachDevice([&](const torch::Device& device) {
1633       torch::Tensor lazy_a = CopyToDevice(a, device);
1634       torch::Tensor lazy_b = torch::sum(lazy_a, {dim});
1635       AllClose(b, lazy_b);
1636     });
1637   }
1638 }
1639 
TEST_F(LazyOpsTest,TestSumInDims)1640 TEST_F(LazyOpsTest, TestSumInDims) {
1641   torch::Tensor a = torch::rand(
1642       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1643   for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1644     torch::Tensor b = torch::sum(a, dims);
1645     ForEachDevice([&](const torch::Device& device) {
1646       torch::Tensor lazy_a = CopyToDevice(a, device);
1647       torch::Tensor lazy_b = torch::sum(lazy_a, dims);
1648       AllClose(b, lazy_b);
1649     });
1650   }
1651 }
1652 
TEST_F(LazyOpsTest,TestSumInDimsKeep)1653 TEST_F(LazyOpsTest, TestSumInDimsKeep) {
1654   torch::Tensor a = torch::rand(
1655       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1656   for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1657     torch::Tensor b = torch::sum(a, dims, /*keepdim=*/true);
1658     ForEachDevice([&](const torch::Device& device) {
1659       torch::Tensor lazy_a = CopyToDevice(a, device);
1660       torch::Tensor lazy_b = torch::sum(lazy_a, dims, /*keepdim=*/true);
1661       AllClose(b, lazy_b);
1662     });
1663   }
1664 }
1665 
TEST_F(LazyOpsTest,TestSumInDimsKeepCast)1666 TEST_F(LazyOpsTest, TestSumInDimsKeepCast) {
1667   torch::Tensor a = torch::rand(
1668       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1669   for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1670     torch::Tensor b = torch::sum(a, dims, /*keepdim=*/true, torch::kDouble);
1671     ForEachDevice([&](const torch::Device& device) {
1672       torch::Tensor lazy_a = CopyToDevice(a, device);
1673       torch::Tensor lazy_b =
1674           torch::sum(lazy_a, dims, /*keepdim=*/true, torch::kDouble);
1675       AllClose(b, lazy_b);
1676     });
1677   }
1678 }
1679 
TEST_F(LazyOpsTest,TestVar)1680 TEST_F(LazyOpsTest, TestVar) {
1681   torch::Tensor a = torch::rand(
1682       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1683   for (bool unbiased : {true, false}) {
1684     torch::Tensor b = torch::var(a, unbiased);
1685     ForEachDevice([&](const torch::Device& device) {
1686       torch::Tensor lazy_a = CopyToDevice(a, device);
1687       torch::Tensor lazy_b = torch::var(lazy_a, unbiased);
1688       AllClose(b, lazy_b);
1689     });
1690   }
1691 }
1692 
TEST_F(LazyOpsTest,TestVarWithDim)1693 TEST_F(LazyOpsTest, TestVarWithDim) {
1694   torch::Tensor a = torch::rand(
1695       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1696   for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1697     for (bool keepDim : {true, false}) {
1698       for (bool unbiased : {true, false}) {
1699         torch::Tensor b = torch::var(a, dims, unbiased, keepDim);
1700         ForEachDevice([&](const torch::Device& device) {
1701           torch::Tensor lazy_a = CopyToDevice(a, device);
1702           torch::Tensor lazy_b = torch::var(lazy_a, dims, unbiased, keepDim);
1703           AllClose(b, lazy_b);
1704         });
1705       }
1706     }
1707   }
1708 }
1709 
TEST_F(LazyOpsTest,TestVarWithCorrection)1710 TEST_F(LazyOpsTest, TestVarWithCorrection) {
1711   torch::Tensor a = torch::rand(
1712       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1713   std::optional<c10::Scalar> corrections[] = {1, 2, std::nullopt};
1714   for (const auto& dim : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1715     for (bool keepDim : {true, false}) {
1716       for (const auto& correction : corrections) {
1717         torch::Tensor b = torch::var(a, dim, correction, keepDim);
1718         ForEachDevice([&](const torch::Device& device) {
1719           torch::Tensor lazy_a = CopyToDevice(a, device);
1720           torch::Tensor lazy_b = torch::var(lazy_a, dim, correction, keepDim);
1721           AllClose(b, lazy_b);
1722         });
1723       }
1724     }
1725   }
1726   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
1727   ExpectCounterChanged("lazy::var", GetIgnoredCounters());
1728 }
1729 
TEST_F(LazyOpsTest,TestVarMeanWithCorrection)1730 TEST_F(LazyOpsTest, TestVarMeanWithCorrection) {
1731   torch::Tensor a = torch::rand(
1732       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1733   std::optional<c10::Scalar> corrections[] = {1, 2, std::nullopt};
1734   for (const auto& dim : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1735     for (const auto& correction : corrections) {
1736       for (auto keepdim : {true, false}) {
1737         auto b = torch::var_mean(a, dim, correction, keepdim);
1738         ForEachDevice([&](const torch::Device& device) {
1739           torch::Tensor lazy_a = CopyToDevice(a, device);
1740           auto lazy_b = torch::var_mean(lazy_a, dim, correction, keepdim);
1741           AllClose(std::get<0>(b), std::get<0>(lazy_b));
1742           AllClose(std::get<1>(b), std::get<1>(lazy_b));
1743         });
1744       }
1745     }
1746   }
1747 }
1748 
TEST_F(LazyOpsTest,TestMaxInDim)1749 TEST_F(LazyOpsTest, TestMaxInDim) {
1750   torch::Tensor input = torch::rand(
1751       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1752   int rank = input.dim();
1753   for (int dim = -rank; dim < rank; ++dim) {
1754     for (bool keepdim : {false, true}) {
1755       auto values_indices = torch::max(input, dim, /*keepdim=*/keepdim);
1756       ForEachDevice([&](const torch::Device& device) {
1757         torch::Tensor lazy_input = CopyToDevice(input, device);
1758         auto lazy_values_indices =
1759             torch::max(lazy_input, dim, /*keepdim=*/keepdim);
1760         AllClose(std::get<0>(values_indices), std::get<0>(lazy_values_indices));
1761         AllEqual(std::get<1>(values_indices), std::get<1>(lazy_values_indices));
1762       });
1763     }
1764   }
1765 }
1766 
TEST_F(LazyOpsTest,TestMinInDim)1767 TEST_F(LazyOpsTest, TestMinInDim) {
1768   torch::Tensor input = torch::rand(
1769       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1770   int rank = input.dim();
1771   for (int dim = -rank; dim < rank; ++dim) {
1772     for (bool keepdim : {false, true}) {
1773       auto values_indices = torch::min(input, dim, /*keepdim=*/keepdim);
1774       ForEachDevice([&](const torch::Device& device) {
1775         torch::Tensor lazy_input = CopyToDevice(input, device);
1776         auto lazy_values_indices =
1777             torch::min(lazy_input, dim, /*keepdim=*/keepdim);
1778         AllClose(std::get<0>(values_indices), std::get<0>(lazy_values_indices));
1779         AllEqual(std::get<1>(values_indices), std::get<1>(lazy_values_indices));
1780       });
1781     }
1782   }
1783 }
1784 
TEST_F(LazyOpsTest,TestNorm)1785 TEST_F(LazyOpsTest, TestNorm) {
1786   torch::Tensor a = torch::rand(
1787       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1788   torch::Tensor b = torch::norm(a);
1789   ForEachDevice([&](const torch::Device& device) {
1790     torch::Tensor lazy_a = CopyToDevice(a, device);
1791     torch::Tensor lazy_b = torch::norm(lazy_a);
1792     AllClose(b, lazy_b);
1793   });
1794 }
1795 
TEST_F(LazyOpsTest,TestNormInDim)1796 TEST_F(LazyOpsTest, TestNormInDim) {
1797   torch::Tensor a = torch::rand(
1798       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1799   for (int dim : {1, -2}) {
1800     torch::Tensor b = torch::norm(a, 2, {dim}, /*keepdim=*/false);
1801     ForEachDevice([&](const torch::Device& device) {
1802       torch::Tensor lazy_a = CopyToDevice(a, device);
1803       torch::Tensor lazy_b = torch::norm(lazy_a, 2, {dim}, /*keepdim=*/false);
1804       AllClose(b, lazy_b);
1805     });
1806   }
1807 }
1808 
TEST_F(LazyOpsTest,TestNormInDims)1809 TEST_F(LazyOpsTest, TestNormInDims) {
1810   torch::Tensor a = torch::rand(
1811       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1812   for (auto dims : std::vector<std::vector<int64_t>>{{1, 2}, {-2, -1}}) {
1813     torch::Tensor b = torch::norm(a, 2, dims, /*keepdim=*/false);
1814     ForEachDevice([&](const torch::Device& device) {
1815       torch::Tensor lazy_a = CopyToDevice(a, device);
1816       torch::Tensor lazy_b = torch::norm(lazy_a, 2, dims, /*keepdim=*/false);
1817       AllClose(b, lazy_b);
1818     });
1819   }
1820 }
1821 
TEST_F(LazyOpsTest,TestNormInDimsKeep)1822 TEST_F(LazyOpsTest, TestNormInDimsKeep) {
1823   torch::Tensor a = torch::rand(
1824       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1825   for (auto dims : std::vector<std::vector<int64_t>>{{1, 2}, {-2, -1}}) {
1826     torch::Tensor b = torch::norm(a, 2, dims, /*keepdim=*/true);
1827     ForEachDevice([&](const torch::Device& device) {
1828       torch::Tensor lazy_a = CopyToDevice(a, device);
1829       torch::Tensor lazy_b = torch::norm(lazy_a, 2, dims, /*keepdim=*/true);
1830       AllClose(b, lazy_b);
1831     });
1832   }
1833 }
1834 
TEST_F(LazyOpsTest,TestNormalTwoTensor)1835 TEST_F(LazyOpsTest, TestNormalTwoTensor) {
1836   at::Tensor mean = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1837   at::Tensor std = at::ones({10, 10, 10}, at::dtype(at::kFloat));
1838   ForEachDevice([&](const torch::Device& device) {
1839     at::Tensor lazy_mean = CopyToDevice(mean, device);
1840     at::Tensor lazy_std = CopyToDevice(std, device);
1841     at::Tensor lazy_normal = at::normal(lazy_mean, lazy_std);
1842     double res_mean = lazy_normal.mean().item().toDouble();
1843     double res_std = lazy_normal.std().item().toDouble();
1844     EXPECT_GT(res_mean, -0.06);
1845     EXPECT_LT(res_mean, 0.06);
1846     EXPECT_GT(res_std, 0.94);
1847     EXPECT_LT(res_std, 1.06);
1848   });
1849 }
1850 
TEST_F(LazyOpsTest,TestNormalDoubleMean)1851 TEST_F(LazyOpsTest, TestNormalDoubleMean) {
1852   at::Tensor std = at::ones({10, 10, 10}, at::dtype(at::kFloat));
1853   ForEachDevice([&](const torch::Device& device) {
1854     at::Tensor lazy_std = CopyToDevice(std, device);
1855     at::Tensor lazy_normal = at::normal(0, lazy_std);
1856     double res_mean = lazy_normal.mean().item().toDouble();
1857     double res_std = lazy_normal.std().item().toDouble();
1858     EXPECT_GT(res_mean, -0.06);
1859     EXPECT_LT(res_mean, 0.06);
1860     EXPECT_GT(res_std, 0.94);
1861     EXPECT_LT(res_std, 1.06);
1862   });
1863 }
1864 
TEST_F(LazyOpsTest,TestNormalDoubleStd)1865 TEST_F(LazyOpsTest, TestNormalDoubleStd) {
1866   at::Tensor mean = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1867   ForEachDevice([&](const torch::Device& device) {
1868     at::Tensor lazy_mean = CopyToDevice(mean, device);
1869     at::Tensor lazy_normal = at::normal(lazy_mean, 1);
1870     double res_mean = lazy_normal.mean().item().toDouble();
1871     double res_std = lazy_normal.std().item().toDouble();
1872     EXPECT_GT(res_mean, -0.06);
1873     EXPECT_LT(res_mean, 0.06);
1874     EXPECT_GT(res_std, 0.94);
1875     EXPECT_LT(res_std, 1.06);
1876   });
1877 }
1878 
TEST_F(LazyOpsTest,TestNormalInPlace)1879 TEST_F(LazyOpsTest, TestNormalInPlace) {
1880   at::Tensor a = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1881   ForEachDevice([&](const torch::Device& device) {
1882     at::Tensor lazy_a = CopyToDevice(a, device);
1883     lazy_a.normal_(/*mean=*/0, /*std=*/1);
1884     double res_mean = lazy_a.mean().item().toDouble();
1885     double res_std = lazy_a.std().item().toDouble();
1886     EXPECT_GT(res_mean, -0.06);
1887     EXPECT_LT(res_mean, 0.06);
1888     EXPECT_GT(res_std, 0.94);
1889     EXPECT_LT(res_std, 1.06);
1890   });
1891 }
1892 
TEST_F(LazyOpsTest,TestUniformInPlace)1893 TEST_F(LazyOpsTest, TestUniformInPlace) {
1894   const double eps = 1e-3;
1895   at::Tensor a = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1896   ForEachDevice([&](const torch::Device& device) {
1897     at::Tensor lazy_a = CopyToDevice(a, device);
1898     lazy_a.uniform_(/*from=*/0, /*to=*/1);
1899     at::Tensor cpu_a = ToCpuTensor(lazy_a);
1900     double res_min = cpu_a.min().item().toDouble();
1901     double res_max = cpu_a.max().item().toDouble();
1902     EXPECT_GT(res_min, 0.0 - eps);
1903     EXPECT_LT(res_max, 1.0 + eps);
1904   });
1905 }
1906 
TEST_F(LazyOpsTest,TestRandomInPlace)1907 TEST_F(LazyOpsTest, TestRandomInPlace) {
1908   for (auto dtype :
1909        {torch::kFloat,
1910         torch::kDouble,
1911         torch::kByte,
1912         torch::kChar,
1913         torch::kShort,
1914         torch::kInt,
1915         torch::kLong}) {
1916     const double eps = 0.2;
1917     torch::Tensor a = torch::zeros({10, 10, 10}, torch::TensorOptions(dtype));
1918     ForEachDevice([&](const torch::Device& device) {
1919       torch::Tensor lazy_a = CopyToDevice(a, device);
1920       lazy_a.random_(/*from=*/0, /*to=*/10);
1921       double res_mean = lazy_a.sum().item().toDouble() / a.numel();
1922       double res_min = lazy_a.min().item().toDouble();
1923       double res_max = lazy_a.max().item().toDouble();
1924       EXPECT_GT(res_mean, 4.5 - eps);
1925       EXPECT_LT(res_mean, 4.5 + eps);
1926       EXPECT_EQ(res_min, 0.0);
1927       EXPECT_EQ(res_max, 9.0);
1928     });
1929   }
1930 }
1931 
TEST_F(LazyOpsTest,TestRandomInPlaceDefaultFrom)1932 TEST_F(LazyOpsTest, TestRandomInPlaceDefaultFrom) {
1933   for (auto dtype :
1934        {torch::kFloat,
1935         torch::kDouble,
1936         torch::kByte,
1937         torch::kChar,
1938         torch::kShort,
1939         torch::kInt,
1940         torch::kLong}) {
1941     const double eps = 0.2;
1942     torch::Tensor a = torch::zeros({10, 10, 10}, torch::TensorOptions(dtype));
1943     ForEachDevice([&](const torch::Device& device) {
1944       torch::Tensor lazy_a = CopyToDevice(a, device);
1945       lazy_a.random_(/*to=*/10);
1946       double res_mean = lazy_a.sum().item().toDouble() / a.numel();
1947       double res_min = lazy_a.min().item().toDouble();
1948       double res_max = lazy_a.max().item().toDouble();
1949       EXPECT_GT(res_mean, 4.5 - eps);
1950       EXPECT_LT(res_mean, 4.5 + eps);
1951       EXPECT_EQ(res_min, 0.0);
1952       EXPECT_EQ(res_max, 9.0);
1953     });
1954   }
1955 }
1956 
TEST_F(LazyOpsTest,TestRandomInPlaceDefault)1957 TEST_F(LazyOpsTest, TestRandomInPlaceDefault) {
1958   for (auto dtype :
1959        {torch::kFloat,
1960         torch::kDouble,
1961         torch::kByte,
1962         torch::kChar,
1963         torch::kShort,
1964         torch::kInt,
1965         torch::kLong}) {
1966     auto input = torch::zeros({10}, torch::TensorOptions(dtype));
1967     ForEachDevice([&](const torch::Device& device) {
1968       auto lazyInput = CopyToDevice(input, device);
1969       lazyInput.random_();
1970       auto output = ToCpuTensor(lazyInput);
1971       EXPECT_TRUE(torch::all(output.ne(input)).item<bool>());
1972     });
1973   }
1974 }
1975 
TEST_F(LazyOpsTest,TestNormGeneral)1976 TEST_F(LazyOpsTest, TestNormGeneral) {
1977   torch::Tensor a = torch::randn(
1978       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1979   torch::Tensor b = torch::norm(a, 3.5);
1980   ForEachDevice([&](const torch::Device& device) {
1981     torch::Tensor lazy_a = CopyToDevice(a, device);
1982     torch::Tensor lazy_b = torch::norm(lazy_a, 3.5);
1983     AllClose(b, lazy_b);
1984   });
1985 }
1986 
TEST_F(LazyOpsTest,TestNormNuclear)1987 TEST_F(LazyOpsTest, TestNormNuclear) {
1988   torch::Tensor a = torch::rand(
1989       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1990   torch::Tensor b = torch::norm(a, 1);
1991   ForEachDevice([&](const torch::Device& device) {
1992     torch::Tensor lazy_a = CopyToDevice(a, device);
1993     torch::Tensor lazy_b = torch::norm(lazy_a, 1);
1994     AllClose(b, lazy_b);
1995   });
1996 }
1997 
TEST_F(LazyOpsTest,TestFrobeniusNormInDim)1998 TEST_F(LazyOpsTest, TestFrobeniusNormInDim) {
1999   torch::Tensor a = torch::rand(
2000       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2001   for (int dim : {1, -2}) {
2002     torch::Tensor b = torch::frobenius_norm(a, {dim}, /*keepdim=*/false);
2003     ForEachDevice([&](const torch::Device& device) {
2004       torch::Tensor lazy_a = CopyToDevice(a, device);
2005       torch::Tensor lazy_b =
2006           torch::frobenius_norm(lazy_a, {dim}, /*keepdim=*/false);
2007       AllClose(b, lazy_b);
2008     });
2009   }
2010 }
2011 
TEST_F(LazyOpsTest,TestFrobeniusNormInDims)2012 TEST_F(LazyOpsTest, TestFrobeniusNormInDims) {
2013   torch::Tensor a = torch::rand(
2014       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2015   for (auto dims : std::vector<std::vector<int64_t>>{{1, 2}, {-2, -1}}) {
2016     torch::Tensor b = torch::frobenius_norm(a, dims, /*keepdim=*/false);
2017     ForEachDevice([&](const torch::Device& device) {
2018       torch::Tensor lazy_a = CopyToDevice(a, device);
2019       torch::Tensor lazy_b =
2020           torch::frobenius_norm(lazy_a, dims, /*keepdim=*/false);
2021       AllClose(b, lazy_b);
2022     });
2023   }
2024 }
2025 
TEST_F(LazyOpsTest,TestGroupNorm)2026 TEST_F(LazyOpsTest, TestGroupNorm) {
2027   int num_channels = 6;
2028   torch::Tensor input = torch::rand(
2029       {20, num_channels, 10, 10},
2030       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2031   torch::Tensor weight = torch::rand(
2032       {num_channels},
2033       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2034   torch::Tensor bias = torch::rand(
2035       {num_channels},
2036       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2037   double eps = 1e-05;
2038   for (int num_groups : {3, 6, 1}) {
2039     torch::Tensor output = torch::group_norm(
2040         input,
2041         num_groups,
2042         weight,
2043         bias,
2044         eps,
2045         /*cudnn_enabled=*/false);
2046     ForEachDevice([&](const torch::Device& device) {
2047       torch::Tensor lazy_input = CopyToDevice(input, device);
2048       torch::Tensor lazy_weight = CopyToDevice(weight, device);
2049       torch::Tensor lazy_bias = CopyToDevice(bias, device);
2050       torch::Tensor lazy_output = torch::group_norm(
2051           lazy_input,
2052           num_groups,
2053           lazy_weight,
2054           lazy_bias,
2055           eps,
2056           /*cudnn_enabled=*/false);
2057       AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
2058     });
2059   }
2060 }
2061 
TEST_F(LazyOpsTest,TestGroupNormBackward)2062 TEST_F(LazyOpsTest, TestGroupNormBackward) {
2063   int num_channels = 6;
2064   torch::Tensor input = torch::rand(
2065       {2, num_channels, 5, 5},
2066       torch::TensorOptions(torch::kFloat)
2067           .device(DefaultDevice())
2068           .requires_grad(true));
2069   torch::Tensor weight = torch::rand(
2070       {num_channels},
2071       torch::TensorOptions(torch::kFloat)
2072           .device(DefaultDevice())
2073           .requires_grad(true));
2074   torch::Tensor bias = torch::rand(
2075       {num_channels},
2076       torch::TensorOptions(torch::kFloat)
2077           .device(DefaultDevice())
2078           .requires_grad(true));
2079   double eps = 1e-05;
2080   for (bool undef_weight : {true, false}) {
2081     for (int num_groups : {3, 6, 1}) {
2082       auto testfn =
2083           [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
2084         return torch::group_norm(
2085             /*input=*/inputs[0],
2086             num_groups,
2087             inputs[1],
2088             inputs[2],
2089             /*eps=*/eps,
2090             /*cudnn_enabled=*/false);
2091       };
2092       torch::Tensor undef;
2093       ForEachDevice([&](const torch::Device& device) {
2094         TestBackward(
2095             {input, undef_weight ? undef : weight, undef_weight ? undef : bias},
2096             device,
2097             testfn,
2098             /*rtol=*/1e-3,
2099             /*atol=*/1e-3,
2100             /*derivative_level=*/2);
2101       });
2102     }
2103   }
2104 }
2105 
TEST_F(LazyOpsTest,TestInstanceNorm)2106 TEST_F(LazyOpsTest, TestInstanceNorm) {
2107   int batch = 5;
2108   int num_channels = 20;
2109   torch::Tensor input = torch::rand(
2110       {batch, num_channels, 10, 10},
2111       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2112   torch::Tensor weight = torch::rand(
2113       {num_channels},
2114       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2115   torch::Tensor bias = torch::rand(
2116       {num_channels},
2117       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2118   torch::Tensor running_mean = torch::zeros(
2119       {num_channels},
2120       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2121   torch::Tensor running_var = torch::ones(
2122       {num_channels},
2123       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2124   double momentum = 0.1;
2125   double eps = 1e-05;
2126   torch::Tensor output = torch::instance_norm(
2127       input,
2128       weight,
2129       bias,
2130       running_mean,
2131       running_var,
2132       /*use_input_stats=*/true,
2133       momentum,
2134       eps,
2135       /*cudnn_enabled=*/false);
2136   ForEachDevice([&](const torch::Device& device) {
2137     torch::Tensor lazy_input = CopyToDevice(input, device);
2138     torch::Tensor lazy_weight = CopyToDevice(weight, device);
2139     torch::Tensor lazy_bias = CopyToDevice(bias, device);
2140     torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device);
2141     torch::Tensor lazy_running_var = CopyToDevice(running_var, device);
2142     torch::Tensor lazy_output = torch::instance_norm(
2143         lazy_input,
2144         lazy_weight,
2145         lazy_bias,
2146         lazy_running_mean,
2147         lazy_running_var,
2148         /*use_input_stats=*/true,
2149         momentum,
2150         eps,
2151         /*cudnn_enabled=*/false);
2152     AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
2153   });
2154 }
2155 
TEST_F(LazyOpsTest,TestLayerNorm)2156 TEST_F(LazyOpsTest, TestLayerNorm) {
2157   torch::Tensor input = torch::rand(
2158       {20, 10, 10, 10},
2159       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2160   double eps = 1e-05;
2161   torch::Tensor undef;
2162   for (bool undef_weight : {true, false}) {
2163     for (int64_t normalized_size : {2, 3}) {
2164       std::vector<int64_t> normalized_shape(normalized_size, 10);
2165       torch::Tensor weight = torch::rand(
2166           normalized_shape,
2167           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2168       torch::Tensor bias = torch::rand(
2169           normalized_shape,
2170           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2171       torch::Tensor output = torch::layer_norm(
2172           input,
2173           normalized_shape,
2174           undef_weight ? undef : weight,
2175           undef_weight ? undef : bias,
2176           eps,
2177           /*cudnn_enabled=*/false);
2178       ForEachDevice([&](const torch::Device& device) {
2179         torch::Tensor lazy_input = CopyToDevice(input, device);
2180         torch::Tensor lazy_weight =
2181             undef_weight ? undef : CopyToDevice(weight, device);
2182         torch::Tensor lazy_bias =
2183             undef_weight ? undef : CopyToDevice(bias, device);
2184         torch::Tensor lazy_output = torch::layer_norm(
2185             lazy_input,
2186             normalized_shape,
2187             lazy_weight,
2188             lazy_bias,
2189             eps,
2190             /*cudnn_enabled=*/false);
2191         AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
2192       });
2193     }
2194   }
2195 }
2196 
TEST_F(LazyOpsTest,TestLayerNormBackward)2197 TEST_F(LazyOpsTest, TestLayerNormBackward) {
2198   torch::Tensor input = torch::rand(
2199       {2, 3, 3, 3},
2200       torch::TensorOptions(torch::kFloat)
2201           .device(DefaultDevice())
2202           .requires_grad(true));
2203   double eps = 1e-05;
2204   for (bool undef_weight : {true, false}) {
2205     for (int64_t normalized_size : {2, 3}) {
2206       std::vector<int64_t> normalized_shape(normalized_size, 3);
2207       auto testfn =
2208           [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
2209         return torch::layer_norm(
2210             /*input=*/inputs[0],
2211             normalized_shape,
2212             inputs[1],
2213             inputs[2],
2214             /*eps=*/eps,
2215             /*cudnn_enabled=*/false);
2216       };
2217       torch::Tensor weight = torch::rand(
2218           normalized_shape,
2219           torch::TensorOptions(torch::kFloat)
2220               .device(DefaultDevice())
2221               .requires_grad(true));
2222       torch::Tensor bias = torch::rand(
2223           normalized_shape,
2224           torch::TensorOptions(torch::kFloat)
2225               .device(DefaultDevice())
2226               .requires_grad(true));
2227       torch::Tensor undef;
2228       ForEachDevice([&](const torch::Device& device) {
2229         TestBackward(
2230             {input, undef_weight ? undef : weight, undef_weight ? undef : bias},
2231             device,
2232             testfn,
2233             /*rtol=*/1e-3,
2234             /*atol=*/1e-4,
2235             /*derivative_level=*/2);
2236       });
2237     }
2238   }
2239 }
2240 
TEST_F(LazyOpsTest,TestNuclearNorm)2241 TEST_F(LazyOpsTest, TestNuclearNorm) {
2242   torch::Tensor a = torch::rand(
2243       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2244   torch::Tensor b = torch::nuclear_norm(a);
2245   ForEachDevice([&](const torch::Device& device) {
2246     torch::Tensor lazy_a = CopyToDevice(a, device);
2247     torch::Tensor lazy_b = torch::nuclear_norm(lazy_a);
2248     AllClose(b, lazy_b);
2249   });
2250 }
2251 
TEST_F(LazyOpsTest,TestPairwiseDistance)2252 TEST_F(LazyOpsTest, TestPairwiseDistance) {
2253   torch::Tensor x1 = torch::rand(
2254       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2255   torch::Tensor x2 = torch::rand(
2256       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2257   double eps = 1e-6;
2258   for (bool keepdim : {false, true}) {
2259     for (double p : {1, 2, 3, 4}) {
2260       ForEachDevice([&](const torch::Device& device) {
2261         torch::Tensor output =
2262             torch::pairwise_distance(x1, x2, p, eps, keepdim);
2263         torch::Tensor lazy_x1 = CopyToDevice(x1, device);
2264         torch::Tensor lazy_x2 = CopyToDevice(x2, device);
2265         torch::Tensor lazy_output =
2266             torch::pairwise_distance(lazy_x1, lazy_x2, p, eps, keepdim);
2267         AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-5);
2268       });
2269     }
2270   }
2271 }
2272 
TEST_F(LazyOpsTest,TestCosineSimilarity)2273 TEST_F(LazyOpsTest, TestCosineSimilarity) {
2274   torch::Tensor x1 = torch::rand(
2275       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2276   torch::Tensor x2 = torch::rand(
2277       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2278   double eps = 1e-8;
2279   int rank = x1.dim();
2280   for (int dim = -rank; dim < rank; ++dim) {
2281     ForEachDevice([&](const torch::Device& device) {
2282       torch::Tensor output = torch::cosine_similarity(x1, x2, dim, eps);
2283       torch::Tensor lazy_x1 = CopyToDevice(x1, device);
2284       torch::Tensor lazy_x2 = CopyToDevice(x2, device);
2285       torch::Tensor lazy_output =
2286           torch::cosine_similarity(lazy_x1, lazy_x2, dim, eps);
2287       AllClose(output, lazy_output);
2288     });
2289   }
2290 }
2291 
TEST_F(LazyOpsTest,TestCosineEmbeddingLoss)2292 TEST_F(LazyOpsTest, TestCosineEmbeddingLoss) {
2293   torch::Tensor input1 = torch::rand(
2294       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2295   torch::Tensor input2 = torch::rand(
2296       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2297   torch::Tensor target = torch::rand(
2298       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2299   for (torch::Reduction::Reduction reduction :
2300        {torch::Reduction::Mean, torch::Reduction::Sum}) {
2301     for (double margin : {0., 0.2}) {
2302       ForEachDevice([&](const torch::Device& device) {
2303         torch::Tensor output = torch::cosine_embedding_loss(
2304             input1, input2, target, margin, reduction);
2305         torch::Tensor lazy_input1 = CopyToDevice(input1, device);
2306         torch::Tensor lazy_input2 = CopyToDevice(input2, device);
2307         torch::Tensor lazy_target = CopyToDevice(target, device);
2308         torch::Tensor lazy_output = torch::cosine_embedding_loss(
2309             lazy_input1, lazy_input2, lazy_target, margin, reduction);
2310         AllClose(output, lazy_output);
2311       });
2312     }
2313   }
2314 }
2315 
TEST_F(LazyOpsTest,TestHingeEmbeddingLoss)2316 TEST_F(LazyOpsTest, TestHingeEmbeddingLoss) {
2317   torch::Tensor input = torch::rand(
2318       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2319   torch::Tensor target = torch::rand(
2320       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2321   for (torch::Reduction::Reduction reduction :
2322        {torch::Reduction::Mean, torch::Reduction::Sum}) {
2323     for (double margin : {0., 0.2}) {
2324       ForEachDevice([&](const torch::Device& device) {
2325         torch::Tensor output =
2326             torch::hinge_embedding_loss(input, target, margin, reduction);
2327         torch::Tensor lazy_input = CopyToDevice(input, device);
2328         torch::Tensor lazy_target = CopyToDevice(target, device);
2329         torch::Tensor lazy_output = torch::hinge_embedding_loss(
2330             lazy_input, lazy_target, margin, reduction);
2331         AllClose(output, lazy_output);
2332       });
2333     }
2334   }
2335 }
2336 
TEST_F(LazyOpsTest,TestTripletMarginLoss)2337 TEST_F(LazyOpsTest, TestTripletMarginLoss) {
2338   torch::Tensor anchor = torch::rand(
2339       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2340   torch::Tensor positive = torch::abs(torch::rand(
2341       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
2342   torch::Tensor negative = torch::neg(torch::abs(torch::rand(
2343       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()))));
2344   double eps = 1e-6;
2345   for (double margin : {0., 0.2}) {
2346     for (double p : {1, 2, 3, 4}) {
2347       for (bool swap : {false, true}) {
2348         for (torch::Reduction::Reduction reduction :
2349              {torch::Reduction::Mean, torch::Reduction::Sum}) {
2350           ForEachDevice([&](const torch::Device& device) {
2351             torch::Tensor output = torch::triplet_margin_loss(
2352                 anchor, positive, negative, margin, p, eps, swap, reduction);
2353             torch::Tensor lazy_anchor = CopyToDevice(anchor, device);
2354             torch::Tensor lazy_positive = CopyToDevice(positive, device);
2355             torch::Tensor lazy_negative = CopyToDevice(negative, device);
2356             torch::Tensor lazy_output = torch::triplet_margin_loss(
2357                 lazy_anchor,
2358                 lazy_positive,
2359                 lazy_negative,
2360                 margin,
2361                 p,
2362                 eps,
2363                 swap,
2364                 reduction);
2365             AllClose(output, lazy_output);
2366           });
2367         }
2368       }
2369     }
2370   }
2371 }
2372 
TEST_F(LazyOpsTest,TestBinaryCrossEntropy)2373 TEST_F(LazyOpsTest, TestBinaryCrossEntropy) {
2374   int batch = 10;
2375   int classes = 5;
2376   torch::Tensor input = torch::rand(
2377       {batch, classes},
2378       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2379   torch::Tensor target = torch::rand(
2380       {batch, classes},
2381       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2382   torch::Tensor weight = torch::rand(
2383       {batch, classes},
2384       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2385   torch::Tensor undef;
2386   for (torch::Reduction::Reduction reduction :
2387        {torch::Reduction::Mean,
2388         torch::Reduction::Sum,
2389         torch::Reduction::None}) {
2390     for (bool undef_weight : {false, true}) {
2391       ForEachDevice([&](const torch::Device& device) {
2392         torch::Tensor output = torch::binary_cross_entropy(
2393             input, target, undef_weight ? undef : weight, reduction);
2394         torch::Tensor lazy_input = CopyToDevice(input, device);
2395         torch::Tensor lazy_target = CopyToDevice(target, device);
2396         torch::Tensor lazy_weight =
2397             undef_weight ? undef : CopyToDevice(weight, device);
2398         torch::Tensor lazy_output = torch::binary_cross_entropy(
2399             lazy_input, lazy_target, lazy_weight, reduction);
2400         AllClose(output, lazy_output, /*rtol=*/1e-4, /*atol=*/1e-5);
2401       });
2402     }
2403   }
2404 }
2405 
TEST_F(LazyOpsTest,TestMarginRankingLoss)2406 TEST_F(LazyOpsTest, TestMarginRankingLoss) {
2407   torch::Tensor input1 = torch::rand(
2408       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2409   torch::Tensor input2 = torch::rand(
2410       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2411   torch::Tensor target = torch::rand(
2412       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2413   for (torch::Reduction::Reduction reduction :
2414        {torch::Reduction::Mean, torch::Reduction::Sum}) {
2415     for (double margin : {0., 0.2}) {
2416       ForEachDevice([&](const torch::Device& device) {
2417         torch::Tensor output = torch::margin_ranking_loss(
2418             input1, input2, target, margin, reduction);
2419         torch::Tensor lazy_input1 = CopyToDevice(input1, device);
2420         torch::Tensor lazy_input2 = CopyToDevice(input2, device);
2421         torch::Tensor lazy_target = CopyToDevice(target, device);
2422         torch::Tensor lazy_output = torch::margin_ranking_loss(
2423             lazy_input1, lazy_input2, lazy_target, margin, reduction);
2424         AllClose(output, lazy_output);
2425       });
2426     }
2427   }
2428 }
2429 
TEST_F(LazyOpsTest,TestBCEWithLogits)2430 TEST_F(LazyOpsTest, TestBCEWithLogits) {
2431   int batch = 10;
2432   int classes = 5;
2433   torch::Tensor input = torch::rand(
2434       {batch, classes},
2435       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2436   torch::Tensor target = torch::rand(
2437       {batch, classes},
2438       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2439   torch::Tensor weight = torch::rand(
2440       {classes}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2441   torch::Tensor pos_weight = torch::rand(
2442       {classes}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2443   torch::Tensor undef;
2444   for (torch::Reduction::Reduction reduction :
2445        {torch::Reduction::Mean, torch::Reduction::Sum}) {
2446     for (bool undef_weight : {false, true}) {
2447       for (bool undef_pos_weight : {false, true}) {
2448         ForEachDevice([&](const torch::Device& device) {
2449           torch::Tensor output = torch::binary_cross_entropy_with_logits(
2450               input,
2451               target,
2452               undef_weight ? undef : weight,
2453               undef_pos_weight ? undef : pos_weight,
2454               reduction);
2455           torch::Tensor lazy_input = CopyToDevice(input, device);
2456           torch::Tensor lazy_target = CopyToDevice(target, device);
2457           torch::Tensor lazy_weight =
2458               undef_weight ? undef : CopyToDevice(weight, device);
2459           torch::Tensor lazy_pos_weight =
2460               undef_pos_weight ? undef : CopyToDevice(pos_weight, device);
2461           torch::Tensor lazy_output = torch::binary_cross_entropy_with_logits(
2462               lazy_input, lazy_target, lazy_weight, lazy_pos_weight, reduction);
2463         });
2464       }
2465     }
2466   }
2467 }
2468 
TEST_F(LazyOpsTest,TestKlDiv)2469 TEST_F(LazyOpsTest, TestKlDiv) {
2470   torch::Tensor input = torch::rand(
2471       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2472   torch::Tensor target = torch::rand(
2473       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2474   for (bool log_target : {true, false}) {
2475     for (torch::Reduction::Reduction reduction :
2476          {torch::Reduction::Mean, torch::Reduction::Sum}) {
2477       ForEachDevice([&](const torch::Device& device) {
2478         torch::Tensor output =
2479             torch::kl_div(input, target, reduction, log_target);
2480         torch::Tensor lazy_input = CopyToDevice(input, device);
2481         torch::Tensor lazy_target = CopyToDevice(target, device);
2482         torch::Tensor lazy_output =
2483             torch::kl_div(lazy_input, lazy_target, reduction, log_target);
2484         AllClose(output, lazy_output);
2485       });
2486     }
2487   }
2488 }
2489 
TEST_F(LazyOpsTest,TestProd)2490 TEST_F(LazyOpsTest, TestProd) {
2491   torch::Tensor a = torch::rand(
2492       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2493   torch::Tensor b = torch::prod(a);
2494   ForEachDevice([&](const torch::Device& device) {
2495     torch::Tensor lazy_a = CopyToDevice(a, device);
2496     torch::Tensor lazy_b = torch::prod(lazy_a);
2497     AllClose(b, lazy_b);
2498   });
2499 }
2500 
TEST_F(LazyOpsTest,TestProdCast)2501 TEST_F(LazyOpsTest, TestProdCast) {
2502   torch::Tensor a = torch::rand(
2503       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2504   torch::Tensor b = torch::prod(a, torch::kDouble);
2505   ForEachDevice([&](const torch::Device& device) {
2506     torch::Tensor lazy_a = CopyToDevice(a, device);
2507     torch::Tensor lazy_b = torch::prod(lazy_a, torch::kDouble);
2508     AllClose(b, lazy_b);
2509   });
2510 }
2511 
TEST_F(LazyOpsTest,TestProdInDim)2512 TEST_F(LazyOpsTest, TestProdInDim) {
2513   torch::Tensor a = torch::rand(
2514       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2515   int rank = a.dim();
2516   for (int dim = -rank; dim < rank; ++dim) {
2517     torch::Tensor b = torch::prod(a, dim);
2518     ForEachDevice([&](const torch::Device& device) {
2519       torch::Tensor lazy_a = CopyToDevice(a, device);
2520       torch::Tensor lazy_b = torch::prod(lazy_a, dim);
2521       AllClose(b, lazy_b);
2522     });
2523   }
2524 }
2525 
TEST_F(LazyOpsTest,TestProdInDimKeepCast)2526 TEST_F(LazyOpsTest, TestProdInDimKeepCast) {
2527   torch::Tensor a = torch::rand(
2528       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2529   int rank = a.dim();
2530   for (int dim = -rank; dim < rank; ++dim) {
2531     torch::Tensor b = torch::prod(a, dim, /*keepdim=*/true, torch::kDouble);
2532     ForEachDevice([&](const torch::Device& device) {
2533       torch::Tensor lazy_a = CopyToDevice(a, device);
2534       torch::Tensor lazy_b =
2535           torch::prod(lazy_a, dim, /*keepdim=*/true, torch::kDouble);
2536       AllClose(b, lazy_b);
2537     });
2538   }
2539 }
2540 
TEST_F(LazyOpsTest,TestProdInDimKeep)2541 TEST_F(LazyOpsTest, TestProdInDimKeep) {
2542   torch::Tensor a = torch::rand(
2543       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2544   int rank = a.dim();
2545   for (int dim = -rank; dim < rank; ++dim) {
2546     torch::Tensor b = torch::prod(a, dim, /*keepdim=*/true);
2547     ForEachDevice([&](const torch::Device& device) {
2548       torch::Tensor lazy_a = CopyToDevice(a, device);
2549       torch::Tensor lazy_b = torch::prod(lazy_a, dim, /*keepdim=*/true);
2550       AllClose(b, lazy_b);
2551     });
2552   }
2553 }
2554 
TEST_F(LazyOpsTest,TestCumSum)2555 TEST_F(LazyOpsTest, TestCumSum) {
2556   torch::Tensor input = torch::rand(
2557       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2558   int rank = input.dim();
2559   for (int dim = -rank; dim < rank; ++dim) {
2560     torch::Tensor result = torch::cumsum(input, dim);
2561     ForEachDevice([&](const torch::Device& device) {
2562       torch::Tensor lazy_input = CopyToDevice(input, device);
2563       torch::Tensor lazy_result = torch::cumsum(lazy_input, dim);
2564       AllClose(result, lazy_result);
2565     });
2566   }
2567 }
2568 
TEST_F(LazyOpsTest,TestCumSumCast)2569 TEST_F(LazyOpsTest, TestCumSumCast) {
2570   torch::Tensor input = torch::rand(
2571       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2572   int rank = input.dim();
2573   for (int dim = -rank; dim < rank; ++dim) {
2574     torch::Tensor result = torch::cumsum(input, dim, torch::kDouble);
2575     ForEachDevice([&](const torch::Device& device) {
2576       torch::Tensor lazy_input = CopyToDevice(input, device);
2577       torch::Tensor lazy_result =
2578           torch::cumsum(lazy_input, dim, torch::kDouble);
2579       AllClose(result, lazy_result);
2580     });
2581   }
2582 }
2583 
TEST_F(LazyOpsTest,TestCumSumLong)2584 TEST_F(LazyOpsTest, TestCumSumLong) {
2585   torch::Tensor input = torch::randint(
2586       1000,
2587       {4, 3, 4},
2588       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
2589   int rank = input.dim();
2590   for (int dim = -rank; dim < rank; ++dim) {
2591     torch::Tensor result = torch::cumsum(input, dim);
2592     ForEachDevice([&](const torch::Device& device) {
2593       torch::Tensor lazy_input = CopyToDevice(input, device);
2594       torch::Tensor lazy_result = torch::cumsum(lazy_input, dim);
2595       AllEqual(result, lazy_result);
2596     });
2597   }
2598 }
2599 
TEST_F(LazyOpsTest,TestCumSumCastLong)2600 TEST_F(LazyOpsTest, TestCumSumCastLong) {
2601   torch::Tensor input = torch::rand(
2602       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2603   int rank = input.dim();
2604   for (int dim = -rank; dim < rank; ++dim) {
2605     torch::Tensor result = torch::cumsum(input, dim, torch::kLong);
2606     ForEachDevice([&](const torch::Device& device) {
2607       torch::Tensor lazy_input = CopyToDevice(input, device);
2608       torch::Tensor lazy_result = torch::cumsum(lazy_input, dim, torch::kLong);
2609       AllEqual(result, lazy_result);
2610     });
2611   }
2612 }
2613 
TEST_F(LazyOpsTest,TestCumProd)2614 TEST_F(LazyOpsTest, TestCumProd) {
2615   torch::Tensor input = torch::rand(
2616       {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2617   int rank = input.dim();
2618   for (int dim = -rank; dim < rank; ++dim) {
2619     torch::Tensor result = torch::cumprod(input, dim);
2620     ForEachDevice([&](const torch::Device& device) {
2621       torch::Tensor lazy_input = CopyToDevice(input, device);
2622       torch::Tensor lazy_result = torch::cumprod(lazy_input, dim);
2623       AllClose(result, lazy_result);
2624     });
2625   }
2626 }
2627 
TEST_F(LazyOpsTest,TestCumProdCast)2628 TEST_F(LazyOpsTest, TestCumProdCast) {
2629   torch::Tensor input = torch::mul(
2630       torch::rand(
2631           {4, 3, 4},
2632           torch::TensorOptions(torch::kFloat).device(DefaultDevice())),
2633       10);
2634   int rank = input.dim();
2635   for (int dim = -rank; dim < rank; ++dim) {
2636     torch::Tensor result = torch::cumprod(input, dim, torch::kDouble);
2637     ForEachDevice([&](const torch::Device& device) {
2638       torch::Tensor lazy_input = CopyToDevice(input, device);
2639       torch::Tensor lazy_result =
2640           torch::cumprod(lazy_input, dim, torch::kDouble);
2641       AllClose(result, lazy_result);
2642     });
2643   }
2644 }
2645 
TEST_F(LazyOpsTest,TestCumProdLong)2646 TEST_F(LazyOpsTest, TestCumProdLong) {
2647   torch::Tensor input = torch::randint(
2648       7, {2, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
2649   int rank = input.dim();
2650   for (int dim = -rank; dim < rank; ++dim) {
2651     torch::Tensor result = torch::cumsum(input, dim);
2652     ForEachDevice([&](const torch::Device& device) {
2653       torch::Tensor lazy_input = CopyToDevice(input, device);
2654       torch::Tensor lazy_result = torch::cumsum(lazy_input, dim);
2655       AllEqual(result, lazy_result);
2656     });
2657   }
2658 }
2659 
TEST_F(LazyOpsTest,TestCumProdCastLong)2660 TEST_F(LazyOpsTest, TestCumProdCastLong) {
2661   torch::Tensor input =
2662       torch::rand(
2663           {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
2664       7;
2665   int rank = input.dim();
2666   for (int dim = -rank; dim < rank; ++dim) {
2667     torch::Tensor result = torch::cumsum(input, dim, torch::kLong);
2668     ForEachDevice([&](const torch::Device& device) {
2669       torch::Tensor lazy_input = CopyToDevice(input, device);
2670       torch::Tensor lazy_result = torch::cumsum(lazy_input, dim, torch::kLong);
2671       AllEqual(result, lazy_result);
2672     });
2673   }
2674 }
2675 
TEST_F(LazyOpsTest,TestArgMin)2676 TEST_F(LazyOpsTest, TestArgMin) {
2677   torch::Tensor a = torch::rand(
2678       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2679   torch::Tensor b = torch::argmin(a, std::nullopt, /*keepdim=*/false);
2680   ForEachDevice([&](const torch::Device& device) {
2681     torch::Tensor lazy_a = CopyToDevice(a, device);
2682     torch::Tensor lazy_b =
2683         torch::argmin(lazy_a, std::nullopt, /*keepdim=*/false);
2684     AllEqual(b, lazy_b);
2685   });
2686 }
2687 
TEST_F(LazyOpsTest,TestArgMinDim)2688 TEST_F(LazyOpsTest, TestArgMinDim) {
2689   torch::Tensor a = torch::rand(
2690       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2691   for (int dim : {1, -2}) {
2692     torch::Tensor b = torch::argmin(a, dim, /*keepdim=*/false);
2693     ForEachDevice([&](const torch::Device& device) {
2694       torch::Tensor lazy_a = CopyToDevice(a, device);
2695       torch::Tensor lazy_b = torch::argmin(lazy_a, dim, /*keepdim=*/false);
2696       AllEqual(b, lazy_b);
2697     });
2698   }
2699 }
2700 
TEST_F(LazyOpsTest,TestArgMinDimKeep)2701 TEST_F(LazyOpsTest, TestArgMinDimKeep) {
2702   torch::Tensor a = torch::rand(
2703       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2704   for (int dim : {1, -2}) {
2705     torch::Tensor b = torch::argmin(a, dim, /*keepdim=*/true);
2706     ForEachDevice([&](const torch::Device& device) {
2707       torch::Tensor lazy_a = CopyToDevice(a, device);
2708       torch::Tensor lazy_b = torch::argmin(lazy_a, dim, /*keepdim=*/true);
2709       AllEqual(b, lazy_b);
2710     });
2711   }
2712 }
2713 
TEST_F(LazyOpsTest,TestArgMinSameValue)2714 TEST_F(LazyOpsTest, TestArgMinSameValue) {
2715   torch::Tensor a = torch::ones(
2716       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2717   torch::Tensor b = torch::argmin(a);
2718   ForEachDevice([&](const torch::Device& device) {
2719     torch::Tensor lazy_a = CopyToDevice(a, device);
2720     torch::Tensor lazy_b = torch::argmin(lazy_a);
2721     AllEqual(b, lazy_b);
2722   });
2723 }
2724 
TEST_F(LazyOpsTest,TestArgMinWrapper)2725 TEST_F(LazyOpsTest, TestArgMinWrapper) {
2726   torch::Tensor a = torch::rand(
2727       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2728   for (int dim : {1, -2}) {
2729     torch::Tensor b = torch::argmin(a, dim, /*keepdim=*/false);
2730     ForEachDevice([&](const torch::Device& device) {
2731       torch::Tensor lazy_a = CopyToDevice(a, device);
2732       torch::Tensor lazy_b = torch::argmin(lazy_a, dim, /*keepdim=*/false);
2733       AllEqual(b, lazy_b);
2734     });
2735   }
2736 }
2737 
TEST_F(LazyOpsTest,TestArgMax)2738 TEST_F(LazyOpsTest, TestArgMax) {
2739   torch::Tensor a = torch::rand(
2740       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2741   torch::Tensor b = torch::argmax(a, std::nullopt, /*keepdim=*/false);
2742   ForEachDevice([&](const torch::Device& device) {
2743     torch::Tensor lazy_a = CopyToDevice(a, device);
2744     torch::Tensor lazy_b =
2745         torch::argmax(lazy_a, std::nullopt, /*keepdim=*/false);
2746     AllEqual(b, lazy_b);
2747   });
2748 }
2749 
TEST_F(LazyOpsTest,TestArgMaxDim)2750 TEST_F(LazyOpsTest, TestArgMaxDim) {
2751   torch::Tensor a = torch::rand(
2752       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2753   for (int dim : {1, -2}) {
2754     torch::Tensor b = torch::argmax(a, dim, /*keepdim=*/false);
2755     ForEachDevice([&](const torch::Device& device) {
2756       torch::Tensor lazy_a = CopyToDevice(a, device);
2757       torch::Tensor lazy_b = torch::argmax(lazy_a, dim, /*keepdim=*/false);
2758       AllEqual(b, lazy_b);
2759     });
2760   }
2761 }
2762 
TEST_F(LazyOpsTest,TestArgMaxDimKeep)2763 TEST_F(LazyOpsTest, TestArgMaxDimKeep) {
2764   torch::Tensor a = torch::rand(
2765       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2766   for (int dim : {1, -2}) {
2767     torch::Tensor b = torch::argmax(a, dim, /*keepdim=*/true);
2768     ForEachDevice([&](const torch::Device& device) {
2769       torch::Tensor lazy_a = CopyToDevice(a, device);
2770       torch::Tensor lazy_b = torch::argmax(lazy_a, dim, /*keepdim=*/true);
2771       AllEqual(b, lazy_b);
2772     });
2773   }
2774 }
2775 
TEST_F(LazyOpsTest,TestArgMaxSameValue)2776 TEST_F(LazyOpsTest, TestArgMaxSameValue) {
2777   torch::Tensor a = torch::ones(
2778       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2779   torch::Tensor b = torch::argmax(a, std::nullopt, /*keepdim=*/false);
2780   ForEachDevice([&](const torch::Device& device) {
2781     torch::Tensor lazy_a = CopyToDevice(a, device);
2782     torch::Tensor lazy_b =
2783         torch::argmax(lazy_a, std::nullopt, /*keepdim=*/false);
2784     AllEqual(b, lazy_b);
2785   });
2786 }
2787 
TEST_F(LazyOpsTest,TestArgMaxWrapper)2788 TEST_F(LazyOpsTest, TestArgMaxWrapper) {
2789   torch::Tensor a = torch::rand(
2790       {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2791   for (int dim : {1, -2}) {
2792     torch::Tensor b = torch::argmax(a, dim, /*keepdim=*/false);
2793     ForEachDevice([&](const torch::Device& device) {
2794       torch::Tensor lazy_a = CopyToDevice(a, device);
2795       torch::Tensor lazy_b = torch::argmax(lazy_a, dim, /*keepdim=*/false);
2796       AllEqual(b, lazy_b);
2797     });
2798   }
2799 }
2800 
TEST_F(LazyOpsTest,TestAsin)2801 TEST_F(LazyOpsTest, TestAsin) {
2802   torch::Tensor a = torch::rand(
2803       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2804   torch::Tensor b = torch::asin(a);
2805   ForEachDevice([&](const torch::Device& device) {
2806     torch::Tensor lazy_a = CopyToDevice(a, device);
2807     torch::Tensor lazy_b = torch::asin(lazy_a);
2808     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2809   });
2810 }
2811 
TEST_F(LazyOpsTest,TestAsinh)2812 TEST_F(LazyOpsTest, TestAsinh) {
2813   torch::Tensor a = torch::rand(
2814       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2815   torch::Tensor b = torch::asinh(a);
2816   ForEachDevice([&](const torch::Device& device) {
2817     torch::Tensor lazy_a = CopyToDevice(a, device);
2818     torch::Tensor lazy_b = torch::asinh(lazy_a);
2819     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2820   });
2821 }
2822 
TEST_F(LazyOpsTest,TestAsinhInPlace)2823 TEST_F(LazyOpsTest, TestAsinhInPlace) {
2824   torch::Tensor a = torch::rand(
2825       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2826   ForEachDevice([&](const torch::Device& device) {
2827     torch::Tensor lazy_a = CopyToDevice(a, device);
2828     torch::Tensor b = torch::asinh_(a);
2829     torch::Tensor lazy_b = torch::asinh_(lazy_a);
2830     AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-5);
2831     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2832   });
2833 }
2834 
TEST_F(LazyOpsTest,TestSin)2835 TEST_F(LazyOpsTest, TestSin) {
2836   torch::Tensor a = torch::rand(
2837       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2838   torch::Tensor b = torch::sin(a);
2839   ForEachDevice([&](const torch::Device& device) {
2840     torch::Tensor lazy_a = CopyToDevice(a, device);
2841     torch::Tensor lazy_b = torch::sin(lazy_a);
2842     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2843   });
2844 }
2845 
TEST_F(LazyOpsTest,TestSinh)2846 TEST_F(LazyOpsTest, TestSinh) {
2847   torch::Tensor a = torch::rand(
2848       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2849   torch::Tensor b = torch::sinh(a);
2850   ForEachDevice([&](const torch::Device& device) {
2851     torch::Tensor lazy_a = CopyToDevice(a, device);
2852     torch::Tensor lazy_b = torch::sinh(lazy_a);
2853     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2854   });
2855 }
2856 
TEST_F(LazyOpsTest,TestAcos)2857 TEST_F(LazyOpsTest, TestAcos) {
2858   torch::Tensor a = torch::rand(
2859       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2860   torch::Tensor b = torch::acos(a);
2861   ForEachDevice([&](const torch::Device& device) {
2862     torch::Tensor lazy_a = CopyToDevice(a, device);
2863     torch::Tensor lazy_b = torch::acos(lazy_a);
2864     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2865   });
2866 }
2867 
TEST_F(LazyOpsTest,TestAcosh)2868 TEST_F(LazyOpsTest, TestAcosh) {
2869   torch::Tensor a =
2870       torch::rand(
2871           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
2872       100;
2873   torch::Tensor b = torch::acosh(a);
2874   ForEachDevice([&](const torch::Device& device) {
2875     torch::Tensor lazy_a = CopyToDevice(a, device);
2876     torch::Tensor lazy_b = torch::acosh(lazy_a);
2877     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2878   });
2879 }
2880 
TEST_F(LazyOpsTest,TestAcoshInPlace)2881 TEST_F(LazyOpsTest, TestAcoshInPlace) {
2882   torch::Tensor a =
2883       torch::rand(
2884           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
2885       100;
2886   ForEachDevice([&](const torch::Device& device) {
2887     torch::Tensor lazy_a = CopyToDevice(a, device);
2888     torch::Tensor b = torch::acosh_(a);
2889     torch::Tensor lazy_b = torch::acosh_(lazy_a);
2890     AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-5);
2891     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2892   });
2893 }
2894 
TEST_F(LazyOpsTest,TestCos)2895 TEST_F(LazyOpsTest, TestCos) {
2896   torch::Tensor a = torch::rand(
2897       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2898   torch::Tensor b = torch::cos(a);
2899   ForEachDevice([&](const torch::Device& device) {
2900     torch::Tensor lazy_a = CopyToDevice(a, device);
2901     torch::Tensor lazy_b = torch::cos(lazy_a);
2902     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2903   });
2904 }
2905 
TEST_F(LazyOpsTest,TestCosh)2906 TEST_F(LazyOpsTest, TestCosh) {
2907   torch::Tensor a = torch::rand(
2908       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2909   torch::Tensor b = torch::cosh(a);
2910   ForEachDevice([&](const torch::Device& device) {
2911     torch::Tensor lazy_a = CopyToDevice(a, device);
2912     torch::Tensor lazy_b = torch::cosh(lazy_a);
2913     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2914   });
2915 }
2916 
TEST_F(LazyOpsTest,TestAtan)2917 TEST_F(LazyOpsTest, TestAtan) {
2918   torch::Tensor a = torch::rand(
2919       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2920   torch::Tensor b = torch::atan(a);
2921   ForEachDevice([&](const torch::Device& device) {
2922     torch::Tensor lazy_a = CopyToDevice(a, device);
2923     torch::Tensor lazy_b = torch::atan(lazy_a);
2924     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2925   });
2926 }
2927 
TEST_F(LazyOpsTest,TestAtanh)2928 TEST_F(LazyOpsTest, TestAtanh) {
2929   torch::Tensor a = torch::rand(
2930       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2931   torch::Tensor b = torch::atanh(a);
2932   ForEachDevice([&](const torch::Device& device) {
2933     torch::Tensor lazy_a = CopyToDevice(a, device);
2934     torch::Tensor lazy_b = torch::atanh(lazy_a);
2935     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2936   });
2937 }
2938 
TEST_F(LazyOpsTest,TestAtanhInPlace)2939 TEST_F(LazyOpsTest, TestAtanhInPlace) {
2940   torch::Tensor a = torch::rand(
2941       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2942   ForEachDevice([&](const torch::Device& device) {
2943     torch::Tensor lazy_a = CopyToDevice(a, device);
2944     torch::Tensor b = torch::atanh_(a);
2945     torch::Tensor lazy_b = torch::atanh_(lazy_a);
2946     AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-5);
2947     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2948   });
2949 }
2950 
TEST_F(LazyOpsTest,TestAtan2)2951 TEST_F(LazyOpsTest, TestAtan2) {
2952   torch::Tensor a = torch::randn(
2953       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2954   torch::Tensor b = torch::randn(
2955       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2956   torch::Tensor c = torch::atan2(a, b);
2957   ForEachDevice([&](const torch::Device& device) {
2958     torch::Tensor lazy_a = CopyToDevice(a, device);
2959     torch::Tensor lazy_b = CopyToDevice(b, device);
2960     torch::Tensor lazy_c = torch::atan2(lazy_a, lazy_b);
2961     AllClose(c, lazy_c, /*rtol=*/1e-3, /*atol=*/1e-5);
2962   });
2963 }
2964 
TEST_F(LazyOpsTest,TestTan)2965 TEST_F(LazyOpsTest, TestTan) {
2966   torch::Tensor a = torch::rand(
2967       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2968   torch::Tensor b = torch::tan(a);
2969   ForEachDevice([&](const torch::Device& device) {
2970     torch::Tensor lazy_a = CopyToDevice(a, device);
2971     torch::Tensor lazy_b = torch::tan(lazy_a);
2972     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2973   });
2974 }
2975 
TEST_F(LazyOpsTest,TestTanh)2976 TEST_F(LazyOpsTest, TestTanh) {
2977   torch::Tensor a = torch::rand(
2978       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2979   torch::Tensor b = torch::tanh(a);
2980   ForEachDevice([&](const torch::Device& device) {
2981     torch::Tensor lazy_a = CopyToDevice(a, device);
2982     torch::Tensor lazy_b = torch::tanh(lazy_a);
2983     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2984   });
2985 }
2986 
TEST_F(LazyOpsTest,TestClampMinMax)2987 TEST_F(LazyOpsTest, TestClampMinMax) {
2988   torch::Tensor a = torch::rand(
2989       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2990   torch::Scalar min_val(0.311);
2991   torch::Scalar max_val(0.409);
2992   torch::Tensor b = torch::clamp(a, min_val, max_val);
2993   ForEachDevice([&](const torch::Device& device) {
2994     torch::Tensor lazy_a = CopyToDevice(a, device);
2995     torch::Tensor lazy_b = torch::clamp(lazy_a, min_val, max_val);
2996     AllClose(b, lazy_b);
2997   });
2998 }
2999 
TEST_F(LazyOpsTest,TestClampMin)3000 TEST_F(LazyOpsTest, TestClampMin) {
3001   torch::Tensor a = torch::rand(
3002       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3003   torch::Scalar min_val(0.311);
3004   torch::Tensor b = torch::clamp(a, min_val, std::nullopt);
3005   ForEachDevice([&](const torch::Device& device) {
3006     torch::Tensor lazy_a = CopyToDevice(a, device);
3007     torch::Tensor lazy_b = torch::clamp(lazy_a, min_val, std::nullopt);
3008     AllClose(b, lazy_b);
3009   });
3010 }
3011 
TEST_F(LazyOpsTest,TestClampMax)3012 TEST_F(LazyOpsTest, TestClampMax) {
3013   torch::Tensor a = torch::rand(
3014       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3015   torch::Scalar max_val(0.409);
3016   torch::Tensor b = torch::clamp(a, std::nullopt, max_val);
3017   ForEachDevice([&](const torch::Device& device) {
3018     torch::Tensor lazy_a = CopyToDevice(a, device);
3019     torch::Tensor lazy_b = torch::clamp(lazy_a, std::nullopt, max_val);
3020     AllClose(b, lazy_b);
3021   });
3022 }
3023 
TEST_F(LazyOpsTest,TestClampMinExplicit)3024 TEST_F(LazyOpsTest, TestClampMinExplicit) {
3025   torch::Tensor a = torch::rand(
3026       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3027   torch::Scalar min_val(0.311);
3028   torch::Tensor b = torch::clamp_min(a, min_val);
3029   ForEachDevice([&](const torch::Device& device) {
3030     torch::Tensor lazy_a = CopyToDevice(a, device);
3031     torch::Tensor lazy_b = torch::clamp_min(lazy_a, min_val);
3032     AllClose(b, lazy_b);
3033   });
3034 }
3035 
TEST_F(LazyOpsTest,TestClampMaxExplicit)3036 TEST_F(LazyOpsTest, TestClampMaxExplicit) {
3037   torch::Tensor a = torch::rand(
3038       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3039   torch::Scalar max_val(0.409);
3040   torch::Tensor b = torch::clamp_max(a, max_val);
3041   ForEachDevice([&](const torch::Device& device) {
3042     torch::Tensor lazy_a = CopyToDevice(a, device);
3043     torch::Tensor lazy_b = torch::clamp_max(lazy_a, max_val);
3044     AllClose(b, lazy_b);
3045   });
3046 }
3047 
TEST_F(LazyOpsTest,TestClampMinExplicitInPlace)3048 TEST_F(LazyOpsTest, TestClampMinExplicitInPlace) {
3049   torch::Tensor a = torch::rand(
3050       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3051   torch::Scalar min_val(0.311);
3052   ForEachDevice([&](const torch::Device& device) {
3053     torch::Tensor lazy_a = CopyToDevice(a, device);
3054     torch::Tensor b = torch::clamp_min_(a, min_val);
3055     torch::Tensor lazy_b = torch::clamp_min_(lazy_a, min_val);
3056     AllClose(a, lazy_a);
3057     AllClose(b, lazy_b);
3058   });
3059 }
3060 
TEST_F(LazyOpsTest,TestClampMaxExplicitInPlace)3061 TEST_F(LazyOpsTest, TestClampMaxExplicitInPlace) {
3062   torch::Tensor a = torch::rand(
3063       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3064   torch::Scalar max_val(0.409);
3065   ForEachDevice([&](const torch::Device& device) {
3066     torch::Tensor lazy_a = CopyToDevice(a, device);
3067     torch::Tensor b = torch::clamp_max_(a, max_val);
3068     torch::Tensor lazy_b = torch::clamp_max_(lazy_a, max_val);
3069     AllClose(a, lazy_a);
3070     AllClose(b, lazy_b);
3071   });
3072 }
3073 
TEST_F(LazyOpsTest,TestCeil)3074 TEST_F(LazyOpsTest, TestCeil) {
3075   torch::Tensor a =
3076       torch::randn(
3077           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3078       100.0;
3079   torch::Tensor b = torch::ceil(a);
3080   ForEachDevice([&](const torch::Device& device) {
3081     torch::Tensor lazy_a = CopyToDevice(a, device);
3082     torch::Tensor lazy_b = torch::ceil(lazy_a);
3083     AllClose(b, lazy_b);
3084   });
3085 }
3086 
TEST_F(LazyOpsTest,TestFloor)3087 TEST_F(LazyOpsTest, TestFloor) {
3088   torch::Tensor a =
3089       torch::randn(
3090           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3091       100.0;
3092   torch::Tensor b = torch::floor(a);
3093   ForEachDevice([&](const torch::Device& device) {
3094     torch::Tensor lazy_a = CopyToDevice(a, device);
3095     torch::Tensor lazy_b = torch::floor(lazy_a);
3096     AllClose(b, lazy_b);
3097   });
3098 }
3099 
TEST_F(LazyOpsTest,TestRound)3100 TEST_F(LazyOpsTest, TestRound) {
3101   torch::Tensor a = torch::cat(
3102       {torch::randn(
3103            {8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3104            100.0,
3105        // Special case: 0.5, -0.5. lazy::Round impl rounds to -1/1 whereas
3106        // lazy::RoundToEven properly implements bankers rounding.
3107        torch::tensor(
3108            {-0.5, 0.5},
3109            torch::TensorOptions(torch::kFloat).device(DefaultDevice()))},
3110       0);
3111   torch::Tensor b = torch::round(a);
3112   ForEachDevice([&](const torch::Device& device) {
3113     torch::Tensor lazy_a = CopyToDevice(a, device);
3114     torch::Tensor lazy_b = torch::round(lazy_a);
3115     AllClose(b, lazy_b);
3116   });
3117 }
3118 
TEST_F(LazyOpsTest,TestTrunc)3119 TEST_F(LazyOpsTest, TestTrunc) {
3120   torch::Tensor a =
3121       torch::randn(
3122           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3123       100.0;
3124   torch::Tensor b = torch::trunc(a);
3125   ForEachDevice([&](const torch::Device& device) {
3126     torch::Tensor lazy_a = CopyToDevice(a, device);
3127     torch::Tensor lazy_b = torch::trunc(lazy_a);
3128     AllClose(b, lazy_b);
3129   });
3130 }
3131 
TEST_F(LazyOpsTest,TestFrac)3132 TEST_F(LazyOpsTest, TestFrac) {
3133   torch::Tensor a =
3134       torch::randn(
3135           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3136       100.0;
3137   torch::Tensor b = torch::frac(a);
3138   ForEachDevice([&](const torch::Device& device) {
3139     torch::Tensor lazy_a = CopyToDevice(a, device);
3140     torch::Tensor lazy_b = torch::frac(lazy_a);
3141     AllClose(b, lazy_b);
3142   });
3143 }
3144 
TEST_F(LazyOpsTest,TestNeg)3145 TEST_F(LazyOpsTest, TestNeg) {
3146   torch::Tensor a = torch::rand(
3147       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3148   torch::Tensor b = torch::neg(a);
3149   ForEachDevice([&](const torch::Device& device) {
3150     torch::Tensor lazy_a = CopyToDevice(a, device);
3151     torch::Tensor lazy_b = torch::neg(lazy_a);
3152     AllClose(b, lazy_b);
3153   });
3154 }
3155 
TEST_F(LazyOpsTest,TestBitwiseNot)3156 TEST_F(LazyOpsTest, TestBitwiseNot) {
3157   std::vector<torch::ScalarType> types(
3158       {torch::kByte, torch::kChar, torch::kShort, torch::kInt, torch::kLong});
3159 
3160   ForEachDevice([&](const torch::Device& device) {
3161     for (auto type : types) {
3162       torch::Tensor a =
3163           torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
3164       torch::Tensor b = torch::bitwise_not(a);
3165       torch::Tensor lazy_a = CopyToDevice(a, device);
3166       torch::Tensor lazy_b = torch::bitwise_not(lazy_a);
3167       AllEqual(b, lazy_b);
3168     }
3169   });
3170 }
3171 
TEST_F(LazyOpsTest,TestBitwiseNotInPlace)3172 TEST_F(LazyOpsTest, TestBitwiseNotInPlace) {
3173   std::vector<torch::ScalarType> types(
3174       {torch::kByte, torch::kChar, torch::kShort, torch::kInt, torch::kLong});
3175 
3176   ForEachDevice([&](const torch::Device& device) {
3177     for (auto type : types) {
3178       torch::Tensor a =
3179           torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
3180       torch::Tensor lazy_a = CopyToDevice(a, device);
3181       a.bitwise_not_();
3182       lazy_a.bitwise_not_();
3183       AllEqual(a, lazy_a);
3184     }
3185   });
3186 }
3187 
TEST_F(LazyOpsTest,TestSign)3188 TEST_F(LazyOpsTest, TestSign) {
3189   torch::Tensor a =
3190       torch::randn(
3191           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3192       100.0;
3193   torch::Tensor b = torch::sign(a);
3194   ForEachDevice([&](const torch::Device& device) {
3195     torch::Tensor lazy_a = CopyToDevice(a, device);
3196     torch::Tensor lazy_b = torch::sign(lazy_a);
3197     AllClose(b, lazy_b);
3198   });
3199 }
3200 
TEST_F(LazyOpsTest,TestSignByte)3201 TEST_F(LazyOpsTest, TestSignByte) {
3202   torch::Tensor a = torch::randint(
3203       256, {2, 2}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
3204   torch::Tensor b = torch::sign(a);
3205   ForEachDevice([&](const torch::Device& device) {
3206     torch::Tensor lazy_a = CopyToDevice(a, device);
3207     torch::Tensor lazy_b = torch::sign(lazy_a);
3208     AllEqual(b, lazy_b);
3209   });
3210 }
3211 
TEST_F(LazyOpsTest,TestAbs)3212 TEST_F(LazyOpsTest, TestAbs) {
3213   torch::Tensor a = torch::randn(
3214       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3215   torch::Tensor b = torch::abs(a);
3216   ForEachDevice([&](const torch::Device& device) {
3217     torch::Tensor lazy_a = CopyToDevice(a, device);
3218     torch::Tensor lazy_b = torch::abs(lazy_a);
3219     AllClose(b, lazy_b);
3220   });
3221 }
3222 
TEST_F(LazyOpsTest,TestAbsByte)3223 TEST_F(LazyOpsTest, TestAbsByte) {
3224   torch::Tensor a = torch::randint(
3225       256, {2, 2}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
3226   torch::Tensor b = torch::abs(a);
3227   ForEachDevice([&](const torch::Device& device) {
3228     torch::Tensor lazy_a = CopyToDevice(a, device);
3229     torch::Tensor lazy_b = torch::abs(lazy_a);
3230     AllEqual(b, lazy_b);
3231   });
3232 }
3233 
TEST_F(LazyOpsTest,TestEmptyLike)3234 TEST_F(LazyOpsTest, TestEmptyLike) {
3235   torch::Tensor a = torch::rand(
3236       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3237   torch::Tensor b = torch::empty_like(a);
3238   ForEachDevice([&](const torch::Device& device) {
3239     torch::Tensor lazy_a = CopyToDevice(a, device);
3240     torch::Tensor lazy_b = torch::empty_like(lazy_a);
3241     EXPECT_EQ(b.sizes(), lazy_b.sizes());
3242   });
3243 }
3244 
TEST_F(LazyOpsTest,TestEmptyLikeOptions)3245 TEST_F(LazyOpsTest, TestEmptyLikeOptions) {
3246   torch::Tensor a = torch::rand(
3247       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3248   torch::Tensor b = torch::empty_like(
3249       a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3250   ForEachDevice([&](const torch::Device& device) {
3251     torch::Tensor lazy_a = CopyToDevice(a, device);
3252     torch::Tensor lazy_b = torch::empty_like(
3253         lazy_a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3254     EXPECT_EQ(b.sizes(), lazy_b.sizes());
3255   });
3256 }
3257 
TEST_F(LazyOpsTest,TestEmpty)3258 TEST_F(LazyOpsTest, TestEmpty) {
3259   torch::Tensor a = torch::zeros(
3260       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3261   ForEachDevice([&](const torch::Device& device) {
3262     torch::Tensor lazy_a = torch::empty(
3263         {2, 2}, torch::TensorOptions(torch::kFloat).device(device));
3264     EXPECT_EQ(a.sizes(), lazy_a.sizes());
3265   });
3266 }
3267 
TEST_F(LazyOpsTest,TestZeroInPlace)3268 TEST_F(LazyOpsTest, TestZeroInPlace) {
3269   torch::Tensor input = torch::ones(
3270       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3271 
3272   ForEachDevice([&](const torch::Device& device) {
3273     torch::Tensor lazyInput = CopyToDevice(input, device);
3274     auto& output = torch::zero_(input);
3275     auto& lazyOutput = torch::zero_(lazyInput);
3276     AllClose(output, lazyOutput);
3277   });
3278 }
3279 
TEST_F(LazyOpsTest,TestZerosLike)3280 TEST_F(LazyOpsTest, TestZerosLike) {
3281   torch::Tensor a = torch::rand(
3282       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3283   torch::Tensor b = torch::zeros_like(a);
3284   ForEachDevice([&](const torch::Device& device) {
3285     torch::Tensor lazy_a = CopyToDevice(a, device);
3286     torch::Tensor lazy_b = torch::zeros_like(lazy_a);
3287     AllClose(a, lazy_a);
3288   });
3289 }
3290 
TEST_F(LazyOpsTest,TestZerosLikeOptions)3291 TEST_F(LazyOpsTest, TestZerosLikeOptions) {
3292   torch::Tensor a = torch::rand(
3293       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3294   torch::Tensor b = torch::zeros_like(
3295       a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3296   ForEachDevice([&](const torch::Device& device) {
3297     torch::Tensor lazy_a = CopyToDevice(a, device);
3298     torch::Tensor lazy_b = torch::zeros_like(
3299         lazy_a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3300     AllClose(a, lazy_a);
3301   });
3302 }
3303 
TEST_F(LazyOpsTest,TestZeros)3304 TEST_F(LazyOpsTest, TestZeros) {
3305   torch::Tensor a = torch::zeros(
3306       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3307   ForEachDevice([&](const torch::Device& device) {
3308     torch::Tensor lazy_a = torch::zeros(
3309         {2, 2}, torch::TensorOptions(torch::kFloat).device(device));
3310     AllClose(a, lazy_a);
3311   });
3312 }
3313 
TEST_F(LazyOpsTest,TestOnes)3314 TEST_F(LazyOpsTest, TestOnes) {
3315   torch::Tensor a = torch::ones(
3316       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3317   ForEachDevice([&](const torch::Device& device) {
3318     torch::Tensor lazy_a =
3319         torch::ones({2, 2}, torch::TensorOptions(torch::kFloat).device(device));
3320     AllClose(a, lazy_a);
3321   });
3322 }
3323 
TEST_F(LazyOpsTest,TestOnesLike)3324 TEST_F(LazyOpsTest, TestOnesLike) {
3325   torch::Tensor a = torch::rand(
3326       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3327   torch::Tensor b = torch::ones_like(a);
3328   ForEachDevice([&](const torch::Device& device) {
3329     torch::Tensor lazy_a = CopyToDevice(a, device);
3330     torch::Tensor lazy_b = torch::ones_like(lazy_a);
3331     AllClose(a, lazy_a);
3332   });
3333 }
3334 
TEST_F(LazyOpsTest,TestOnesLikeOptions)3335 TEST_F(LazyOpsTest, TestOnesLikeOptions) {
3336   torch::Tensor a = torch::rand(
3337       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3338   torch::Tensor b = torch::ones_like(
3339       a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3340   ForEachDevice([&](const torch::Device& device) {
3341     torch::Tensor lazy_a = CopyToDevice(a, device);
3342     torch::Tensor lazy_b = torch::ones_like(
3343         lazy_a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3344     AllClose(a, lazy_a);
3345   });
3346 }
3347 
TEST_F(LazyOpsTest,TestFull)3348 TEST_F(LazyOpsTest, TestFull) {
3349   torch::Tensor a = torch::full(
3350       {2, 2},
3351       3.1165,
3352       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3353   ForEachDevice([&](const torch::Device& device) {
3354     torch::Tensor lazy_a = torch::full(
3355         {2, 2}, 3.1165, torch::TensorOptions(torch::kFloat).device(device));
3356     AllClose(a, lazy_a);
3357   });
3358 }
3359 
TEST_F(LazyOpsTest,TestFullLike)3360 TEST_F(LazyOpsTest, TestFullLike) {
3361   torch::Tensor a = torch::rand(
3362       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3363   torch::Tensor b = torch::full_like(a, 3.1165);
3364   ForEachDevice([&](const torch::Device& device) {
3365     torch::Tensor lazy_a = CopyToDevice(a, device);
3366     torch::Tensor lazy_b = torch::full_like(lazy_a, 3.1165);
3367     AllClose(a, lazy_a);
3368   });
3369 }
3370 
TEST_F(LazyOpsTest,TestFullLikeOptions)3371 TEST_F(LazyOpsTest, TestFullLikeOptions) {
3372   torch::Tensor a = torch::rand(
3373       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3374   torch::Tensor b = torch::full_like(
3375       a, 3.1165, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3376   ForEachDevice([&](const torch::Device& device) {
3377     torch::Tensor lazy_a = CopyToDevice(a, device);
3378     torch::Tensor lazy_b = torch::full_like(
3379         lazy_a,
3380         3.1165,
3381         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3382     AllClose(a, lazy_a);
3383   });
3384 }
3385 
TEST_F(LazyOpsTest,TestARange)3386 TEST_F(LazyOpsTest, TestARange) {
3387   for (auto& ranges : std::vector<std::vector<float>>{
3388            {0.0, 100.0, 0.5}, {0.0, -100.0, -0.5}}) {
3389     torch::Tensor a = torch::arange(
3390         ranges[0],
3391         ranges[1],
3392         ranges[2],
3393         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3394     ForEachDevice([&](const torch::Device& device) {
3395       torch::Tensor lazy_a = torch::arange(
3396           ranges[0],
3397           ranges[1],
3398           ranges[2],
3399           torch::TensorOptions(torch::kFloat).device(device));
3400       AllClose(a, lazy_a);
3401     });
3402   }
3403 }
3404 
TEST_F(LazyOpsTest,TestARangeOut)3405 TEST_F(LazyOpsTest, TestARangeOut) {
3406   torch::Tensor a = torch::randn(
3407       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3408   for (auto& ranges : std::vector<std::vector<float>>{
3409            {0.0, 100.0, 0.5}, {0.0, -100.0, -0.5}}) {
3410     torch::Tensor b = torch::arange_out(a, ranges[0], ranges[1], ranges[2]);
3411     ForEachDevice([&](const torch::Device& device) {
3412       torch::Tensor lazy_a = CopyToDevice(a, device);
3413       torch::Tensor lazy_b =
3414           torch::arange_out(lazy_a, ranges[0], ranges[1], ranges[2]);
3415       AllClose(b, lazy_b);
3416     });
3417   }
3418 }
3419 
TEST_F(LazyOpsTest,TestDimARange)3420 TEST_F(LazyOpsTest, TestDimARange) {
3421   torch::Tensor like = torch::rand(
3422       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3423   torch::Tensor a = torch::_dim_arange(like, 1);
3424   ForEachDevice([&](const torch::Device& device) {
3425     torch::Tensor lazy_like = CopyToDevice(like, device);
3426     torch::Tensor lazy_a = torch::_dim_arange(lazy_like, 1);
3427     AllClose(a, lazy_a);
3428   });
3429 }
3430 
TEST_F(LazyOpsTest,TestBartlettWindow)3431 TEST_F(LazyOpsTest, TestBartlettWindow) {
3432   int window_length = 10;
3433   for (bool periodic : {false, true}) {
3434     ForEachDevice([&](const torch::Device& device) {
3435       torch::Tensor output = torch::bartlett_window(
3436           window_length,
3437           periodic,
3438           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3439 
3440       torch::Tensor lazy_output = torch::bartlett_window(
3441           window_length,
3442           periodic,
3443           torch::TensorOptions(torch::kFloat).device(device));
3444       AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-7);
3445     });
3446   }
3447 }
3448 
TEST_F(LazyOpsTest,TestBlackmanWindow)3449 TEST_F(LazyOpsTest, TestBlackmanWindow) {
3450   int window_length = 10;
3451   for (bool periodic : {false, true}) {
3452     ForEachDevice([&](const torch::Device& device) {
3453       torch::Tensor output = torch::blackman_window(
3454           window_length,
3455           periodic,
3456           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3457       torch::Tensor lazy_output = torch::blackman_window(
3458           window_length,
3459           periodic,
3460           torch::TensorOptions(torch::kFloat).device(device));
3461       AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-7);
3462     });
3463   }
3464 }
3465 
TEST_F(LazyOpsTest,TestHammingWindow)3466 TEST_F(LazyOpsTest, TestHammingWindow) {
3467   double alpha = 0.54;
3468   double beta = 0.46;
3469   int window_length = 10;
3470   for (bool periodic : {false, true}) {
3471     ForEachDevice([&](const torch::Device& device) {
3472       torch::Tensor output = torch::hamming_window(
3473           window_length,
3474           periodic,
3475           alpha,
3476           beta,
3477           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3478       torch::Tensor lazy_output = torch::hamming_window(
3479           window_length,
3480           periodic,
3481           alpha,
3482           beta,
3483           torch::TensorOptions(torch::kFloat).device(device));
3484       AllClose(output, lazy_output);
3485     });
3486   }
3487 }
3488 
TEST_F(LazyOpsTest,TestHannWindow)3489 TEST_F(LazyOpsTest, TestHannWindow) {
3490   int window_length = 10;
3491   for (bool periodic : {false, true}) {
3492     ForEachDevice([&](const torch::Device& device) {
3493       torch::Tensor output = torch::hann_window(
3494           window_length,
3495           periodic,
3496           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3497       torch::Tensor lazy_output = torch::hann_window(
3498           window_length,
3499           periodic,
3500           torch::TensorOptions(torch::kFloat).device(device));
3501       AllClose(output, lazy_output);
3502     });
3503   }
3504 }
3505 
TEST_F(LazyOpsTest,TestLogSigmoid)3506 TEST_F(LazyOpsTest, TestLogSigmoid) {
3507   torch::Tensor a = torch::empty(
3508       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3509   a.uniform_(-1.0, 1.0);
3510   torch::Tensor b = torch::log_sigmoid(a);
3511   ForEachDevice([&](const torch::Device& device) {
3512     torch::Tensor lazy_a = CopyToDevice(a, device);
3513     torch::Tensor lazy_b = torch::log_sigmoid(lazy_a);
3514     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
3515   });
3516 }
3517 
TEST_F(LazyOpsTest,TestLogSigmoidForward)3518 TEST_F(LazyOpsTest, TestLogSigmoidForward) {
3519   torch::Tensor a = torch::empty(
3520       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3521   a.uniform_(-1.0, 1.0);
3522   auto tuple = torch::log_sigmoid_forward(a);
3523   ForEachDevice([&](const torch::Device& device) {
3524     torch::Tensor lazy_a = CopyToDevice(a, device);
3525     auto lazy_tuple = torch::log_sigmoid_forward(lazy_a);
3526     AllClose(
3527         std::get<0>(tuple),
3528         std::get<0>(lazy_tuple),
3529         /*rtol=*/1e-3,
3530         /*atol=*/1e-5);
3531     AllClose(
3532         std::get<1>(tuple),
3533         std::get<1>(lazy_tuple),
3534         /*rtol=*/1e-3,
3535         /*atol=*/1e-5);
3536   });
3537 }
3538 
TEST_F(LazyOpsTest,TestLogsumexp)3539 TEST_F(LazyOpsTest, TestLogsumexp) {
3540   torch::Tensor a = torch::rand(
3541       {3, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3542   for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
3543     for (bool keepdim : {false, true}) {
3544       torch::Tensor b = torch::logsumexp(a, dims, keepdim);
3545       ForEachDevice([&](const torch::Device& device) {
3546         torch::Tensor lazy_a = CopyToDevice(a, device);
3547         torch::Tensor lazy_b = torch::logsumexp(lazy_a, dims, keepdim);
3548         AllClose(b, lazy_b);
3549       });
3550     }
3551   }
3552 }
3553 
TEST_F(LazyOpsTest,TestSiLU)3554 TEST_F(LazyOpsTest, TestSiLU) {
3555   torch::Tensor a = torch::rand(
3556       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3557   torch::Tensor b = torch::silu(a);
3558   ForEachDevice([&](const torch::Device& device) {
3559     torch::Tensor lazy_a = CopyToDevice(a, device);
3560     torch::Tensor lazy_b = torch::silu(lazy_a);
3561     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
3562   });
3563   ExpectCounterChanged("lazy::silu_out", GetIgnoredCounters());
3564 }
3565 
TEST_F(LazyOpsTest,TestSigmoid)3566 TEST_F(LazyOpsTest, TestSigmoid) {
3567   torch::Tensor a = torch::rand(
3568       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3569   torch::Tensor b = torch::sigmoid(a);
3570   ForEachDevice([&](const torch::Device& device) {
3571     torch::Tensor lazy_a = CopyToDevice(a, device);
3572     torch::Tensor lazy_b = torch::sigmoid(lazy_a);
3573     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
3574   });
3575 }
3576 
TEST_F(LazyOpsTest,TestMatmul_1x1)3577 TEST_F(LazyOpsTest, TestMatmul_1x1) {
3578   torch::Tensor a = torch::rand(
3579       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3580   torch::Tensor b = torch::rand(
3581       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3582   torch::Tensor c = torch::matmul(a, b);
3583   ForEachDevice([&](const torch::Device& device) {
3584     torch::Tensor lazy_a = CopyToDevice(a, device);
3585     torch::Tensor lazy_b = CopyToDevice(b, device);
3586     torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3587     AllClose(c, lazy_c);
3588   });
3589 }
3590 
TEST_F(LazyOpsTest,TestMatmul_2x1)3591 TEST_F(LazyOpsTest, TestMatmul_2x1) {
3592   torch::Tensor a = torch::rand(
3593       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3594   torch::Tensor b = torch::rand(
3595       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3596   torch::Tensor c = torch::matmul(a, b);
3597   ForEachDevice([&](const torch::Device& device) {
3598     torch::Tensor lazy_a = CopyToDevice(a, device);
3599     torch::Tensor lazy_b = CopyToDevice(b, device);
3600     torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3601     AllClose(c, lazy_c);
3602   });
3603 }
3604 
TEST_F(LazyOpsTest,TestMatmul_1x2)3605 TEST_F(LazyOpsTest, TestMatmul_1x2) {
3606   torch::Tensor a = torch::rand(
3607       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3608   torch::Tensor b = torch::rand(
3609       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3610   torch::Tensor c = torch::matmul(a, b);
3611   ForEachDevice([&](const torch::Device& device) {
3612     torch::Tensor lazy_a = CopyToDevice(a, device);
3613     torch::Tensor lazy_b = CopyToDevice(b, device);
3614     torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3615     AllClose(c, lazy_c);
3616   });
3617 }
3618 
TEST_F(LazyOpsTest,TestMatmul_2x2)3619 TEST_F(LazyOpsTest, TestMatmul_2x2) {
3620   torch::Tensor a = torch::rand(
3621       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3622   torch::Tensor b = torch::rand(
3623       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3624   torch::Tensor c = torch::matmul(a, b);
3625   ForEachDevice([&](const torch::Device& device) {
3626     torch::Tensor lazy_a = CopyToDevice(a, device);
3627     torch::Tensor lazy_b = CopyToDevice(b, device);
3628     torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3629     AllClose(c, lazy_c, /*rtol=*/1e-3, /*atol=*/1e-4);
3630   });
3631 }
3632 
TEST_F(LazyOpsTest,TestMatmulBcast)3633 TEST_F(LazyOpsTest, TestMatmulBcast) {
3634   torch::Tensor a = torch::rand(
3635       {4, 2, 3, 2, 4},
3636       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3637   torch::Tensor b = torch::rand(
3638       {2, 1, 4, 3},
3639       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3640   torch::Tensor c = torch::matmul(a, b);
3641   ForEachDevice([&](const torch::Device& device) {
3642     torch::Tensor lazy_a = CopyToDevice(a, device);
3643     torch::Tensor lazy_b = CopyToDevice(b, device);
3644     torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3645     AllClose(c, lazy_c);
3646   });
3647 }
3648 
TEST_F(LazyOpsTest,TestDot)3649 TEST_F(LazyOpsTest, TestDot) {
3650   torch::Tensor a = torch::rand(
3651       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3652   torch::Tensor b = torch::rand(
3653       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3654   torch::Tensor c = torch::dot(a, b);
3655   ForEachDevice([&](const torch::Device& device) {
3656     torch::Tensor lazy_a = CopyToDevice(a, device);
3657     torch::Tensor lazy_b = CopyToDevice(b, device);
3658     torch::Tensor lazy_c = torch::dot(lazy_a, lazy_b);
3659     AllClose(c, lazy_c);
3660   });
3661 }
3662 
TEST_F(LazyOpsTest,TestTensorDot)3663 TEST_F(LazyOpsTest, TestTensorDot) {
3664   torch::Tensor a = torch::rand(
3665       {6, 4, 8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3666   torch::Tensor b = torch::rand(
3667       {4, 7, 8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3668   std::vector<int64_t> dims_a = {1, 2};
3669   std::vector<int64_t> dims_b = {0, 2};
3670   torch::Tensor c = torch::tensordot(a, b, dims_a, dims_b);
3671   ForEachDevice([&](const torch::Device& device) {
3672     torch::Tensor lazy_a = CopyToDevice(a, device);
3673     torch::Tensor lazy_b = CopyToDevice(b, device);
3674     torch::Tensor lazy_c = torch::tensordot(lazy_a, lazy_b, dims_a, dims_b);
3675     AllClose(c, lazy_c);
3676   });
3677 }
3678 
TEST_F(LazyOpsTest,TestGer)3679 TEST_F(LazyOpsTest, TestGer) {
3680   torch::Tensor a = torch::rand(
3681       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3682   torch::Tensor b = torch::rand(
3683       {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3684   torch::Tensor c = torch::ger(a, b);
3685   ForEachDevice([&](const torch::Device& device) {
3686     torch::Tensor lazy_a = CopyToDevice(a, device);
3687     torch::Tensor lazy_b = CopyToDevice(b, device);
3688     torch::Tensor lazy_c = torch::ger(lazy_a, lazy_b);
3689     AllClose(c, lazy_c);
3690   });
3691 }
3692 
TEST_F(LazyOpsTest,TestMv)3693 TEST_F(LazyOpsTest, TestMv) {
3694   torch::Tensor a = torch::rand(
3695       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3696   torch::Tensor b = torch::rand(
3697       {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3698   torch::Tensor c = torch::mv(a, b);
3699   ForEachDevice([&](const torch::Device& device) {
3700     torch::Tensor lazy_a = CopyToDevice(a, device);
3701     torch::Tensor lazy_b = CopyToDevice(b, device);
3702     torch::Tensor lazy_c = torch::mv(lazy_a, lazy_b);
3703     AllClose(c, lazy_c);
3704   });
3705 }
3706 
TEST_F(LazyOpsTest,TestMvOut)3707 TEST_F(LazyOpsTest, TestMvOut) {
3708   torch::Tensor a = torch::rand(
3709       {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3710   torch::Tensor b = torch::rand(
3711       {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3712   torch::Tensor c = torch::empty(
3713       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3714   torch::mv_out(c, a, b);
3715   ForEachDevice([&](const torch::Device& device) {
3716     torch::Tensor lazy_a = CopyToDevice(a, device);
3717     torch::Tensor lazy_b = CopyToDevice(b, device);
3718     torch::Tensor lazy_c = torch::empty({4}, lazy_b.options());
3719     torch::mv_out(lazy_c, lazy_a, lazy_b);
3720     AllClose(c, lazy_c);
3721   });
3722 }
3723 
TEST_F(LazyOpsTest,TestBatchAddBatchMatMul)3724 TEST_F(LazyOpsTest, TestBatchAddBatchMatMul) {
3725   torch::Tensor a = torch::rand(
3726       {3, 6, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3727   torch::Tensor b = torch::rand(
3728       {3, 6, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3729   torch::Tensor c = torch::rand(
3730       {3, 4, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3731   torch::Scalar alpha = 0.5;
3732   torch::Scalar beta = 1.5;
3733   torch::Tensor d = torch::baddbmm(a, b, c, beta, alpha);
3734   ForEachDevice([&](const torch::Device& device) {
3735     torch::Tensor lazy_a = CopyToDevice(a, device);
3736     torch::Tensor lazy_b = CopyToDevice(b, device);
3737     torch::Tensor lazy_c = CopyToDevice(c, device);
3738     torch::Tensor lazy_d = torch::baddbmm(lazy_a, lazy_b, lazy_c, beta, alpha);
3739     AllClose(d, lazy_d, /*rtol=*/1e-3, /*atol=*/1e-4);
3740   });
3741 }
3742 
TEST_F(LazyOpsTest,TestBatchAddBatchMatMulInPlace)3743 TEST_F(LazyOpsTest, TestBatchAddBatchMatMulInPlace) {
3744   torch::Tensor a = torch::rand(
3745       {3, 6, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3746   torch::Tensor b = torch::rand(
3747       {3, 6, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3748   torch::Tensor c = torch::rand(
3749       {3, 4, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3750   torch::Scalar alpha = 0.5;
3751   torch::Scalar beta = 1.5;
3752   ForEachDevice([&](const torch::Device& device) {
3753     torch::Tensor lazy_a = CopyToDevice(a, device);
3754     torch::Tensor lazy_b = CopyToDevice(b, device);
3755     torch::Tensor lazy_c = CopyToDevice(c, device);
3756     torch::Tensor d = a.baddbmm_(b, c, beta, alpha);
3757     torch::Tensor lazy_d = lazy_a.baddbmm_(lazy_b, lazy_c, beta, alpha);
3758     AllClose(d, lazy_d, /*rtol=*/1e-3, /*atol=*/1e-4);
3759     AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-4);
3760   });
3761 }
3762 
TEST_F(LazyOpsTest,TestBatchMatMul)3763 TEST_F(LazyOpsTest, TestBatchMatMul) {
3764   torch::Tensor a = torch::rand(
3765       {3, 6, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3766   torch::Tensor b = torch::rand(
3767       {3, 4, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3768   torch::Tensor c = torch::bmm(a, b);
3769   ForEachDevice([&](const torch::Device& device) {
3770     torch::Tensor lazy_a = CopyToDevice(a, device);
3771     torch::Tensor lazy_b = CopyToDevice(b, device);
3772     torch::Tensor lazy_c = torch::bmm(lazy_a, lazy_b);
3773     AllClose(c, lazy_c, /*rtol=*/1e-3, /*atol=*/1e-4);
3774   });
3775 }
3776 
TEST_F(LazyOpsTest,TestChainMatMul)3777 TEST_F(LazyOpsTest, TestChainMatMul) {
3778   torch::Tensor a = torch::rand(
3779       {5, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3780   torch::Tensor b = torch::rand(
3781       {4, 6}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3782   torch::Tensor c = torch::rand(
3783       {6, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3784   torch::Tensor d = torch::rand(
3785       {2, 7}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3786   torch::Tensor result = torch::chain_matmul({a, b, c, d});
3787   ForEachDevice([&](const torch::Device& device) {
3788     torch::Tensor lazy_a = CopyToDevice(a, device);
3789     torch::Tensor lazy_b = CopyToDevice(b, device);
3790     torch::Tensor lazy_c = CopyToDevice(c, device);
3791     torch::Tensor lazy_d = CopyToDevice(d, device);
3792     torch::Tensor lazy_result =
3793         torch::chain_matmul({lazy_a, lazy_b, lazy_c, lazy_d});
3794     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-4);
3795   });
3796 }
3797 
TEST_F(LazyOpsTest,TestLinear)3798 TEST_F(LazyOpsTest, TestLinear) {
3799   torch::Tensor input = torch::rand(
3800       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3801   torch::Tensor weight = torch::rand(
3802       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3803   torch::Tensor bias = torch::rand(
3804       {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3805   torch::Tensor result = torch::linear(input, weight);
3806   torch::Tensor result_with_bias = torch::linear(input, weight, bias);
3807   ForEachDevice([&](const torch::Device& device) {
3808     torch::Tensor lazy_input = CopyToDevice(input, device);
3809     torch::Tensor lazy_weight = CopyToDevice(weight, device);
3810     torch::Tensor lazy_bias = CopyToDevice(bias, device);
3811     torch::Tensor lazy_result = torch::linear(lazy_input, lazy_weight);
3812     torch::Tensor lazy_result_with_bias =
3813         torch::linear(lazy_input, lazy_weight, lazy_bias);
3814     AllClose(result, lazy_result, /*rtol=*/1e-2, /*atol=*/1e-4);
3815     AllClose(
3816         result_with_bias,
3817         lazy_result_with_bias,
3818         /*rtol=*/1e-2,
3819         /*atol=*/1e-4);
3820   });
3821 }
3822 
TEST_F(LazyOpsTest,TestPinverse)3823 TEST_F(LazyOpsTest, TestPinverse) {
3824   torch::Tensor input = torch::rand(
3825       {4, 6}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3826   torch::Tensor result = torch::pinverse(input);
3827   ForEachDevice([&](const torch::Device& device) {
3828     torch::Tensor lazy_input = CopyToDevice(input, device);
3829     torch::Tensor lazy_result = torch::pinverse(lazy_input);
3830     AllClose(result, lazy_result, /*rtol=*/1e-4);
3831   });
3832 }
3833 
TEST_F(LazyOpsTest,TestEinsumOuter)3834 TEST_F(LazyOpsTest, TestEinsumOuter) {
3835   torch::Tensor a = torch::rand(
3836       {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3837   torch::Tensor b = torch::rand(
3838       {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3839   std::string equation = "i,j->ij";
3840   torch::Tensor c = torch::einsum(equation, {a, b});
3841   ForEachDevice([&](const torch::Device& device) {
3842     torch::Tensor lazy_a = CopyToDevice(a, device);
3843     torch::Tensor lazy_b = CopyToDevice(b, device);
3844     torch::Tensor lazy_c = torch::einsum(equation, {lazy_a, lazy_b});
3845     AllClose(c, lazy_c);
3846   });
3847 }
3848 
TEST_F(LazyOpsTest,TestEinsumOuterBackward)3849 TEST_F(LazyOpsTest, TestEinsumOuterBackward) {
3850   torch::Tensor a = torch::rand(
3851       {5},
3852       torch::TensorOptions(torch::kFloat)
3853           .device(DefaultDevice())
3854           .requires_grad(true));
3855   torch::Tensor b = torch::rand(
3856       {5},
3857       torch::TensorOptions(torch::kFloat)
3858           .device(DefaultDevice())
3859           .requires_grad(true));
3860   std::string equation = "i,j->ij";
3861   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
3862     return torch::einsum(equation, inputs);
3863   };
3864   ForEachDevice([&](const torch::Device& device) {
3865     TestBackward({a, b}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
3866   });
3867 }
3868 
TEST_F(LazyOpsTest,TestEinsumBatchMatMul)3869 TEST_F(LazyOpsTest, TestEinsumBatchMatMul) {
3870   torch::Tensor a = torch::rand(
3871       {3, 2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3872   torch::Tensor b = torch::rand(
3873       {3, 5, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3874   std::string equation = "bij,bjk->bik";
3875   torch::Tensor c = torch::einsum(equation, {a, b});
3876   ForEachDevice([&](const torch::Device& device) {
3877     torch::Tensor lazy_a = CopyToDevice(a, device);
3878     torch::Tensor lazy_b = CopyToDevice(b, device);
3879     torch::Tensor lazy_c = torch::einsum(equation, {lazy_a, lazy_b});
3880     AllClose(c, lazy_c);
3881   });
3882 }
3883 
TEST_F(LazyOpsTest,TestEinsumPyTorchLowerBilinear)3884 TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBilinear) {
3885   torch::Tensor a = torch::rand(
3886       {3, 5, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3887   torch::Tensor l = torch::rand(
3888       {2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3889   torch::Tensor r = torch::rand(
3890       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3891   std::string equation = "bn,anm,bm->ba";
3892   torch::Tensor c = torch::einsum(equation, {l, a, r});
3893   ForEachDevice([&](const torch::Device& device) {
3894     torch::Tensor lazy_l = CopyToDevice(l, device);
3895     torch::Tensor lazy_a = CopyToDevice(a, device);
3896     torch::Tensor lazy_r = CopyToDevice(r, device);
3897     torch::Tensor lazy_c = torch::einsum(equation, {lazy_l, lazy_a, lazy_r});
3898     AllClose(c, lazy_c);
3899   });
3900 }
3901 
TEST_F(LazyOpsTest,TestEinsumPyTorchLowerDiagonal)3902 TEST_F(LazyOpsTest, TestEinsumPyTorchLowerDiagonal) {
3903   torch::Tensor input = torch::rand(
3904       {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3905   std::string equation = "ii->i";
3906   torch::Tensor result = torch::einsum(equation, {input});
3907   ForEachDevice([&](const torch::Device& device) {
3908     torch::Tensor lazy_input = CopyToDevice(input, device);
3909     torch::Tensor lazy_result = torch::einsum(equation, {lazy_input});
3910     AllClose(result, lazy_result);
3911   });
3912 }
3913 
TEST_F(LazyOpsTest,TestEinsumPyTorchLowerBatchDiagonal)3914 TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBatchDiagonal) {
3915   torch::Tensor input = torch::rand(
3916       {4, 3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3917   std::string equation = "...ii->...i";
3918   torch::Tensor result = torch::einsum(equation, {input});
3919   ForEachDevice([&](const torch::Device& device) {
3920     torch::Tensor lazy_input = CopyToDevice(input, device);
3921     torch::Tensor lazy_result = torch::einsum(equation, {lazy_input});
3922     AllClose(result, lazy_result);
3923   });
3924 }
3925 
TEST_F(LazyOpsTest,TestEinsumPyTorchLowerBatchPermute)3926 TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBatchPermute) {
3927   torch::Tensor input = torch::rand(
3928       {2, 3, 4, 5},
3929       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3930   std::string equation = "...ij->...ji";
3931   torch::Tensor result = torch::einsum(equation, {input});
3932   ForEachDevice([&](const torch::Device& device) {
3933     torch::Tensor lazy_input = CopyToDevice(input, device);
3934     torch::Tensor lazy_result = torch::einsum(equation, {lazy_input});
3935     AllClose(result, lazy_result);
3936   });
3937 }
3938 
TEST_F(LazyOpsTest,TestEinsumPyTorchLowerRepeatedAxis)3939 TEST_F(LazyOpsTest, TestEinsumPyTorchLowerRepeatedAxis) {
3940   torch::Tensor x = torch::rand(
3941       {2, 3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3942   torch::Tensor y = torch::rand(
3943       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3944   std::string equation = "ijj,k->ik";
3945   torch::Tensor result = torch::einsum(equation, {x, y});
3946   ForEachDevice([&](const torch::Device& device) {
3947     torch::Tensor lazy_x = CopyToDevice(x, device);
3948     torch::Tensor lazy_y = CopyToDevice(y, device);
3949     torch::Tensor lazy_result = torch::einsum(equation, {lazy_x, lazy_y});
3950     AllClose(result, lazy_result);
3951   });
3952 }
3953 
TEST_F(LazyOpsTest,TestBilinear)3954 TEST_F(LazyOpsTest, TestBilinear) {
3955   int batch_size = 16;
3956   int in1_features = 4;
3957   int in2_features = 6;
3958   int out_features = 8;
3959   torch::Tensor input1 = torch::rand(
3960       {batch_size, in1_features},
3961       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3962   torch::Tensor input2 = torch::rand(
3963       {batch_size, in2_features},
3964       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3965   torch::Tensor weight = torch::rand(
3966       {out_features, in1_features, in2_features},
3967       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3968   torch::Tensor bias = torch::rand(
3969       {out_features},
3970       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3971   ForEachDevice([&](const torch::Device& device) {
3972     torch::Tensor lazy_input1 = CopyToDevice(input1, device);
3973     torch::Tensor lazy_input2 = CopyToDevice(input2, device);
3974     torch::Tensor lazy_weight = CopyToDevice(weight, device);
3975     torch::Tensor lazy_bias = CopyToDevice(bias, device);
3976     torch::Tensor result = torch::bilinear(input1, input2, weight, bias);
3977     torch::Tensor lazy_result =
3978         torch::bilinear(lazy_input1, lazy_input2, lazy_weight, lazy_bias);
3979     AllClose(result, lazy_result);
3980   });
3981 }
3982 
TEST_F(LazyOpsTest,TestUpsampleNearest2D)3983 TEST_F(LazyOpsTest, TestUpsampleNearest2D) {
3984   int batch_size = 2;
3985   int h = 5;
3986   int w = 5;
3987   int uh = 8;
3988   int uw = 8;
3989   int chans = 2;
3990   torch::Tensor input = torch::rand(
3991       {batch_size, chans, h, w},
3992       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3993   ForEachDevice([&](const torch::Device& device) {
3994     torch::Tensor lazy_input = CopyToDevice(input, device);
3995     torch::Tensor result = torch::upsample_nearest2d(input, {uh, uw});
3996     torch::Tensor lazy_result = torch::upsample_nearest2d(lazy_input, {uh, uw});
3997     AllClose(result, lazy_result);
3998   });
3999 }
4000 
TEST_F(LazyOpsTest,TestUpsampleNearest2DBackward)4001 TEST_F(LazyOpsTest, TestUpsampleNearest2DBackward) {
4002   int batch_size = 2;
4003   int h = 5;
4004   int w = 5;
4005   int uh = 8;
4006   int uw = 8;
4007   int chans = 2;
4008   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4009     return torch::upsample_nearest2d(inputs[0], {uh, uw});
4010   };
4011   ForEachDevice([&](const torch::Device& device) {
4012     TestBackward(
4013         {torch::rand(
4014             {batch_size, chans, h, w},
4015             torch::TensorOptions(torch::kFloat)
4016                 .device(DefaultDevice())
4017                 .requires_grad(true))},
4018         device,
4019         testfn);
4020   });
4021 }
4022 
TEST_F(LazyOpsTest,TestUpsampleNearest2DWithScale)4023 TEST_F(LazyOpsTest, TestUpsampleNearest2DWithScale) {
4024   int batch_size = 2;
4025   int h = 5;
4026   int w = 5;
4027   int chans = 2;
4028   double scale_h = 2.5;
4029   double scale_w = 3.4;
4030   torch::Tensor input = torch::rand(
4031       {batch_size, chans, h, w},
4032       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4033   ForEachDevice([&](const torch::Device& device) {
4034     torch::Tensor lazy_input = CopyToDevice(input, device);
4035     torch::Tensor result = torch::upsample_nearest2d(
4036         input, std::nullopt, at::ArrayRef<double>{scale_h, scale_w});
4037     torch::Tensor lazy_result = torch::upsample_nearest2d(
4038         lazy_input, std::nullopt, at::ArrayRef<double>{scale_h, scale_w});
4039     AllClose(result, lazy_result);
4040   });
4041 }
4042 
TEST_F(LazyOpsTest,TestUpsampleNearest2DBackwardWithScale)4043 TEST_F(LazyOpsTest, TestUpsampleNearest2DBackwardWithScale) {
4044   int batch_size = 2;
4045   int h = 5;
4046   int w = 5;
4047   int chans = 2;
4048   double scale_h = 2.5;
4049   double scale_w = 3.4;
4050   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4051     return torch::upsample_nearest2d(
4052         inputs[0], std::nullopt, at::ArrayRef<double>{scale_h, scale_w});
4053   };
4054   ForEachDevice([&](const torch::Device& device) {
4055     TestBackward(
4056         {torch::rand(
4057             {batch_size, chans, h, w},
4058             torch::TensorOptions(torch::kFloat)
4059                 .device(DefaultDevice())
4060                 .requires_grad(true))},
4061         device,
4062         testfn);
4063   });
4064 }
4065 
TEST_F(LazyOpsTest,TestUpsampleBilinear2D)4066 TEST_F(LazyOpsTest, TestUpsampleBilinear2D) {
4067   int batch_size = 2;
4068   int h = 5;
4069   int w = 5;
4070   int uh = 8;
4071   int uw = 8;
4072   int chans = 2;
4073   for (bool align_corners : {true, false}) {
4074     torch::Tensor input = torch::rand(
4075         {batch_size, chans, h, w},
4076         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4077     ForEachDevice([&](const torch::Device& device) {
4078       torch::Tensor lazy_input = CopyToDevice(input, device);
4079       torch::Tensor result =
4080           torch::upsample_bilinear2d(input, {uh, uw}, align_corners);
4081       torch::Tensor lazy_result =
4082           torch::upsample_bilinear2d(lazy_input, {uh, uw}, align_corners);
4083       AllClose(result, lazy_result);
4084     });
4085   }
4086 }
4087 
TEST_F(LazyOpsTest,TestUpsampleBilinear2DBackward)4088 TEST_F(LazyOpsTest, TestUpsampleBilinear2DBackward) {
4089   int batch_size = 2;
4090   int h = 5;
4091   int w = 5;
4092   int uh = 8;
4093   int uw = 8;
4094   int chans = 2;
4095   for (bool align_corners : {true, false}) {
4096     auto testfn =
4097         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4098       return torch::upsample_bilinear2d(inputs[0], {uh, uw}, align_corners);
4099     };
4100     ForEachDevice([&](const torch::Device& device) {
4101       TestBackward(
4102           {torch::rand(
4103               {batch_size, chans, h, w},
4104               torch::TensorOptions(torch::kFloat)
4105                   .device(DefaultDevice())
4106                   .requires_grad(true))},
4107           device,
4108           testfn);
4109     });
4110   }
4111 }
4112 
TEST_F(LazyOpsTest,TestAddCMul)4113 TEST_F(LazyOpsTest, TestAddCMul) {
4114   torch::Tensor a = torch::rand(
4115       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4116   torch::Tensor b = torch::rand(
4117       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4118   torch::Tensor c = torch::rand(
4119       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4120   torch::Tensor d = torch::addcmul(a, b, c, 3.1165);
4121   ForEachDevice([&](const torch::Device& device) {
4122     torch::Tensor lazy_a = CopyToDevice(a, device);
4123     torch::Tensor lazy_b = CopyToDevice(b, device);
4124     torch::Tensor lazy_c = CopyToDevice(c, device);
4125     torch::Tensor lazy_d = torch::addcmul(lazy_a, lazy_b, lazy_c, 3.1165);
4126     AllClose(d, lazy_d);
4127   });
4128 }
4129 
TEST_F(LazyOpsTest,TestAddCDiv)4130 TEST_F(LazyOpsTest, TestAddCDiv) {
4131   torch::Tensor a = torch::rand(
4132       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4133   torch::Tensor b = torch::rand(
4134       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4135   torch::Tensor c =
4136       torch::abs(torch::rand(
4137           {2, 2},
4138           torch::TensorOptions(torch::kFloat).device(DefaultDevice()))) +
4139       1.0;
4140   torch::Tensor d = torch::addcdiv(a, b, c, 3.1165);
4141   ForEachDevice([&](const torch::Device& device) {
4142     torch::Tensor lazy_a = CopyToDevice(a, device);
4143     torch::Tensor lazy_b = CopyToDevice(b, device);
4144     torch::Tensor lazy_c = CopyToDevice(c, device);
4145     torch::Tensor lazy_d = torch::addcdiv(lazy_a, lazy_b, lazy_c, 3.1165);
4146     AllClose(d, lazy_d);
4147   });
4148 }
4149 
TEST_F(LazyOpsTest,TestAddCDivWithBroadcast)4150 TEST_F(LazyOpsTest, TestAddCDivWithBroadcast) {
4151   torch::Tensor a = torch::rand(
4152       {1, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4153   torch::Tensor b = torch::rand(
4154       {3, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4155   torch::Tensor c =
4156       torch::abs(torch::rand(
4157           {1, 3},
4158           torch::TensorOptions(torch::kFloat).device(DefaultDevice()))) +
4159       1.0;
4160   torch::Tensor d = torch::addcdiv(a, b, c, 3.1165);
4161   ForEachDevice([&](const torch::Device& device) {
4162     torch::Tensor lazy_a = CopyToDevice(a, device);
4163     torch::Tensor lazy_b = CopyToDevice(b, device);
4164     torch::Tensor lazy_c = CopyToDevice(c, device);
4165     torch::Tensor lazy_d = torch::addcdiv(lazy_a, lazy_b, lazy_c, 3.1165);
4166     AllClose(d, lazy_d);
4167   });
4168 }
4169 
TEST_F(LazyOpsTest,TestSize)4170 TEST_F(LazyOpsTest, TestSize) {
4171   torch::Tensor input = torch::rand(
4172       {2, 1, 4, 6},
4173       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4174   int rank = input.dim();
4175   ForEachDevice([&](const torch::Device& device) {
4176     torch::Tensor lazy_input = CopyToDevice(input, device);
4177     for (int dim = -rank; dim < rank; ++dim) {
4178       EXPECT_EQ(torch::size(input, dim), torch::size(lazy_input, dim));
4179     }
4180   });
4181 }
4182 
TEST_F(LazyOpsTest,TestSelect)4183 TEST_F(LazyOpsTest, TestSelect) {
4184   std::vector<int64_t> input_sizes = {14, 24, 8};
4185   int rank = input_sizes.size();
4186   for (int dim = -rank; dim < rank; ++dim) {
4187     auto testfn =
4188         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4189       return torch::select(inputs[0], dim, 0);
4190     };
4191     ForEachDevice([&](const torch::Device& device) {
4192       TestBackward(
4193           {torch::rand(
4194               input_sizes,
4195               torch::TensorOptions(torch::kFloat).requires_grad(true))},
4196           device,
4197           testfn);
4198     });
4199   };
4200 }
4201 
TEST_F(LazyOpsTest,TestBernoulliScalarProb)4202 TEST_F(LazyOpsTest, TestBernoulliScalarProb) {
4203   torch::Tensor input = torch::zeros(
4204       1000, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4205   ForEachDevice([&](const torch::Device& device) {
4206     torch::Tensor lazy_input = CopyToDevice(input, device);
4207     torch::Tensor lazy_output = torch::bernoulli(lazy_input, 0.1);
4208     double frac = lazy_output.sum().item().toDouble() / input.numel();
4209     EXPECT_GT(frac, 0.06);
4210     EXPECT_LT(frac, 0.14);
4211   });
4212 }
4213 
TEST_F(LazyOpsTest,TestBernoulliTensorProb)4214 TEST_F(LazyOpsTest, TestBernoulliTensorProb) {
4215   std::vector<float> prob_values(1000, 0.1);
4216   torch::Tensor input = torch::tensor(
4217       prob_values, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4218   ForEachDevice([&](const torch::Device& device) {
4219     torch::Tensor lazy_input = CopyToDevice(input, device);
4220     torch::Tensor lazy_output = torch::bernoulli(lazy_input);
4221     double frac = lazy_output.sum().item().toDouble() / input.numel();
4222     EXPECT_GT(frac, 0.06);
4223     EXPECT_LT(frac, 0.14);
4224   });
4225 }
4226 
TEST_F(LazyOpsTest,TestBernoulliScalarProbInPlace)4227 TEST_F(LazyOpsTest, TestBernoulliScalarProbInPlace) {
4228   torch::Tensor input = torch::zeros(
4229       1000, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4230   ForEachDevice([&](const torch::Device& device) {
4231     torch::Tensor lazy_input = CopyToDevice(input, device);
4232     lazy_input.bernoulli_(0.1);
4233     double frac = lazy_input.sum().item().toDouble() / input.numel();
4234     EXPECT_GT(frac, 0.06);
4235     EXPECT_LT(frac, 0.14);
4236   });
4237 }
4238 
TEST_F(LazyOpsTest,TestBernoulliTensorProbInPlace)4239 TEST_F(LazyOpsTest, TestBernoulliTensorProbInPlace) {
4240   torch::Tensor input = torch::zeros(
4241       1000, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4242   torch::Tensor prob = torch::scalar_tensor(
4243       0.1, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4244   ForEachDevice([&](const torch::Device& device) {
4245     torch::Tensor lazy_input = CopyToDevice(input, device);
4246     torch::Tensor lazy_prob = CopyToDevice(prob, device);
4247     lazy_input.bernoulli_(lazy_prob);
4248     double frac = lazy_input.sum().item().toDouble() / input.numel();
4249     EXPECT_GT(frac, 0.06);
4250     EXPECT_LT(frac, 0.14);
4251   });
4252 }
4253 
TEST_F(LazyOpsTest,TestDropout)4254 TEST_F(LazyOpsTest, TestDropout) {
4255   torch::Tensor a = torch::rand(
4256       {17, 21}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4257   ForEachDevice([&](const torch::Device& device) {
4258     torch::Tensor lazy_a = CopyToDevice(a, device);
4259     torch::Tensor lazy_b = torch::dropout(lazy_a, 0.1, /*train=*/true);
4260     double prob =
4261         static_cast<double>(lazy_b.cpu().ne(0.0f).sum().item().toDouble()) /
4262         a.numel();
4263     EXPECT_GT(prob, 0.86);
4264     EXPECT_LT(prob, 0.94);
4265   });
4266 }
4267 
TEST_F(LazyOpsTest,TestDropoutInPlace)4268 TEST_F(LazyOpsTest, TestDropoutInPlace) {
4269   torch::Tensor a = torch::rand(
4270       {17, 21}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4271   ForEachDevice([&](const torch::Device& device) {
4272     torch::Tensor lazy_a = CopyToDevice(a, device);
4273     torch::dropout_(lazy_a, 0.1, /*train=*/true);
4274     double prob =
4275         static_cast<double>(lazy_a.cpu().ne(0.0f).sum().item().toDouble()) /
4276         a.numel();
4277     EXPECT_GT(prob, 0.85);
4278     EXPECT_LT(prob, 0.94);
4279   });
4280 }
4281 
TEST_F(LazyOpsTest,TestRandperm)4282 TEST_F(LazyOpsTest, TestRandperm) {
4283   unsigned n = 5;
4284   torch::Tensor shuffle = torch::randperm(
4285       n, torch::TensorOptions(torch::kLong).device(torch::kLazy));
4286   torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);
4287   std::vector<int64_t> shuffle_data(
4288       shuffle_cpu.data_ptr<int64_t>(), shuffle_cpu.data_ptr<int64_t>() + n);
4289   EXPECT_TRUE(
4290       shuffle_data.size() == n && torch::lazy::IsPermutation(shuffle_data));
4291 }
4292 
TEST_F(LazyOpsTest,TestSlice)4293 TEST_F(LazyOpsTest, TestSlice) {
4294   torch::Tensor a = torch::rand(
4295       {32, 24, 16},
4296       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4297   torch::Tensor b = torch::slice(a, 1, 0, 16, 1);
4298   ForEachDevice([&](const torch::Device& device) {
4299     torch::Tensor lazy_a = CopyToDevice(a, device);
4300     torch::Tensor lazy_b = torch::slice(lazy_a, 1, 0, 16, 1);
4301     AllClose(b, lazy_b);
4302   });
4303 }
4304 
TEST_F(LazyOpsTest,TestTake)4305 TEST_F(LazyOpsTest, TestTake) {
4306   torch::Tensor a = torch::rand(
4307       {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4308   torch::Tensor b = torch::randint(
4309       16, {5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4310   torch::Tensor c = torch::take(a, b);
4311   ForEachDevice([&](const torch::Device& device) {
4312     torch::Tensor lazy_a = CopyToDevice(a, device);
4313     torch::Tensor lazy_b = CopyToDevice(b, device);
4314     torch::Tensor lazy_c = torch::take(lazy_a, lazy_b);
4315     AllClose(c, lazy_c);
4316   });
4317 }
4318 
TEST_F(LazyOpsTest,TestTakeBackward)4319 TEST_F(LazyOpsTest, TestTakeBackward) {
4320   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4321     return torch::take(inputs[0], inputs[1]);
4322   };
4323   ForEachDevice([&](const torch::Device& device) {
4324     TestBackward(
4325         {torch::rand(
4326              {4, 4},
4327              torch::TensorOptions(torch::kFloat)
4328                  .device(DefaultDevice())
4329                  .requires_grad(true)),
4330          torch::randint(
4331              16,
4332              {5},
4333              torch::TensorOptions(torch::kLong).device(DefaultDevice()))},
4334         device,
4335         testfn);
4336   });
4337 }
4338 
TEST_F(LazyOpsTest,TestStack)4339 TEST_F(LazyOpsTest, TestStack) {
4340   torch::Tensor a = torch::rand(
4341       {2, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4342   torch::Tensor b = torch::rand(
4343       {2, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4344   torch::Tensor c = torch::rand(
4345       {2, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4346   int rank = a.dim() + 1;
4347   for (int dim = -rank; dim < rank; ++dim) {
4348     torch::Tensor d = torch::stack({a, b, c}, dim);
4349     ForEachDevice([&](const torch::Device& device) {
4350       torch::Tensor lazy_a = CopyToDevice(a, device);
4351       torch::Tensor lazy_b = CopyToDevice(b, device);
4352       torch::Tensor lazy_c = CopyToDevice(c, device);
4353       torch::Tensor lazy_d = torch::stack({lazy_a, lazy_b, lazy_c}, dim);
4354       AllClose(d, lazy_d);
4355     });
4356   }
4357 }
4358 
TEST_F(LazyOpsTest,TestCat)4359 TEST_F(LazyOpsTest, TestCat) {
4360   torch::Tensor a = torch::rand(
4361       {2, 1, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4362   torch::Tensor b = torch::rand(
4363       {2, 2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4364   torch::Tensor c = torch::rand(
4365       {2, 3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4366   for (int dim : {1, -2}) {
4367     torch::Tensor d = torch::cat({a, b, c}, dim);
4368     ForEachDevice([&](const torch::Device& device) {
4369       torch::Tensor lazy_a = CopyToDevice(a, device);
4370       torch::Tensor lazy_b = CopyToDevice(b, device);
4371       torch::Tensor lazy_c = CopyToDevice(c, device);
4372       torch::Tensor lazy_d = torch::cat({lazy_a, lazy_b, lazy_c}, dim);
4373       EXPECT_TRUE(d.sizes() == lazy_d.sizes() && d.dtype() == lazy_d.dtype());
4374       AllClose(d, lazy_d);
4375     });
4376   }
4377 }
4378 
TEST_F(LazyOpsTest,TestUnbind)4379 TEST_F(LazyOpsTest, TestUnbind) {
4380   torch::Tensor input = torch::rand(
4381       {4, 3, 7}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4382   int rank = input.dim();
4383   for (int dim = -rank; dim < rank; ++dim) {
4384     std::vector<torch::Tensor> output = torch::unbind(input, dim);
4385     ForEachDevice([&](const torch::Device& device) {
4386       torch::Tensor lazy_input = CopyToDevice(input, device);
4387       std::vector<torch::Tensor> lazy_output = torch::unbind(lazy_input, dim);
4388       ASSERT_EQ(output.size(), lazy_output.size());
4389       for (size_t i = 0; i < output.size(); ++i) {
4390         AllClose(output[i], lazy_output[i]);
4391       }
4392     });
4393   }
4394 }
4395 
TEST_F(LazyOpsTest,TestRepeat)4396 TEST_F(LazyOpsTest, TestRepeat) {
4397   std::vector<std::vector<int64_t>> repeats_list = {{4, 2}, {4, 2, 3}};
4398   std::vector<std::vector<int64_t>> input_size_list = {{3}, {2, 4}};
4399   for (const auto& repeats : repeats_list) {
4400     for (const auto& input_size : input_size_list) {
4401       torch::Tensor input = torch::rand(
4402           input_size,
4403           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4404       torch::Tensor output = input.repeat(repeats);
4405       ForEachDevice([&](const torch::Device& device) {
4406         torch::Tensor lazy_input = CopyToDevice(input, device);
4407         torch::Tensor lazy_output = lazy_input.repeat(repeats);
4408         AllClose(output, lazy_output);
4409       });
4410     }
4411   }
4412 }
4413 
TEST_F(LazyOpsTest,TestGather)4414 TEST_F(LazyOpsTest, TestGather) {
4415   torch::Tensor a = torch::rand(
4416       {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4417   torch::Tensor b = torch::empty(
4418       {3, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4419   for (int i = 0; i < 3; i++) {
4420     for (int j = 0; j < 3; j++) {
4421       b[i][j] = (i + j) % 3;
4422     }
4423   }
4424   for (bool sparse_grad : {false, true}) {
4425     torch::Tensor c = torch::gather(a, 1, b, sparse_grad);
4426     ForEachDevice([&](const torch::Device& device) {
4427       torch::Tensor lazy_a = CopyToDevice(a, device);
4428       torch::Tensor lazy_b = CopyToDevice(b, device);
4429       torch::Tensor lazy_c = torch::gather(lazy_a, 1, lazy_b, sparse_grad);
4430       AllClose(c, lazy_c);
4431     });
4432   }
4433 }
4434 
TEST_F(LazyOpsTest,TestScatter)4435 TEST_F(LazyOpsTest, TestScatter) {
4436   torch::Tensor a = torch::rand(
4437       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4438   torch::Tensor b = torch::rand(
4439       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4440   torch::Tensor c = torch::empty(
4441       {3, 5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4442   for (int dim = 0; dim < 2; ++dim) {
4443     for (int i = 0; i < 3; i++) {
4444       for (int j = 0; j < 5; j++) {
4445         c[i][j] = (i + j) % c.sizes()[dim];
4446       }
4447     }
4448     torch::Tensor d = torch::scatter(a, dim, c, b);
4449     ForEachDevice([&](const torch::Device& device) {
4450       torch::Tensor lazy_a = CopyToDevice(a, device);
4451       torch::Tensor lazy_b = CopyToDevice(b, device);
4452       torch::Tensor lazy_c = CopyToDevice(c, device);
4453       torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, lazy_b);
4454       AllClose(d, lazy_d);
4455     });
4456   }
4457 }
4458 
TEST_F(LazyOpsTest,TestScatterR1)4459 TEST_F(LazyOpsTest, TestScatterR1) {
4460   torch::Tensor a = torch::rand(
4461       {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4462   torch::Tensor b = torch::rand(
4463       {2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4464   torch::Tensor c = torch::empty(
4465       {2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4466   c[0] = 1;
4467   c[1] = 3;
4468   torch::Tensor d = torch::scatter(a, 0, c, b);
4469   ForEachDevice([&](const torch::Device& device) {
4470     torch::Tensor lazy_a = CopyToDevice(a, device);
4471     torch::Tensor lazy_b = CopyToDevice(b, device);
4472     torch::Tensor lazy_c = CopyToDevice(c, device);
4473     torch::Tensor lazy_d = torch::scatter(lazy_a, 0, lazy_c, lazy_b);
4474     AllClose(d, lazy_d);
4475   });
4476 }
4477 
TEST_F(LazyOpsTest,TestScatterR3)4478 TEST_F(LazyOpsTest, TestScatterR3) {
4479   torch::Tensor a = torch::rand(
4480       {3, 5, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4481   torch::Tensor b = torch::rand(
4482       {3, 4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4483   torch::Tensor c = torch::empty(
4484       {3, 4, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4485   for (int i = 0; i < 3; i++) {
4486     for (int j = 0; j < 4; j++) {
4487       for (int k = 0; k < 2; k++) {
4488         c[i][j][k] = (i + j + k) % 4;
4489       }
4490     }
4491   }
4492   torch::Tensor d = torch::scatter(a, 1, c, b);
4493   ForEachDevice([&](const torch::Device& device) {
4494     torch::Tensor lazy_a = CopyToDevice(a, device);
4495     torch::Tensor lazy_b = CopyToDevice(b, device);
4496     torch::Tensor lazy_c = CopyToDevice(c, device);
4497     torch::Tensor lazy_d = torch::scatter(lazy_a, 1, lazy_c, lazy_b);
4498     AllClose(d, lazy_d);
4499   });
4500 }
4501 
TEST_F(LazyOpsTest,TestScatterBiggerSource)4502 TEST_F(LazyOpsTest, TestScatterBiggerSource) {
4503   torch::Tensor a = torch::rand(
4504       {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4505   torch::Tensor b = torch::rand(
4506       {8, 8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4507   torch::Tensor c = torch::empty(
4508       {4, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4509   for (int i = 0; i < 4; i++) {
4510     for (int j = 0; j < 4; j++) {
4511       c[i][j] = (i + j) % 4;
4512     }
4513   }
4514   for (int dim = 0; dim < 2; ++dim) {
4515     torch::Tensor d = torch::scatter(a, dim, c, b);
4516     ForEachDevice([&](const torch::Device& device) {
4517       torch::Tensor lazy_a = CopyToDevice(a, device);
4518       torch::Tensor lazy_b = CopyToDevice(b, device);
4519       torch::Tensor lazy_c = CopyToDevice(c, device);
4520       torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, lazy_b);
4521       AllClose(d, lazy_d);
4522     });
4523   }
4524 }
4525 
TEST_F(LazyOpsTest,TestScatterScalar)4526 TEST_F(LazyOpsTest, TestScatterScalar) {
4527   torch::Tensor a = torch::rand(
4528       {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4529   torch::Scalar b = 1.0f;
4530   torch::Tensor c = torch::empty(
4531       {4, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4532   for (int i = 0; i < 4; i++) {
4533     for (int j = 0; j < 4; j++) {
4534       c[i][j] = (i + j) % 4;
4535     }
4536   }
4537   for (int dim = 0; dim < 2; ++dim) {
4538     torch::Tensor d = torch::scatter(a, dim, c, b);
4539     ForEachDevice([&](const torch::Device& device) {
4540       torch::Tensor lazy_a = CopyToDevice(a, device);
4541       torch::Tensor lazy_c = CopyToDevice(c, device);
4542       torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, b);
4543       AllClose(d, lazy_d);
4544     });
4545   }
4546 }
4547 
TEST_F(LazyOpsTest,TestScatterReduceAdd)4548 TEST_F(LazyOpsTest, TestScatterReduceAdd) {
4549   torch::Tensor a = torch::rand(
4550       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4551   torch::Tensor b = torch::rand(
4552       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4553   torch::Tensor c = torch::empty(
4554       {3, 5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4555   for (int dim = 0; dim < 2; ++dim) {
4556     for (int i = 0; i < 3; i++) {
4557       for (int j = 0; j < 5; j++) {
4558         c[i][j] = (i + j) % c.sizes()[dim];
4559       }
4560     }
4561     torch::Tensor d = torch::scatter(a, dim, c, b, "add");
4562     ForEachDevice([&](const torch::Device& device) {
4563       torch::Tensor lazy_a = CopyToDevice(a, device);
4564       torch::Tensor lazy_b = CopyToDevice(b, device);
4565       torch::Tensor lazy_c = CopyToDevice(c, device);
4566       torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, lazy_b, "add");
4567       AllClose(d, lazy_d);
4568     });
4569   }
4570 
4571   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4572   ExpectCounterChanged("lazy::scatter_out", GetIgnoredCounters());
4573 }
4574 
TEST_F(LazyOpsTest,TestScatterAdd)4575 TEST_F(LazyOpsTest, TestScatterAdd) {
4576   torch::Tensor a = torch::rand(
4577       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4578   torch::Tensor b = torch::rand(
4579       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4580   torch::Tensor c = torch::empty(
4581       {3, 5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4582   for (int dim = 0; dim < 2; ++dim) {
4583     for (int i = 0; i < 3; i++) {
4584       for (int j = 0; j < 5; j++) {
4585         c[i][j] = (i + j) % c.sizes()[dim];
4586       }
4587     }
4588     torch::Tensor d = torch::scatter_add(a, dim, c, b);
4589     ForEachDevice([&](const torch::Device& device) {
4590       torch::Tensor lazy_a = CopyToDevice(a, device);
4591       torch::Tensor lazy_b = CopyToDevice(b, device);
4592       torch::Tensor lazy_c = CopyToDevice(c, device);
4593       torch::Tensor lazy_d = torch::scatter_add(lazy_a, dim, lazy_c, lazy_b);
4594       AllClose(d, lazy_d);
4595     });
4596   }
4597 }
4598 
TEST_F(LazyOpsTest,TestScatterAddInPlace)4599 TEST_F(LazyOpsTest, TestScatterAddInPlace) {
4600   torch::Tensor b = torch::rand(
4601       {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4602   torch::Tensor c = torch::empty(
4603       {4, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4604   for (int i = 0; i < 4; i++) {
4605     for (int j = 0; j < 4; j++) {
4606       c[i][j] = (i + j) % 4;
4607     }
4608   }
4609   for (int dim = 0; dim < 2; ++dim) {
4610     ForEachDevice([&](const torch::Device& device) {
4611       torch::Tensor a = torch::rand(
4612           {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4613       torch::Tensor lazy_a = CopyToDevice(a, device);
4614       torch::Tensor d = a.scatter_add_(dim, c, b);
4615       torch::Tensor lazy_b = CopyToDevice(b, device);
4616       torch::Tensor lazy_c = CopyToDevice(c, device);
4617       torch::Tensor lazy_d = lazy_a.scatter_add_(dim, lazy_c, lazy_b);
4618       AllClose(d, lazy_d);
4619       AllClose(a, lazy_a);
4620     });
4621   }
4622 }
4623 
TEST_F(LazyOpsTest,TestIndexSelect)4624 TEST_F(LazyOpsTest, TestIndexSelect) {
4625   for (torch::ScalarType scalar_type :
4626        {torch::kFloat,
4627         torch::kByte,
4628         torch::kChar,
4629         torch::kShort,
4630         torch::kInt,
4631         torch::kLong}) {
4632     torch::Tensor a = isFloatingType(scalar_type)
4633         ? torch::rand(
4634               {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
4635         : torch::randint(
4636               100,
4637               {3, 4},
4638               torch::TensorOptions(scalar_type).device(DefaultDevice()));
4639     for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) {
4640       torch::Tensor b = torch::empty(
4641           {2}, torch::TensorOptions(index_scalar_type).device(DefaultDevice()));
4642       b[0] = 0;
4643       b[1] = 2;
4644       for (auto offset : {-2, 0}) {
4645         torch::Tensor c0 = torch::index_select(a, 0 + offset, b);
4646         torch::Tensor c1 = torch::index_select(a, 1 + offset, b);
4647         ForEachDevice([&](const torch::Device& device) {
4648           torch::Tensor lazy_a = CopyToDevice(a, device);
4649           torch::Tensor lazy_b = CopyToDevice(b, device);
4650           torch::Tensor lazy_c0 =
4651               torch::index_select(lazy_a, 0 + offset, lazy_b);
4652           torch::Tensor lazy_c1 =
4653               torch::index_select(lazy_a, 1 + offset, lazy_b);
4654           AllEqual(c0, lazy_c0);
4655           AllEqual(c1, lazy_c1);
4656         });
4657       }
4658     }
4659   }
4660 }
4661 
TEST_F(LazyOpsTest,TestIndexSelectRank0)4662 TEST_F(LazyOpsTest, TestIndexSelectRank0) {
4663   for (torch::ScalarType scalar_type :
4664        {torch::kFloat,
4665         torch::kByte,
4666         torch::kChar,
4667         torch::kShort,
4668         torch::kInt,
4669         torch::kLong}) {
4670     torch::Tensor a = isFloatingType(scalar_type)
4671         ? torch::rand(
4672               {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
4673         : torch::randint(
4674               100,
4675               {3, 4},
4676               torch::TensorOptions(scalar_type).device(DefaultDevice()));
4677     torch::Tensor b = torch::scalar_tensor(
4678         2, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4679     torch::Tensor c0 = torch::index_select(a, 0, b);
4680     torch::Tensor c1 = torch::index_select(a, 1, b);
4681     ForEachDevice([&](const torch::Device& device) {
4682       torch::Tensor lazy_a = CopyToDevice(a, device);
4683       torch::Tensor lazy_b = CopyToDevice(b, device);
4684       torch::Tensor lazy_c0 = torch::index_select(lazy_a, 0, lazy_b);
4685       torch::Tensor lazy_c1 = torch::index_select(lazy_a, 1, lazy_b);
4686       AllEqual(c0, lazy_c0);
4687       AllEqual(c1, lazy_c1);
4688     });
4689   }
4690 }
4691 
TEST_F(LazyOpsTest,TestInverse)4692 TEST_F(LazyOpsTest, TestInverse) {
4693   if (IsCuda()) {
4694     // TODO(whc) debug failure on cuda, lazy_b comes back transposed
4695     GTEST_SKIP();
4696   }
4697   torch::Tensor a = torch::randn(
4698       {5, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4699   torch::Tensor b = torch::inverse(a);
4700   ForEachDevice([&](const torch::Device& device) {
4701     torch::Tensor lazy_a = CopyToDevice(a, device);
4702     torch::Tensor lazy_b = torch::inverse(lazy_a);
4703     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-4);
4704   });
4705 }
4706 
TEST_F(LazyOpsTest,TestIsnan)4707 TEST_F(LazyOpsTest, TestIsnan) {
4708   torch::Tensor a = torch::tensor(
4709       {1.0, 2.0, std::nan("1"), 4.0},
4710       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4711   torch::Tensor b = torch::isnan(a);
4712   ForEachDevice([&](const torch::Device& device) {
4713     torch::Tensor lazy_a = CopyToDevice(a, device);
4714     torch::Tensor lazy_b = torch::isnan(lazy_a);
4715     AllEqual(b, lazy_b);
4716   });
4717   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4718   ExpectCounterChanged("lazy::isnan", GetIgnoredCounters());
4719 }
4720 
TEST_F(LazyOpsTest,TestExpand)4721 TEST_F(LazyOpsTest, TestExpand) {
4722   torch::Tensor a = torch::rand(
4723       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4724   torch::Tensor b = a.expand({2, 3, 4}, /*implicit=*/false);
4725   ForEachDevice([&](const torch::Device& device) {
4726     torch::Tensor lazy_a = CopyToDevice(a, device);
4727     torch::Tensor lazy_b = lazy_a.expand({2, 3, 4}, /*implicit=*/false);
4728     AllClose(b, lazy_b);
4729   });
4730 }
4731 
TEST_F(LazyOpsTest,TestExpandBack)4732 TEST_F(LazyOpsTest, TestExpandBack) {
4733   torch::Tensor a = torch::rand(
4734       {3, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4735   torch::Tensor b = a.expand({3, 4}, /*implicit=*/false);
4736   ForEachDevice([&](const torch::Device& device) {
4737     torch::Tensor lazy_a = CopyToDevice(a, device);
4738     torch::Tensor lazy_b = lazy_a.expand({3, 4}, /*implicit=*/false);
4739     AllClose(b, lazy_b);
4740   });
4741 }
4742 
TEST_F(LazyOpsTest,TestExpandAs)4743 TEST_F(LazyOpsTest, TestExpandAs) {
4744   torch::Tensor a = torch::rand(
4745       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4746   torch::Tensor b = torch::rand(
4747       {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4748   torch::Tensor c = torch::native::expand_as(a, b);
4749   ForEachDevice([&](const torch::Device& device) {
4750     torch::Tensor lazy_a = CopyToDevice(a, device);
4751     torch::Tensor lazy_b = CopyToDevice(b, device);
4752     torch::Tensor lazy_c = torch::native::expand_as(lazy_a, lazy_b);
4753     AllClose(c, lazy_c);
4754   });
4755 }
4756 
TEST_F(LazyOpsTest,TestEye)4757 TEST_F(LazyOpsTest, TestEye) {
4758   int n = 5;
4759   ForEachDevice([&](const torch::Device& device) {
4760     torch::Tensor out = torch::eye(
4761         n, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4762     torch::Tensor lazy_out =
4763         torch::eye(n, torch::TensorOptions(torch::kFloat).device(device));
4764     AllClose(out, lazy_out);
4765   });
4766 }
4767 
TEST_F(LazyOpsTest,TestEyeWide)4768 TEST_F(LazyOpsTest, TestEyeWide) {
4769   int lines = 3;
4770   int cols = 5;
4771   ForEachDevice([&](const torch::Device& device) {
4772     torch::Tensor out = torch::eye(
4773         lines,
4774         cols,
4775         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4776     torch::Tensor lazy_out = torch::eye(
4777         lines, cols, torch::TensorOptions(torch::kFloat).device(device));
4778     AllClose(out, lazy_out);
4779   });
4780 }
4781 
TEST_F(LazyOpsTest,TestEyeNarrow)4782 TEST_F(LazyOpsTest, TestEyeNarrow) {
4783   int lines = 5;
4784   int cols = 3;
4785   ForEachDevice([&](const torch::Device& device) {
4786     torch::Tensor out = torch::eye(
4787         lines,
4788         cols,
4789         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4790     torch::Tensor lazy_out = torch::eye(
4791         lines, cols, torch::TensorOptions(torch::kFloat).device(device));
4792     AllClose(out, lazy_out);
4793   });
4794 }
4795 
TEST_F(LazyOpsTest,TestBroadcastTensors)4796 TEST_F(LazyOpsTest, TestBroadcastTensors) {
4797   torch::Tensor a = torch::rand(
4798       {2, 1, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4799   torch::Tensor b = torch::rand(
4800       {2, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4801   std::vector<torch::Tensor> c = torch::broadcast_tensors({a, b});
4802   ForEachDevice([&](const torch::Device& device) {
4803     torch::Tensor lazy_a = CopyToDevice(a, device);
4804     torch::Tensor lazy_b = CopyToDevice(b, device);
4805     std::vector<torch::Tensor> lazy_c =
4806         torch::broadcast_tensors({lazy_a, lazy_b});
4807     ASSERT_EQ(c.size(), lazy_c.size());
4808     for (size_t i = 0; i < c.size(); ++i) {
4809       AllClose(c[i], lazy_c[i]);
4810     }
4811   });
4812 }
4813 
TEST_F(LazyOpsTest,TestOneIndex)4814 TEST_F(LazyOpsTest, TestOneIndex) {
4815   for (torch::ScalarType scalar_type :
4816        {torch::kFloat,
4817         torch::kByte,
4818         torch::kChar,
4819         torch::kShort,
4820         torch::kInt,
4821         torch::kLong}) {
4822     torch::Tensor params = isFloatingType(scalar_type)
4823         ? torch::rand(
4824               {4, 3, 5, 6, 7},
4825               torch::TensorOptions(scalar_type).device(DefaultDevice()))
4826         : torch::randint(
4827               100,
4828               {4, 3, 5, 6, 7},
4829               torch::TensorOptions(scalar_type).device(DefaultDevice()));
4830     torch::Tensor indices = torch::randint(
4831         -3,
4832         3,
4833         {2, 4, 3},
4834         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4835     torch::Tensor result = torch::index(params, {indices});
4836     ForEachDevice([&](const torch::Device& device) {
4837       torch::Tensor lazy_params = CopyToDevice(params, device);
4838       torch::Tensor lazy_indices = CopyToDevice(indices, device);
4839       torch::Tensor lazy_result = torch::index(lazy_params, {lazy_indices});
4840       AllEqual(result, lazy_result);
4841     });
4842   }
4843 }
4844 
TEST_F(LazyOpsTest,TestOneIndexTransfer)4845 TEST_F(LazyOpsTest, TestOneIndexTransfer) {
4846   for (torch::ScalarType scalar_type :
4847        {torch::kFloat,
4848         torch::kByte,
4849         torch::kChar,
4850         torch::kShort,
4851         torch::kInt,
4852         torch::kLong}) {
4853     torch::Tensor params = isFloatingType(scalar_type)
4854         ? torch::rand(
4855               {4, 3, 5, 6, 7},
4856               torch::TensorOptions(scalar_type).device(DefaultDevice()))
4857         : torch::randint(
4858               100,
4859               {4, 3, 5, 6, 7},
4860               torch::TensorOptions(scalar_type).device(DefaultDevice()));
4861     torch::Tensor indices = torch::randint(
4862         -3,
4863         3,
4864         {2, 4, 3},
4865         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4866     torch::Tensor result = torch::index(params, {indices});
4867     ForEachDevice([&](const torch::Device& device) {
4868       torch::Tensor lazy_params = CopyToDevice(params, device);
4869       torch::Tensor lazy_result = torch::index(lazy_params, {indices.cpu()});
4870       AllEqual(result, lazy_result);
4871     });
4872   }
4873 }
4874 
TEST_F(LazyOpsTest,TestNonzero)4875 TEST_F(LazyOpsTest, TestNonzero) {
4876   torch::Tensor a = torch::zeros(
4877       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4878   a[0][1] = 1.0;
4879   a[1][0] = 2.0;
4880   a[3][1] = 3.0;
4881   torch::Tensor b = torch::nonzero(a);
4882   ForEachDevice([&](const torch::Device& device) {
4883     torch::Tensor lazy_a = CopyToDevice(a, device);
4884     torch::Tensor lazy_b = torch::nonzero(lazy_a);
4885     AllClose(b, lazy_b);
4886 
4887     if (DebugUtil::ExperimentEnabled("nonzero")) {
4888       // If the nonzero support is enabled, we must not see any aten:: calls.
4889       ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4890     }
4891     ResetCounters();
4892   });
4893 }
4894 
TEST_F(LazyOpsTest,TestMaskedSelect)4895 TEST_F(LazyOpsTest, TestMaskedSelect) {
4896   torch::Tensor a = torch::rand(
4897       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4898   torch::Tensor b = torch::randint(
4899       0, 2, {5}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
4900   torch::Tensor c = torch::masked_select(a, b);
4901   ForEachDevice([&](const torch::Device& device) {
4902     torch::Tensor lazy_a = CopyToDevice(a, device);
4903     torch::Tensor lazy_b = CopyToDevice(b, device);
4904     torch::Tensor lazy_c = torch::masked_select(lazy_a, lazy_b);
4905     AllClose(c, lazy_c);
4906 
4907     if (DebugUtil::ExperimentEnabled("masked_select")) {
4908       // If the masked_select support is enabled, we must not see any aten::
4909       // calls.
4910       ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4911     }
4912     ResetCounters();
4913   });
4914 }
4915 
TEST_F(LazyOpsTest,TestMaskedScatter)4916 TEST_F(LazyOpsTest, TestMaskedScatter) {
4917   torch::Tensor a = torch::rand(
4918       {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4919   torch::Tensor b = torch::randint(
4920       0, 2, {3, 5}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
4921   torch::Tensor c = torch::rand(
4922       {15}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4923   torch::Tensor d = torch::masked_scatter(a, b, c);
4924   ForEachDevice([&](const torch::Device& device) {
4925     torch::Tensor lazy_a = CopyToDevice(a, device);
4926     torch::Tensor lazy_b = CopyToDevice(b, device);
4927     torch::Tensor lazy_c = CopyToDevice(c, device);
4928     torch::Tensor lazy_d = torch::masked_scatter(lazy_a, lazy_b, lazy_c);
4929     AllClose(d, lazy_d);
4930 
4931     if (DebugUtil::ExperimentEnabled("masked_scatter")) {
4932       // If the masked_select support is enabled, we must not see any aten::
4933       // calls.
4934       ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4935     }
4936     ResetCounters();
4937   });
4938 }
4939 
TEST_F(LazyOpsTest,TestMultiIndexHeadNull)4940 TEST_F(LazyOpsTest, TestMultiIndexHeadNull) {
4941   for (torch::ScalarType scalar_type :
4942        {torch::kFloat,
4943         torch::kByte,
4944         torch::kChar,
4945         torch::kShort,
4946         torch::kInt,
4947         torch::kLong}) {
4948     torch::Tensor params = isFloatingType(scalar_type)
4949         ? torch::rand(
4950               {4, 3, 5, 6, 7},
4951               torch::TensorOptions(scalar_type).device(DefaultDevice()))
4952         : torch::randint(
4953               100,
4954               {4, 3, 5, 6, 7},
4955               torch::TensorOptions(scalar_type).device(DefaultDevice()));
4956     torch::Tensor indices_null;
4957     torch::Tensor indices_0 = torch::randint(
4958         -3,
4959         3,
4960         {2, 4, 3},
4961         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4962     torch::Tensor indices_1 = torch::randint(
4963         -3,
4964         3,
4965         {2, 4, 3},
4966         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4967     torch::Tensor result =
4968         torch::index(params, {indices_null, indices_0, indices_1});
4969     ForEachDevice([&](const torch::Device& device) {
4970       torch::Tensor lazy_params = CopyToDevice(params, device);
4971       torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
4972       torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
4973       torch::Tensor lazy_result = torch::index(
4974           lazy_params, {indices_null, lazy_indices_0, lazy_indices_1});
4975       AllEqual(result, lazy_result);
4976     });
4977   }
4978 }
4979 
TEST_F(LazyOpsTest,TestMultiIndexMiddleNull)4980 TEST_F(LazyOpsTest, TestMultiIndexMiddleNull) {
4981   for (torch::ScalarType scalar_type :
4982        {torch::kFloat,
4983         torch::kByte,
4984         torch::kChar,
4985         torch::kShort,
4986         torch::kInt,
4987         torch::kLong}) {
4988     torch::Tensor params = isFloatingType(scalar_type)
4989         ? torch::rand(
4990               {4, 3, 5, 6, 7},
4991               torch::TensorOptions(scalar_type).device(DefaultDevice()))
4992         : torch::randint(
4993               100,
4994               {4, 3, 5, 6, 7},
4995               torch::TensorOptions(scalar_type).device(DefaultDevice()));
4996     torch::Tensor indices_0 = torch::randint(
4997         -3,
4998         3,
4999         {2, 4, 3},
5000         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5001     torch::Tensor indices_null;
5002     torch::Tensor indices_1 = torch::randint(
5003         -3,
5004         3,
5005         {2, 4, 3},
5006         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5007     torch::Tensor result =
5008         torch::index(params, {indices_0, indices_null, indices_1});
5009     ForEachDevice([&](const torch::Device& device) {
5010       torch::Tensor lazy_params = CopyToDevice(params, device);
5011       torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5012       torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5013       torch::Tensor lazy_result = torch::index(
5014           lazy_params, {lazy_indices_0, indices_null, lazy_indices_1});
5015       AllEqual(result, lazy_result);
5016     });
5017   }
5018 }
5019 
TEST_F(LazyOpsTest,TestMultiIndexTailNull)5020 TEST_F(LazyOpsTest, TestMultiIndexTailNull) {
5021   for (torch::ScalarType scalar_type :
5022        {torch::kFloat,
5023         torch::kByte,
5024         torch::kChar,
5025         torch::kShort,
5026         torch::kInt,
5027         torch::kLong}) {
5028     torch::Tensor params = isFloatingType(scalar_type)
5029         ? torch::rand(
5030               {4, 3, 5, 6, 7},
5031               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5032         : torch::randint(
5033               100,
5034               {4, 3, 5, 6, 7},
5035               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5036     torch::Tensor indices_0 = torch::randint(
5037         -3,
5038         3,
5039         {2, 4, 3},
5040         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5041     torch::Tensor indices_null;
5042     torch::Tensor indices_1 = torch::randint(
5043         -3,
5044         3,
5045         {2, 4, 3},
5046         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5047     torch::Tensor result =
5048         torch::index(params, {indices_0, indices_1, indices_null});
5049     ForEachDevice([&](const torch::Device& device) {
5050       torch::Tensor lazy_params = CopyToDevice(params, device);
5051       torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5052       torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5053       torch::Tensor lazy_result = torch::index(
5054           lazy_params, {lazy_indices_0, lazy_indices_1, indices_null});
5055       AllEqual(result, lazy_result);
5056     });
5057   }
5058 }
5059 
TEST_F(LazyOpsTest,TestMultiIndexMiddleBroadcast)5060 TEST_F(LazyOpsTest, TestMultiIndexMiddleBroadcast) {
5061   for (torch::ScalarType scalar_type :
5062        {torch::kFloat,
5063         torch::kByte,
5064         torch::kChar,
5065         torch::kShort,
5066         torch::kInt,
5067         torch::kLong}) {
5068     torch::Tensor params = isFloatingType(scalar_type)
5069         ? torch::rand(
5070               {4, 3, 5, 6, 7},
5071               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5072         : torch::randint(
5073               100,
5074               {4, 3, 5, 6, 7},
5075               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5076     torch::Tensor indices_0 = torch::randint(
5077         -3,
5078         3,
5079         {2, 4, 3},
5080         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5081     torch::Tensor indices_1 = torch::randint(
5082         -3,
5083         3,
5084         {2, 1, 3},
5085         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5086     torch::Tensor result = torch::index(params, {indices_0, indices_1});
5087     ForEachDevice([&](const torch::Device& device) {
5088       torch::Tensor lazy_params = CopyToDevice(params, device);
5089       torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5090       torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5091       torch::Tensor lazy_result =
5092           torch::index(lazy_params, {lazy_indices_0, lazy_indices_1});
5093       AllEqual(result, lazy_result);
5094     });
5095   }
5096 }
5097 
TEST_F(LazyOpsTest,TestMultiIndexTailBroadcast)5098 TEST_F(LazyOpsTest, TestMultiIndexTailBroadcast) {
5099   for (torch::ScalarType scalar_type :
5100        {torch::kFloat,
5101         torch::kByte,
5102         torch::kChar,
5103         torch::kShort,
5104         torch::kInt,
5105         torch::kLong}) {
5106     torch::Tensor params = isFloatingType(scalar_type)
5107         ? torch::rand(
5108               {4, 3, 5, 6, 7},
5109               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5110         : torch::randint(
5111               100,
5112               {4, 3, 5, 6, 7},
5113               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5114     torch::Tensor indices_0 = torch::randint(
5115         -3,
5116         3,
5117         {2, 1, 3},
5118         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5119     torch::Tensor indices_1 = torch::randint(
5120         -3,
5121         3,
5122         {2, 1},
5123         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5124     torch::Tensor result = torch::index(params, {indices_0, indices_1});
5125     ForEachDevice([&](const torch::Device& device) {
5126       torch::Tensor lazy_params = CopyToDevice(params, device);
5127       torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5128       torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5129       torch::Tensor lazy_result =
5130           torch::index(lazy_params, {lazy_indices_0, lazy_indices_1});
5131       AllEqual(result, lazy_result);
5132     });
5133   }
5134 }
5135 
TEST_F(LazyOpsTest,TestMaskIndex)5136 TEST_F(LazyOpsTest, TestMaskIndex) {
5137   for (torch::ScalarType scalar_type :
5138        {torch::kFloat,
5139         torch::kByte,
5140         torch::kChar,
5141         torch::kShort,
5142         torch::kInt,
5143         torch::kLong}) {
5144     torch::Tensor params = isFloatingType(scalar_type)
5145         ? torch::rand(
5146               {2, 2}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
5147         : torch::randint(
5148               100,
5149               {2, 2},
5150               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5151     torch::Tensor indices = torch::randint(
5152         0,
5153         2,
5154         {2, 2},
5155         torch::TensorOptions(torch::kBool).device(DefaultDevice()));
5156     torch::Tensor result = torch::index(params, {indices});
5157     ForEachDevice([&](const torch::Device& device) {
5158       torch::Tensor lazy_params = CopyToDevice(params, device);
5159       torch::Tensor lazy_indices = CopyToDevice(indices, device);
5160       torch::Tensor lazy_result = torch::index(lazy_params, {lazy_indices});
5161       AllEqual(result, lazy_result);
5162     });
5163   }
5164 }
5165 
TEST_F(LazyOpsTest,TestOneIndexPut)5166 TEST_F(LazyOpsTest, TestOneIndexPut) {
5167   for (torch::ScalarType scalar_type :
5168        {torch::kFloat,
5169         torch::kByte,
5170         torch::kChar,
5171         torch::kShort,
5172         torch::kInt,
5173         torch::kLong}) {
5174     torch::Tensor params = isFloatingType(scalar_type)
5175         ? torch::rand(
5176               {4, 3, 5, 6, 7},
5177               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5178         : torch::randint(
5179               100,
5180               {4, 3, 5, 6, 7},
5181               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5182     torch::Tensor indices = torch::randint(
5183         -3,
5184         3,
5185         {2, 4, 3},
5186         torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5187     torch::Tensor values = isFloatingType(scalar_type)
5188         ? torch::rand(
5189               {3, 5, 6, 7},
5190               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5191         : torch::randint(
5192               100,
5193               {3, 5, 6, 7},
5194               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5195     for (bool accumulate : {false, true}) {
5196       if (accumulate && IsCuda()) {
5197         GTEST_SKIP();
5198       }
5199       torch::Tensor result =
5200           torch::index_put(params, {indices}, values, accumulate);
5201       ForEachDevice([&](const torch::Device& device) {
5202         torch::Tensor lazy_params = CopyToDevice(params, device);
5203         torch::Tensor lazy_indices = CopyToDevice(indices, device);
5204         torch::Tensor lazy_values = CopyToDevice(values, device);
5205         torch::Tensor lazy_result = torch::index_put(
5206             lazy_params, {lazy_indices}, lazy_values, accumulate);
5207         AllEqual(result, lazy_result);
5208       });
5209     }
5210   }
5211 }
5212 
TEST_F(LazyOpsTest,TestOneIndexPutInPlace)5213 TEST_F(LazyOpsTest, TestOneIndexPutInPlace) {
5214   torch::Tensor indices = torch::randint(
5215       -3,
5216       3,
5217       {2, 4, 3},
5218       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5219   for (torch::ScalarType scalar_type :
5220        {torch::kFloat,
5221         torch::kByte,
5222         torch::kChar,
5223         torch::kShort,
5224         torch::kInt,
5225         torch::kLong}) {
5226     torch::Tensor values = torch::ones(
5227         {3, 5, 6, 7},
5228         torch::TensorOptions(scalar_type).device(DefaultDevice()));
5229     for (bool accumulate : {false, true}) {
5230       if (accumulate && IsCuda()) {
5231         GTEST_SKIP();
5232       }
5233       ForEachDevice([&](const torch::Device& device) {
5234         torch::Tensor params = isFloatingType(scalar_type)
5235             ? torch::rand(
5236                   {4, 3, 5, 6, 7},
5237                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
5238             : torch::randint(
5239                   100,
5240                   {4, 3, 5, 6, 7},
5241                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
5242         torch::Tensor lazy_params = CopyToDevice(params.clone(), device);
5243         torch::Tensor result =
5244             torch::index_put_(params, {indices}, values, accumulate);
5245         torch::Tensor lazy_indices = CopyToDevice(indices, device);
5246         torch::Tensor lazy_values = CopyToDevice(values, device);
5247         torch::Tensor lazy_result = torch::index_put_(
5248             lazy_params, {lazy_indices}, lazy_values, accumulate);
5249         AllEqual(result, lazy_result);
5250         AllEqual(params, lazy_params);
5251       });
5252     }
5253   }
5254 }
5255 
TEST_F(LazyOpsTest,TestOneIndexPutTransfer)5256 TEST_F(LazyOpsTest, TestOneIndexPutTransfer) {
5257   torch::Tensor indices = torch::randint(
5258       -3,
5259       3,
5260       {2, 4, 3},
5261       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5262   for (torch::ScalarType scalar_type :
5263        {torch::kFloat,
5264         torch::kByte,
5265         torch::kChar,
5266         torch::kShort,
5267         torch::kInt,
5268         torch::kLong}) {
5269     torch::Tensor params = isFloatingType(scalar_type)
5270         ? torch::rand(
5271               {4, 3, 5, 6, 7},
5272               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5273         : torch::randint(
5274               100,
5275               {4, 3, 5, 6, 7},
5276               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5277     torch::Tensor values = torch::ones(
5278         {3, 5, 6, 7},
5279         torch::TensorOptions(scalar_type).device(DefaultDevice()));
5280     for (bool accumulate : {false, true}) {
5281       if (accumulate && IsCuda()) {
5282         GTEST_SKIP();
5283       }
5284       torch::Tensor result =
5285           torch::index_put(params, {indices}, values, accumulate);
5286       ForEachDevice([&](const torch::Device& device) {
5287         torch::Tensor lazy_params = CopyToDevice(params, device);
5288         torch::Tensor lazy_values = CopyToDevice(values, device);
5289         torch::Tensor lazy_result =
5290             torch::index_put(lazy_params, {indices}, lazy_values, accumulate);
5291         AllEqual(result, lazy_result);
5292       });
5293     }
5294   }
5295 }
5296 
TEST_F(LazyOpsTest,TestMultiIndexPut)5297 TEST_F(LazyOpsTest, TestMultiIndexPut) {
5298   torch::Tensor indices_0 = torch::randint(
5299       -3,
5300       3,
5301       {2, 4, 3},
5302       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5303   torch::Tensor indices_1 = torch::randint(
5304       -3,
5305       3,
5306       {2, 4, 3},
5307       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5308   for (torch::ScalarType scalar_type :
5309        {torch::kFloat,
5310         torch::kByte,
5311         torch::kChar,
5312         torch::kShort,
5313         torch::kInt,
5314         torch::kLong}) {
5315     torch::Tensor params = isFloatingType(scalar_type)
5316         ? torch::rand(
5317               {4, 3, 5, 6, 7},
5318               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5319         : torch::randint(
5320               100,
5321               {4, 3, 5, 6, 7},
5322               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5323     torch::Tensor values = torch::ones(
5324         {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5325     for (bool accumulate : {false, true}) {
5326       if (accumulate && IsCuda()) {
5327         GTEST_SKIP();
5328       }
5329       torch::Tensor result =
5330           torch::index_put(params, {indices_0, indices_1}, values, accumulate);
5331       ForEachDevice([&](const torch::Device& device) {
5332         torch::Tensor lazy_params = CopyToDevice(params, device);
5333         torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5334         torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5335         torch::Tensor lazy_values = CopyToDevice(values, device);
5336         torch::Tensor lazy_result = torch::index_put(
5337             lazy_params,
5338             {lazy_indices_0, lazy_indices_1},
5339             lazy_values,
5340             accumulate);
5341         AllEqual(result, lazy_result);
5342       });
5343     }
5344   }
5345 }
5346 
TEST_F(LazyOpsTest,TestMultiIndexPutHeadNull)5347 TEST_F(LazyOpsTest, TestMultiIndexPutHeadNull) {
5348   torch::Tensor indices_0 = torch::randint(
5349       -3,
5350       3,
5351       {2, 4, 3},
5352       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5353   torch::Tensor indices_null;
5354   torch::Tensor indices_1 = torch::randint(
5355       -3,
5356       3,
5357       {2, 4, 3},
5358       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5359   for (torch::ScalarType scalar_type :
5360        {torch::kFloat,
5361         torch::kByte,
5362         torch::kChar,
5363         torch::kShort,
5364         torch::kInt,
5365         torch::kLong}) {
5366     torch::Tensor params = isFloatingType(scalar_type)
5367         ? torch::rand(
5368               {4, 3, 3, 6, 7},
5369               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5370         : torch::randint(
5371               100,
5372               {4, 3, 3, 6, 7},
5373               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5374     torch::Tensor values = torch::ones(
5375         {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5376     for (bool accumulate : {false, true}) {
5377       if (accumulate && IsCuda()) {
5378         GTEST_SKIP();
5379       }
5380       torch::Tensor result = torch::index_put(
5381           params, {indices_null, indices_0, indices_1}, values, accumulate);
5382       ForEachDevice([&](const torch::Device& device) {
5383         torch::Tensor lazy_params = CopyToDevice(params, device);
5384         torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5385         torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5386         torch::Tensor lazy_values = CopyToDevice(values, device);
5387         torch::Tensor lazy_result = torch::index_put(
5388             lazy_params,
5389             {indices_null, lazy_indices_0, lazy_indices_1},
5390             lazy_values,
5391             accumulate);
5392         AllEqual(result, lazy_result);
5393       });
5394     }
5395   }
5396 }
5397 
TEST_F(LazyOpsTest,TestMultiIndexPutMiddleNull)5398 TEST_F(LazyOpsTest, TestMultiIndexPutMiddleNull) {
5399   torch::Tensor indices_0 = torch::randint(
5400       -3,
5401       3,
5402       {2, 4, 3},
5403       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5404   torch::Tensor indices_null;
5405   torch::Tensor indices_1 = torch::randint(
5406       -3,
5407       3,
5408       {2, 4, 3},
5409       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5410   for (torch::ScalarType scalar_type :
5411        {torch::kFloat,
5412         torch::kByte,
5413         torch::kChar,
5414         torch::kShort,
5415         torch::kInt,
5416         torch::kLong}) {
5417     torch::Tensor params = isFloatingType(scalar_type)
5418         ? torch::rand(
5419               {4, 3, 3, 6, 7},
5420               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5421         : torch::randint(
5422               100,
5423               {4, 3, 3, 6, 7},
5424               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5425     torch::Tensor values = torch::ones(
5426         {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5427     for (bool accumulate : {false, true}) {
5428       if (accumulate && IsCuda()) {
5429         GTEST_SKIP();
5430       }
5431       torch::Tensor result = torch::index_put(
5432           params, {indices_0, indices_null, indices_1}, values, accumulate);
5433       ForEachDevice([&](const torch::Device& device) {
5434         torch::Tensor lazy_params = CopyToDevice(params, device);
5435         torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5436         torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5437         torch::Tensor lazy_values = CopyToDevice(values, device);
5438         torch::Tensor lazy_result = torch::index_put(
5439             lazy_params,
5440             {lazy_indices_0, indices_null, lazy_indices_1},
5441             lazy_values,
5442             accumulate);
5443         AllEqual(result, lazy_result);
5444       });
5445     }
5446   }
5447 }
5448 
TEST_F(LazyOpsTest,TestMultiIndexPutTailNull)5449 TEST_F(LazyOpsTest, TestMultiIndexPutTailNull) {
5450   torch::Tensor indices_0 = torch::randint(
5451       -3,
5452       3,
5453       {2, 4, 3},
5454       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5455   torch::Tensor indices_1 = torch::randint(
5456       -3,
5457       3,
5458       {2, 4, 3},
5459       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5460   torch::Tensor indices_null;
5461   for (torch::ScalarType scalar_type :
5462        {torch::kFloat,
5463         torch::kByte,
5464         torch::kChar,
5465         torch::kShort,
5466         torch::kInt,
5467         torch::kLong}) {
5468     torch::Tensor params = isFloatingType(scalar_type)
5469         ? torch::rand(
5470               {4, 3, 3, 6, 7},
5471               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5472         : torch::randint(
5473               100,
5474               {4, 3, 3, 6, 7},
5475               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5476     torch::Tensor values = torch::ones(
5477         {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5478     for (bool accumulate : {false, true}) {
5479       if (accumulate && IsCuda()) {
5480         GTEST_SKIP();
5481       }
5482       torch::Tensor result = torch::index_put(
5483           params, {indices_0, indices_1, indices_null}, values, accumulate);
5484       ForEachDevice([&](const torch::Device& device) {
5485         torch::Tensor lazy_params = CopyToDevice(params, device);
5486         torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5487         torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5488         torch::Tensor lazy_values = CopyToDevice(values, device);
5489         torch::Tensor lazy_result = torch::index_put(
5490             lazy_params,
5491             {lazy_indices_0, lazy_indices_1, indices_null},
5492             lazy_values,
5493             accumulate);
5494         AllEqual(result, lazy_result);
5495       });
5496     }
5497   }
5498 }
5499 
TEST_F(LazyOpsTest,TestMultiIndexPutMiddleBroadcast)5500 TEST_F(LazyOpsTest, TestMultiIndexPutMiddleBroadcast) {
5501   torch::Tensor indices_0 = torch::randint(
5502       -3,
5503       3,
5504       {2, 4, 3},
5505       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5506   torch::Tensor indices_1 = torch::randint(
5507       -3,
5508       3,
5509       {2, 1, 3},
5510       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5511   for (torch::ScalarType scalar_type :
5512        {torch::kFloat,
5513         torch::kByte,
5514         torch::kChar,
5515         torch::kShort,
5516         torch::kInt,
5517         torch::kLong}) {
5518     torch::Tensor params = isFloatingType(scalar_type)
5519         ? torch::rand(
5520               {4, 3, 5, 6, 7},
5521               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5522         : torch::randint(
5523               100,
5524               {4, 3, 5, 6, 7},
5525               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5526     torch::Tensor values = torch::ones(
5527         {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5528     for (bool accumulate : {false, true}) {
5529       if (accumulate && IsCuda()) {
5530         GTEST_SKIP();
5531       }
5532       torch::Tensor result =
5533           torch::index_put(params, {indices_0, indices_1}, values, accumulate);
5534       ForEachDevice([&](const torch::Device& device) {
5535         torch::Tensor lazy_params = CopyToDevice(params, device);
5536         torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5537         torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5538         torch::Tensor lazy_values = CopyToDevice(values, device);
5539         torch::Tensor lazy_result = torch::index_put(
5540             lazy_params,
5541             {lazy_indices_0, lazy_indices_1},
5542             lazy_values,
5543             accumulate);
5544         AllEqual(result, lazy_result);
5545       });
5546     }
5547   }
5548 }
5549 
TEST_F(LazyOpsTest,TestMultiIndexPutTailBroadcast)5550 TEST_F(LazyOpsTest, TestMultiIndexPutTailBroadcast) {
5551   torch::Tensor indices_0 = torch::randint(
5552       -3,
5553       3,
5554       {2, 1, 3},
5555       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5556   torch::Tensor indices_1 = torch::randint(
5557       -3,
5558       3,
5559       {2, 1},
5560       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5561   for (torch::ScalarType scalar_type :
5562        {torch::kFloat,
5563         torch::kByte,
5564         torch::kChar,
5565         torch::kShort,
5566         torch::kInt,
5567         torch::kLong}) {
5568     torch::Tensor params = isFloatingType(scalar_type)
5569         ? torch::rand(
5570               {4, 3, 5, 6, 7},
5571               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5572         : torch::randint(
5573               100,
5574               {4, 3, 5, 6, 7},
5575               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5576     torch::Tensor values = torch::ones(
5577         {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5578     for (bool accumulate : {false, true}) {
5579       if (accumulate && IsCuda()) {
5580         GTEST_SKIP();
5581       }
5582       torch::Tensor result =
5583           torch::index_put(params, {indices_0, indices_1}, values, accumulate);
5584       ForEachDevice([&](const torch::Device& device) {
5585         torch::Tensor lazy_params = CopyToDevice(params, device);
5586         torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5587         torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5588         torch::Tensor lazy_values = CopyToDevice(values, device);
5589         torch::Tensor lazy_result = torch::index_put(
5590             lazy_params,
5591             {lazy_indices_0, lazy_indices_1},
5592             lazy_values,
5593             accumulate);
5594         AllEqual(result, lazy_result);
5595       });
5596     }
5597   }
5598 }
5599 
TEST_F(LazyOpsTest,TestMaskIndexPut)5600 TEST_F(LazyOpsTest, TestMaskIndexPut) {
5601   torch::Tensor indices =
5602       torch::tensor(
5603           {0, 1}, torch::TensorOptions(torch::kByte).device(DefaultDevice()))
5604           .to(torch::kBool);
5605   for (torch::ScalarType scalar_type :
5606        {torch::kFloat,
5607         torch::kByte,
5608         torch::kChar,
5609         torch::kShort,
5610         torch::kInt,
5611         torch::kLong}) {
5612     torch::Tensor params = isFloatingType(scalar_type)
5613         ? torch::rand(
5614               {2, 2}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
5615         : torch::randint(
5616               100,
5617               {2, 2},
5618               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5619     torch::Tensor values = torch::ones(
5620         {2}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5621     for (bool accumulate : {false, true}) {
5622       torch::Tensor result =
5623           torch::index_put(params, {indices}, values, accumulate);
5624       ForEachDevice([&](const torch::Device& device) {
5625         torch::Tensor lazy_params = CopyToDevice(params, device);
5626         torch::Tensor lazy_indices = CopyToDevice(indices, device);
5627         torch::Tensor lazy_values = CopyToDevice(values, device);
5628         torch::Tensor lazy_result = torch::index_put(
5629             lazy_params, {lazy_indices}, lazy_values, accumulate);
5630         AllEqual(result, lazy_result);
5631       });
5632     }
5633   }
5634 }
5635 
TEST_F(LazyOpsTest,TestIndexPutImpl)5636 TEST_F(LazyOpsTest, TestIndexPutImpl) {
5637   torch::Tensor indices = torch::randint(
5638       -3,
5639       3,
5640       {2, 4, 3},
5641       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5642   for (torch::ScalarType scalar_type :
5643        {torch::kFloat,
5644         torch::kByte,
5645         torch::kChar,
5646         torch::kShort,
5647         torch::kInt,
5648         torch::kLong}) {
5649     torch::Tensor values = torch::ones(
5650         {3, 5, 6, 7},
5651         torch::TensorOptions(scalar_type).device(DefaultDevice()));
5652     for (bool accumulate : {false, true}) {
5653       if (accumulate && IsCuda()) {
5654         GTEST_SKIP();
5655       }
5656       ForEachDevice([&](const torch::Device& device) {
5657         torch::Tensor params = isFloatingType(scalar_type)
5658             ? torch::rand(
5659                   {4, 3, 5, 6, 7},
5660                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
5661             : torch::randint(
5662                   100,
5663                   {4, 3, 5, 6, 7},
5664                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
5665         torch::Tensor lazy_params = CopyToDevice(params.clone(), device);
5666         torch::Tensor result = torch::_index_put_impl_(
5667             params, {indices}, values, accumulate, /*unsafe=*/true);
5668         torch::Tensor lazy_indices = CopyToDevice(indices, device);
5669         torch::Tensor lazy_values = CopyToDevice(values, device);
5670         torch::Tensor lazy_result = torch::_index_put_impl_(
5671             lazy_params,
5672             {lazy_indices},
5673             lazy_values,
5674             accumulate,
5675             /*unsafe=*/true);
5676         AllEqual(result, lazy_result);
5677         AllEqual(params, lazy_params);
5678       });
5679     }
5680   }
5681 }
5682 
TEST_F(LazyOpsTest,TestIndexFillWithScalar)5683 TEST_F(LazyOpsTest, TestIndexFillWithScalar) {
5684   torch::Tensor index = torch::tensor(
5685       {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5686   torch::Scalar value = 42;
5687   for (torch::ScalarType scalar_type :
5688        {torch::kFloat,
5689         torch::kByte,
5690         torch::kChar,
5691         torch::kShort,
5692         torch::kInt,
5693         torch::kLong}) {
5694     torch::Tensor base = isFloatingType(scalar_type)
5695         ? torch::rand(
5696               {3, 4, 5},
5697               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5698         : torch::randint(
5699               100,
5700               {3, 4, 5},
5701               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5702     int rank = base.dim();
5703     for (int dim = -rank; dim < rank; ++dim) {
5704       torch::Tensor result = torch::index_fill(base, dim, index, value);
5705       ForEachDevice([&](const torch::Device& device) {
5706         torch::Tensor lazy_base = CopyToDevice(base, device);
5707         torch::Tensor lazy_index = CopyToDevice(index, device);
5708         torch::Tensor lazy_result =
5709             torch::index_fill(lazy_base, dim, lazy_index, value);
5710         AllEqual(result, lazy_result);
5711       });
5712     }
5713   }
5714 }
5715 
TEST_F(LazyOpsTest,TestIndexFillWithScalarInPlace)5716 TEST_F(LazyOpsTest, TestIndexFillWithScalarInPlace) {
5717   torch::Tensor index = torch::tensor(
5718       {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5719   torch::Scalar value = 42;
5720   int rank = 3;
5721   for (torch::ScalarType scalar_type :
5722        {torch::kFloat,
5723         torch::kByte,
5724         torch::kChar,
5725         torch::kShort,
5726         torch::kInt,
5727         torch::kLong}) {
5728     for (int dim = -rank; dim < rank; ++dim) {
5729       ForEachDevice([&](const torch::Device& device) {
5730         torch::Tensor base = isFloatingType(scalar_type)
5731             ? torch::rand(
5732                   {3, 4, 5},
5733                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
5734             : torch::randint(
5735                   100,
5736                   {3, 4, 5},
5737                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
5738         torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
5739         torch::Tensor result = base.index_fill_(dim, index, value);
5740         torch::Tensor lazy_index = CopyToDevice(index, device);
5741         torch::Tensor lazy_result =
5742             lazy_base.index_fill_(dim, lazy_index, value);
5743         AllEqual(result, lazy_result);
5744         AllEqual(base, lazy_base);
5745       });
5746     }
5747   }
5748 }
5749 
TEST_F(LazyOpsTest,TestIndexFillWithTensor)5750 TEST_F(LazyOpsTest, TestIndexFillWithTensor) {
5751   torch::Tensor index = torch::tensor(
5752       {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5753   for (torch::ScalarType scalar_type :
5754        {torch::kFloat,
5755         torch::kByte,
5756         torch::kChar,
5757         torch::kShort,
5758         torch::kInt,
5759         torch::kLong}) {
5760     torch::Tensor base = isFloatingType(scalar_type)
5761         ? torch::rand(
5762               {3, 4, 5},
5763               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5764         : torch::randint(
5765               100,
5766               {3, 4, 5},
5767               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5768     torch::Tensor value = torch::scalar_tensor(
5769         42, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5770     int rank = base.dim();
5771     for (int dim = -rank; dim < rank; ++dim) {
5772       torch::Tensor result = torch::index_fill(base, dim, index, value);
5773       ForEachDevice([&](const torch::Device& device) {
5774         torch::Tensor lazy_base = CopyToDevice(base, device);
5775         torch::Tensor lazy_index = CopyToDevice(index, device);
5776         torch::Tensor lazy_value = CopyToDevice(value, device);
5777         torch::Tensor lazy_result =
5778             torch::index_fill(lazy_base, dim, lazy_index, lazy_value);
5779         AllEqual(result, lazy_result);
5780       });
5781     }
5782   }
5783 }
5784 
TEST_F(LazyOpsTest,TestIndexFillWithTensorInPlace)5785 TEST_F(LazyOpsTest, TestIndexFillWithTensorInPlace) {
5786   torch::Tensor index = torch::tensor(
5787       {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5788   for (torch::ScalarType scalar_type :
5789        {torch::kFloat,
5790         torch::kByte,
5791         torch::kChar,
5792         torch::kShort,
5793         torch::kInt,
5794         torch::kLong}) {
5795     torch::Tensor value = torch::scalar_tensor(
5796         42, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5797     int rank = 3;
5798     for (int dim = -rank; dim < rank; ++dim) {
5799       ForEachDevice([&](const torch::Device& device) {
5800         torch::Tensor base = isFloatingType(scalar_type)
5801             ? torch::rand(
5802                   {3, 4, 5},
5803                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
5804             : torch::randint(
5805                   100,
5806                   {3, 4, 5},
5807                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
5808         torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
5809         torch::Tensor result = base.index_fill_(dim, index, value);
5810         torch::Tensor lazy_index = CopyToDevice(index, device);
5811         torch::Tensor lazy_value = CopyToDevice(value, device);
5812         torch::Tensor lazy_result =
5813             lazy_base.index_fill_(dim, lazy_index, lazy_value);
5814         AllEqual(result, lazy_result);
5815         AllEqual(base, lazy_base);
5816       });
5817     }
5818   }
5819 }
5820 
TEST_F(LazyOpsTest,TestIndexFillRank0)5821 TEST_F(LazyOpsTest, TestIndexFillRank0) {
5822   torch::Tensor index = torch::scalar_tensor(
5823       2, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5824   for (torch::ScalarType scalar_type :
5825        {torch::kFloat,
5826         torch::kByte,
5827         torch::kChar,
5828         torch::kShort,
5829         torch::kInt,
5830         torch::kLong}) {
5831     torch::Tensor base = isFloatingType(scalar_type)
5832         ? torch::rand(
5833               {3, 4, 5},
5834               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5835         : torch::randint(
5836               100,
5837               {3, 4, 5},
5838               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5839     torch::Tensor value = torch::scalar_tensor(
5840         42, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5841     int rank = base.dim();
5842     for (int dim = -rank; dim < rank; ++dim) {
5843       torch::Tensor result = torch::index_fill(base, dim, index, value);
5844       ForEachDevice([&](const torch::Device& device) {
5845         torch::Tensor lazy_base = CopyToDevice(base, device);
5846         torch::Tensor lazy_index = CopyToDevice(index, device);
5847         torch::Tensor lazy_value = CopyToDevice(value, device);
5848         torch::Tensor lazy_result =
5849             torch::index_fill(lazy_base, dim, lazy_index, lazy_value);
5850         AllEqual(result, lazy_result);
5851       });
5852     }
5853   }
5854 }
5855 
TEST_F(LazyOpsTest,TestIndexAdd)5856 TEST_F(LazyOpsTest, TestIndexAdd) {
5857   int index_size = 10;
5858   for (torch::ScalarType scalar_type :
5859        {torch::kFloat,
5860         torch::kByte,
5861         torch::kChar,
5862         torch::kShort,
5863         torch::kInt,
5864         torch::kLong}) {
5865     torch::Tensor base = isFloatingType(scalar_type)
5866         ? torch::rand(
5867               {5, 3, 7},
5868               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5869         : torch::randint(
5870               100,
5871               {5, 3, 7},
5872               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5873     int rank = base.dim();
5874     for (int dim = -rank; dim < rank; ++dim) {
5875       for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) {
5876         torch::Tensor index = torch::randint(
5877             0,
5878             base.size(dim),
5879             {index_size},
5880             torch::TensorOptions(index_scalar_type).device(DefaultDevice()));
5881         std::vector<int64_t> value_sizes(
5882             base.sizes().begin(), base.sizes().end());
5883         int canonical_dim = dim < 0 ? dim + rank : dim;
5884         value_sizes[canonical_dim] = index_size;
5885         torch::Tensor value = isFloatingType(scalar_type)
5886             ? torch::rand(
5887                   value_sizes,
5888                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
5889             : torch::randint(
5890                   100,
5891                   value_sizes,
5892                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
5893         torch::Tensor result = torch::index_add(base, dim, index, value);
5894         ForEachDevice([&](const torch::Device& device) {
5895           torch::Tensor lazy_base = CopyToDevice(base, device);
5896           torch::Tensor lazy_index = CopyToDevice(index, device);
5897           torch::Tensor lazy_value = CopyToDevice(value, device);
5898           torch::Tensor lazy_result =
5899               torch::index_add(lazy_base, dim, lazy_index, lazy_value);
5900           AllClose(result, lazy_result);
5901         });
5902       }
5903     }
5904   }
5905 }
5906 
TEST_F(LazyOpsTest,TestIndexAddInPlace)5907 TEST_F(LazyOpsTest, TestIndexAddInPlace) {
5908   int index_size = 10;
5909   int rank = 3;
5910   for (torch::ScalarType scalar_type :
5911        {torch::kFloat,
5912         torch::kByte,
5913         torch::kChar,
5914         torch::kShort,
5915         torch::kInt,
5916         torch::kLong}) {
5917     for (int dim = -rank; dim < rank; ++dim) {
5918       ForEachDevice([&](const torch::Device& device) {
5919         torch::Tensor base = isFloatingType(scalar_type)
5920             ? torch::rand(
5921                   {5, 3, 7},
5922                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
5923             : torch::randint(
5924                   100,
5925                   {5, 3, 7},
5926                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
5927         torch::Tensor index = torch::randint(
5928             0,
5929             base.size(dim),
5930             {index_size},
5931             torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5932         std::vector<int64_t> value_sizes(
5933             base.sizes().begin(), base.sizes().end());
5934         int canonical_dim = dim < 0 ? dim + rank : dim;
5935         value_sizes[canonical_dim] = index_size;
5936         torch::Tensor value = isFloatingType(scalar_type)
5937             ? torch::rand(
5938                   value_sizes,
5939                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
5940             : torch::randint(
5941                   100,
5942                   value_sizes,
5943                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
5944         torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
5945         torch::Tensor result = base.index_add_(dim, index, value);
5946         torch::Tensor lazy_index = CopyToDevice(index, device);
5947         torch::Tensor lazy_value = CopyToDevice(value, device);
5948         torch::Tensor lazy_result =
5949             lazy_base.index_add_(dim, lazy_index, lazy_value);
5950         AllClose(result, lazy_result);
5951         AllClose(base, lazy_base);
5952       });
5953     }
5954   }
5955 }
5956 
TEST_F(LazyOpsTest,TestIndexAddRank0)5957 TEST_F(LazyOpsTest, TestIndexAddRank0) {
5958   for (torch::ScalarType scalar_type :
5959        {torch::kFloat,
5960         torch::kByte,
5961         torch::kChar,
5962         torch::kShort,
5963         torch::kInt,
5964         torch::kLong}) {
5965     torch::Tensor base = isFloatingType(scalar_type)
5966         ? torch::rand(
5967               {5, 3, 7},
5968               torch::TensorOptions(scalar_type).device(DefaultDevice()))
5969         : torch::randint(
5970               100,
5971               {5, 3, 7},
5972               torch::TensorOptions(scalar_type).device(DefaultDevice()));
5973     int rank = base.dim();
5974     for (int dim = -rank; dim < rank; ++dim) {
5975       torch::Tensor index = torch::randint(
5976           0,
5977           base.size(dim),
5978           at::IntArrayRef{},
5979           torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5980       std::vector<int64_t> value_sizes(
5981           base.sizes().begin(), base.sizes().end());
5982       int canonical_dim = dim < 0 ? dim + rank : dim;
5983       value_sizes[canonical_dim] = 1;
5984       torch::Tensor value = isFloatingType(scalar_type)
5985           ? torch::rand(
5986                 value_sizes,
5987                 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5988           : torch::randint(
5989                 100,
5990                 value_sizes,
5991                 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5992       torch::Tensor result = torch::index_add(base, dim, index, value);
5993       ForEachDevice([&](const torch::Device& device) {
5994         torch::Tensor lazy_base = CopyToDevice(base, device);
5995         torch::Tensor lazy_index = CopyToDevice(index, device);
5996         torch::Tensor lazy_value = CopyToDevice(value, device);
5997         torch::Tensor lazy_result =
5998             torch::index_add(lazy_base, dim, lazy_index, lazy_value);
5999         AllEqual(result, lazy_result);
6000       });
6001     }
6002   }
6003 }
6004 
TEST_F(LazyOpsTest,TestIndexCopy)6005 TEST_F(LazyOpsTest, TestIndexCopy) {
6006   for (torch::ScalarType scalar_type :
6007        {torch::kFloat,
6008         torch::kByte,
6009         torch::kChar,
6010         torch::kShort,
6011         torch::kInt,
6012         torch::kLong}) {
6013     torch::Tensor base = isFloatingType(scalar_type)
6014         ? torch::rand(
6015               {5, 3, 7},
6016               torch::TensorOptions(scalar_type).device(DefaultDevice()))
6017         : torch::randint(
6018               100,
6019               {5, 3, 7},
6020               torch::TensorOptions(scalar_type).device(DefaultDevice()));
6021     int rank = base.dim();
6022     for (int dim = -rank; dim < rank; ++dim) {
6023       torch::Tensor index = torch::randperm(
6024           base.size(dim),
6025           torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6026       torch::Tensor value = isFloatingType(scalar_type)
6027           ? torch::rand(
6028                 base.sizes(),
6029                 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6030           : torch::randint(
6031                 100,
6032                 base.sizes(),
6033                 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6034       torch::Tensor result = torch::index_copy(base, dim, index, value);
6035       ForEachDevice([&](const torch::Device& device) {
6036         torch::Tensor lazy_base = CopyToDevice(base, device);
6037         torch::Tensor lazy_index = CopyToDevice(index, device);
6038         torch::Tensor lazy_value = CopyToDevice(value, device);
6039         torch::Tensor lazy_result =
6040             torch::index_copy(lazy_base, dim, lazy_index, lazy_value);
6041         AllEqual(result, lazy_result);
6042       });
6043     }
6044   }
6045 }
6046 
TEST_F(LazyOpsTest,TestIndexCopyInPlace)6047 TEST_F(LazyOpsTest, TestIndexCopyInPlace) {
6048   if (IsCuda()) {
6049     GTEST_SKIP();
6050   }
6051   int index_size = 10;
6052   int rank = 3;
6053   for (torch::ScalarType scalar_type :
6054        {torch::kFloat,
6055         torch::kByte,
6056         torch::kChar,
6057         torch::kShort,
6058         torch::kInt,
6059         torch::kLong}) {
6060     for (int dim = -rank; dim < rank; ++dim) {
6061       ForEachDevice([&](const torch::Device& device) {
6062         torch::Tensor base = isFloatingType(scalar_type)
6063             ? torch::rand(
6064                   {5, 3, 7},
6065                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
6066             : torch::randint(
6067                   100,
6068                   {5, 3, 7},
6069                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
6070         torch::Tensor index = torch::randint(
6071             0,
6072             base.size(dim),
6073             {index_size},
6074             torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6075         std::vector<int64_t> value_sizes(
6076             base.sizes().begin(), base.sizes().end());
6077         int canonical_dim = dim < 0 ? dim + rank : dim;
6078         value_sizes[canonical_dim] = index_size;
6079         torch::Tensor value = isFloatingType(scalar_type)
6080             ? torch::rand(
6081                   value_sizes,
6082                   torch::TensorOptions(scalar_type).device(DefaultDevice()))
6083             : torch::randint(
6084                   100,
6085                   value_sizes,
6086                   torch::TensorOptions(scalar_type).device(DefaultDevice()));
6087         torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
6088         torch::Tensor result = base.index_copy_(dim, index, value);
6089         torch::Tensor lazy_index = CopyToDevice(index, device);
6090         torch::Tensor lazy_value = CopyToDevice(value, device);
6091         torch::Tensor lazy_result =
6092             lazy_base.index_copy_(dim, lazy_index, lazy_value);
6093         AllEqual(result, lazy_result);
6094         AllEqual(base, lazy_base);
6095       });
6096     }
6097   }
6098 }
6099 
TEST_F(LazyOpsTest,TestIndexCopyRank0)6100 TEST_F(LazyOpsTest, TestIndexCopyRank0) {
6101   for (torch::ScalarType scalar_type :
6102        {torch::kFloat,
6103         torch::kByte,
6104         torch::kChar,
6105         torch::kShort,
6106         torch::kInt,
6107         torch::kLong}) {
6108     torch::Tensor base = isFloatingType(scalar_type)
6109         ? torch::rand(
6110               {5, 3, 7},
6111               torch::TensorOptions(scalar_type).device(DefaultDevice()))
6112         : torch::randint(
6113               100,
6114               {5, 3, 7},
6115               torch::TensorOptions(scalar_type).device(DefaultDevice()));
6116     int rank = base.dim();
6117     for (int dim = -rank; dim < rank; ++dim) {
6118       torch::Tensor index = torch::randint(
6119           0,
6120           base.size(dim),
6121           at::IntArrayRef{},
6122           torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6123       std::vector<int64_t> value_sizes(
6124           base.sizes().begin(), base.sizes().end());
6125       int canonical_dim = dim < 0 ? dim + rank : dim;
6126       value_sizes[canonical_dim] = 1;
6127       torch::Tensor value = isFloatingType(scalar_type)
6128           ? torch::rand(
6129                 value_sizes,
6130                 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6131           : torch::randint(
6132                 100,
6133                 value_sizes,
6134                 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6135       torch::Tensor result = torch::index_copy(base, dim, index, value);
6136       ForEachDevice([&](const torch::Device& device) {
6137         torch::Tensor lazy_base = CopyToDevice(base, device);
6138         torch::Tensor lazy_index = CopyToDevice(index, device);
6139         torch::Tensor lazy_value = CopyToDevice(value, device);
6140         torch::Tensor lazy_result =
6141             torch::index_copy(lazy_base, dim, lazy_index, lazy_value);
6142         AllEqual(result, lazy_result);
6143       });
6144     }
6145   }
6146 }
6147 
TEST_F(LazyOpsTest,TestRelu)6148 TEST_F(LazyOpsTest, TestRelu) {
6149   torch::Tensor input = torch::rand(
6150       {2, 1, 4, 6},
6151       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6152   torch::Tensor output = torch::relu(input);
6153   ForEachDevice([&](const torch::Device& device) {
6154     torch::Tensor lazy_input = CopyToDevice(input, device);
6155     torch::Tensor lazy_output = torch::relu(lazy_input);
6156     AllClose(output, lazy_output);
6157   });
6158 }
6159 
TEST_F(LazyOpsTest,TestReluInPlace)6160 TEST_F(LazyOpsTest, TestReluInPlace) {
6161   torch::Tensor input = torch::rand(
6162       {2, 1, 4, 6},
6163       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6164   ForEachDevice([&](const torch::Device& device) {
6165     torch::Tensor lazy_input = CopyToDevice(input, device);
6166     torch::Tensor output = torch::relu_(input);
6167     torch::Tensor lazy_output = torch::relu_(lazy_input);
6168     AllClose(output, lazy_output);
6169     AllClose(input, lazy_input);
6170   });
6171 }
6172 
TEST_F(LazyOpsTest,TestHardshrink)6173 TEST_F(LazyOpsTest, TestHardshrink) {
6174   torch::Tensor input = torch::randn(
6175       {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6176   torch::Tensor output = torch::hardshrink(input);
6177   ForEachDevice([&](const torch::Device& device) {
6178     torch::Tensor lazy_input = CopyToDevice(input, device);
6179     torch::Tensor lazy_output = torch::hardshrink(lazy_input);
6180     AllClose(output, lazy_output);
6181   });
6182 }
6183 
TEST_F(LazyOpsTest,TestHardSigmoid)6184 TEST_F(LazyOpsTest, TestHardSigmoid) {
6185   torch::Tensor input = torch::randn(
6186       {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6187   torch::Tensor output = torch::hardsigmoid(input);
6188   ForEachDevice([&](const torch::Device& device) {
6189     torch::Tensor lazy_input = CopyToDevice(input, device);
6190     torch::Tensor lazy_output = torch::hardsigmoid(lazy_input);
6191     AllClose(output, lazy_output);
6192   });
6193 }
6194 
TEST_F(LazyOpsTest,TestHardSigmoidInPlace)6195 TEST_F(LazyOpsTest, TestHardSigmoidInPlace) {
6196   ForEachDevice([&](const torch::Device& device) {
6197     torch::Tensor input = torch::randn(
6198         {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6199     torch::Tensor lazy_input = CopyToDevice(input, device);
6200     torch::Tensor output = torch::hardsigmoid_(input);
6201     torch::Tensor lazy_output = torch::hardsigmoid_(lazy_input);
6202     AllClose(input, lazy_input);
6203     AllClose(output, lazy_output);
6204   });
6205 }
6206 
TEST_F(LazyOpsTest,TestHardSigmoidBackward)6207 TEST_F(LazyOpsTest, TestHardSigmoidBackward) {
6208   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
6209     return torch::hardsigmoid(inputs[0]);
6210   };
6211   ForEachDevice([&](const torch::Device& device) {
6212     TestBackward(
6213         {torch::randn(
6214             {10},
6215             torch::TensorOptions(torch::kFloat)
6216                 .device(DefaultDevice())
6217                 .requires_grad(true))},
6218         device,
6219         testfn);
6220   });
6221 }
6222 
TEST_F(LazyOpsTest,TestSoftshrink)6223 TEST_F(LazyOpsTest, TestSoftshrink) {
6224   torch::Tensor input = torch::randn(
6225       {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6226   torch::Tensor output = torch::softshrink(input);
6227   ForEachDevice([&](const torch::Device& device) {
6228     torch::Tensor lazy_input = CopyToDevice(input, device);
6229     torch::Tensor lazy_output = torch::softshrink(lazy_input);
6230     AllClose(output, lazy_output);
6231   });
6232 }
6233 
TEST_F(LazyOpsTest,TestHardtanh)6234 TEST_F(LazyOpsTest, TestHardtanh) {
6235   torch::Tensor input = torch::randn(
6236       {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6237   torch::Tensor output = torch::hardtanh(input);
6238   ForEachDevice([&](const torch::Device& device) {
6239     torch::Tensor lazy_input = CopyToDevice(input, device);
6240     torch::Tensor lazy_output = torch::hardtanh(lazy_input);
6241     AllClose(output, lazy_output);
6242   });
6243 }
6244 
TEST_F(LazyOpsTest,TestHardtanhInPlace)6245 TEST_F(LazyOpsTest, TestHardtanhInPlace) {
6246   torch::Tensor input = torch::randn(
6247       {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6248   ForEachDevice([&](const torch::Device& device) {
6249     torch::Tensor lazy_input = CopyToDevice(input, device);
6250     torch::Tensor output = torch::hardtanh_(input);
6251     torch::Tensor lazy_output = torch::hardtanh_(lazy_input);
6252     AllClose(output, lazy_output);
6253     AllClose(input, lazy_input);
6254   });
6255 }
6256 
TEST_F(LazyOpsTest,TestLeakyRelu)6257 TEST_F(LazyOpsTest, TestLeakyRelu) {
6258   torch::Tensor input = torch::rand(
6259       {2, 1, 4, 6},
6260       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6261   double negative_slope = 0.01;
6262   torch::Tensor output = torch::leaky_relu(input, negative_slope);
6263   ForEachDevice([&](const torch::Device& device) {
6264     torch::Tensor lazy_input = CopyToDevice(input, device);
6265     torch::Tensor lazy_output = torch::leaky_relu(lazy_input, negative_slope);
6266     AllClose(output, lazy_output);
6267   });
6268 }
6269 
TEST_F(LazyOpsTest,TestLeakyReluInPlace)6270 TEST_F(LazyOpsTest, TestLeakyReluInPlace) {
6271   torch::Tensor input = torch::rand(
6272       {2, 1, 4, 6},
6273       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6274   double negative_slope = 0.01;
6275   ForEachDevice([&](const torch::Device& device) {
6276     torch::Tensor lazy_input = CopyToDevice(input, device);
6277     torch::Tensor output = torch::leaky_relu_(input, negative_slope);
6278     torch::Tensor lazy_output = torch::leaky_relu_(lazy_input, negative_slope);
6279     AllClose(output, lazy_output);
6280     AllClose(input, lazy_input);
6281   });
6282 }
6283 
TEST_F(LazyOpsTest,TestExp)6284 TEST_F(LazyOpsTest, TestExp) {
6285   torch::Tensor a = torch::rand(
6286       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6287   torch::Tensor b = torch::exp(a);
6288   ForEachDevice([&](const torch::Device& device) {
6289     torch::Tensor lazy_a = CopyToDevice(a, device);
6290     torch::Tensor lazy_b = torch::exp(lazy_a);
6291     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6292   });
6293 }
6294 
TEST_F(LazyOpsTest,TestExpm1)6295 TEST_F(LazyOpsTest, TestExpm1) {
6296   torch::Tensor a = torch::rand(
6297       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6298   torch::Tensor b = torch::expm1(a);
6299   ForEachDevice([&](const torch::Device& device) {
6300     torch::Tensor lazy_a = CopyToDevice(a, device);
6301     torch::Tensor lazy_b = torch::expm1(lazy_a);
6302     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6303   });
6304 }
6305 
TEST_F(LazyOpsTest,TestLog)6306 TEST_F(LazyOpsTest, TestLog) {
6307   torch::Tensor a = torch::rand(
6308       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6309   torch::Tensor b = torch::log(a);
6310   ForEachDevice([&](const torch::Device& device) {
6311     torch::Tensor lazy_a = CopyToDevice(a, device);
6312     torch::Tensor lazy_b = torch::log(lazy_a);
6313     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6314   });
6315 }
6316 
TEST_F(LazyOpsTest,TestLog2)6317 TEST_F(LazyOpsTest, TestLog2) {
6318   torch::Tensor a = torch::rand(
6319       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6320   torch::Tensor b = torch::log2(a);
6321   ForEachDevice([&](const torch::Device& device) {
6322     torch::Tensor lazy_a = CopyToDevice(a, device);
6323     torch::Tensor lazy_b = torch::log2(lazy_a);
6324     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6325   });
6326 }
6327 
TEST_F(LazyOpsTest,TestLog10)6328 TEST_F(LazyOpsTest, TestLog10) {
6329   torch::Tensor a = torch::rand(
6330       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6331   torch::Tensor b = torch::log10(a);
6332   ForEachDevice([&](const torch::Device& device) {
6333     torch::Tensor lazy_a = CopyToDevice(a, device);
6334     torch::Tensor lazy_b = torch::log10(lazy_a);
6335     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6336   });
6337 }
6338 
TEST_F(LazyOpsTest,TestLog1p)6339 TEST_F(LazyOpsTest, TestLog1p) {
6340   torch::Tensor a = torch::rand(
6341       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6342   torch::Tensor b = torch::log1p(a);
6343   ForEachDevice([&](const torch::Device& device) {
6344     torch::Tensor lazy_a = CopyToDevice(a, device);
6345     torch::Tensor lazy_b = torch::log1p(lazy_a);
6346     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6347   });
6348 }
6349 
TEST_F(LazyOpsTest,TestErf)6350 TEST_F(LazyOpsTest, TestErf) {
6351   torch::Tensor a = torch::randn(
6352       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6353   torch::Tensor b = torch::erf(a);
6354   ForEachDevice([&](const torch::Device& device) {
6355     torch::Tensor lazy_a = CopyToDevice(a, device);
6356     torch::Tensor lazy_b = torch::erf(lazy_a);
6357     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6358   });
6359 }
6360 
TEST_F(LazyOpsTest,TestErfc)6361 TEST_F(LazyOpsTest, TestErfc) {
6362   torch::Tensor a = torch::randn(
6363       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6364   torch::Tensor b = torch::erfc(a);
6365   ForEachDevice([&](const torch::Device& device) {
6366     torch::Tensor lazy_a = CopyToDevice(a, device);
6367     torch::Tensor lazy_b = torch::erfc(lazy_a);
6368     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6369   });
6370 }
6371 
TEST_F(LazyOpsTest,TestErfinv)6372 TEST_F(LazyOpsTest, TestErfinv) {
6373   torch::Tensor a = torch::rand(
6374       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6375   torch::Tensor b = torch::erfinv(a);
6376   ForEachDevice([&](const torch::Device& device) {
6377     torch::Tensor lazy_a = CopyToDevice(a, device);
6378     torch::Tensor lazy_b = torch::erfinv(lazy_a);
6379     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6380   });
6381 }
6382 
TEST_F(LazyOpsTest,TestSqrt)6383 TEST_F(LazyOpsTest, TestSqrt) {
6384   torch::Tensor a = torch::abs(torch::rand(
6385       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6386   torch::Tensor b = torch::sqrt(a);
6387   ForEachDevice([&](const torch::Device& device) {
6388     torch::Tensor lazy_a = CopyToDevice(a, device);
6389     torch::Tensor lazy_b = torch::sqrt(lazy_a);
6390     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6391   });
6392 }
6393 
TEST_F(LazyOpsTest,TestRsqrt)6394 TEST_F(LazyOpsTest, TestRsqrt) {
6395   torch::Tensor a = torch::abs(torch::rand(
6396       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6397   torch::Tensor b = torch::rsqrt(a);
6398   ForEachDevice([&](const torch::Device& device) {
6399     torch::Tensor lazy_a = CopyToDevice(a, device);
6400     torch::Tensor lazy_b = torch::rsqrt(lazy_a);
6401     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6402   });
6403 }
6404 
TEST_F(LazyOpsTest,TestReciprocal)6405 TEST_F(LazyOpsTest, TestReciprocal) {
6406   torch::Tensor a = torch::randn(
6407       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6408   torch::Tensor b = torch::reciprocal(a);
6409   ForEachDevice([&](const torch::Device& device) {
6410     torch::Tensor lazy_a = CopyToDevice(a, device);
6411     torch::Tensor lazy_b = torch::reciprocal(lazy_a);
6412     AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6413   });
6414 }
6415 
TEST_F(LazyOpsTest,TestPowTensorScalar)6416 TEST_F(LazyOpsTest, TestPowTensorScalar) {
6417   torch::Tensor base = torch::rand(
6418       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6419   torch::Scalar exponent = 4.09;
6420   torch::Tensor result = torch::pow(base, exponent);
6421   ForEachDevice([&](const torch::Device& device) {
6422     torch::Tensor lazy_base = CopyToDevice(base, device);
6423     torch::Tensor lazy_result = torch::pow(lazy_base, exponent);
6424     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6425   });
6426 }
6427 
TEST_F(LazyOpsTest,TestPowTensorScalarInPlace)6428 TEST_F(LazyOpsTest, TestPowTensorScalarInPlace) {
6429   torch::Tensor base = torch::rand(
6430       {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6431   torch::Scalar exponent = 4.09;
6432   ForEachDevice([&](const torch::Device& device) {
6433     torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
6434     torch::Tensor result = base.pow_(exponent);
6435     torch::Tensor lazy_result = lazy_base.pow_(exponent);
6436     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6437     AllClose(base, lazy_base, /*rtol=*/1e-3, /*atol=*/1e-5);
6438   });
6439 }
6440 
TEST_F(LazyOpsTest,TestPowTensorTensor)6441 TEST_F(LazyOpsTest, TestPowTensorTensor) {
6442   torch::Tensor base = torch::abs(torch::rand(
6443       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6444   torch::Tensor exponent = torch::rand(
6445       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6446   torch::Tensor result = torch::pow(base, exponent);
6447   ForEachDevice([&](const torch::Device& device) {
6448     torch::Tensor lazy_base = CopyToDevice(base, device);
6449     torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6450     torch::Tensor lazy_result = torch::pow(lazy_base, lazy_exponent);
6451     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6452   });
6453 }
6454 
TEST_F(LazyOpsTest,TestPowTensorTensorInPlace)6455 TEST_F(LazyOpsTest, TestPowTensorTensorInPlace) {
6456   torch::Tensor base = torch::abs(torch::rand(
6457       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6458   torch::Tensor exponent = torch::rand(
6459       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6460   ForEachDevice([&](const torch::Device& device) {
6461     torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
6462     torch::Tensor result = base.pow_(exponent);
6463     torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6464     torch::Tensor lazy_result = lazy_base.pow_(lazy_exponent);
6465     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6466     AllClose(base, lazy_base, /*rtol=*/1e-3, /*atol=*/1e-5);
6467   });
6468 }
6469 
TEST_F(LazyOpsTest,TestPowTensorTensorBroadcast)6470 TEST_F(LazyOpsTest, TestPowTensorTensorBroadcast) {
6471   torch::Tensor base = torch::abs(torch::rand(
6472       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6473   torch::Tensor exponent = torch::rand(
6474       {4, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6475   torch::Tensor result = torch::pow(base, exponent);
6476   ForEachDevice([&](const torch::Device& device) {
6477     torch::Tensor lazy_base = CopyToDevice(base, device);
6478     torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6479     torch::Tensor lazy_result = torch::pow(lazy_base, lazy_exponent);
6480     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6481   });
6482 }
6483 
TEST_F(LazyOpsTest,TestPowScalarTensor)6484 TEST_F(LazyOpsTest, TestPowScalarTensor) {
6485   torch::Scalar base = 3.5;
6486   torch::Tensor exponent = torch::rand({4, 2});
6487   torch::Tensor result = torch::pow(base, exponent);
6488   ForEachDevice([&](const torch::Device& device) {
6489     torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6490     torch::Tensor lazy_result = torch::pow(base, lazy_exponent);
6491     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6492   });
6493 }
6494 
TEST_F(LazyOpsTest,TestPowIntExponent)6495 TEST_F(LazyOpsTest, TestPowIntExponent) {
6496   torch::Tensor base = torch::abs(torch::rand(
6497       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6498   torch::Scalar exponent = 3;
6499   torch::Tensor result = torch::pow(base, exponent);
6500   ForEachDevice([&](const torch::Device& device) {
6501     torch::Tensor lazy_base = CopyToDevice(base, device);
6502     torch::Tensor lazy_result = torch::pow(lazy_base, exponent);
6503     AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6504   });
6505 }
6506 
TEST_F(LazyOpsTest,TestFmodScalar)6507 TEST_F(LazyOpsTest, TestFmodScalar) {
6508   torch::Tensor a =
6509       torch::rand(
6510           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6511       100.0;
6512   torch::Scalar divisor = 2.0;
6513   torch::Tensor b = torch::fmod(a, divisor);
6514   ForEachDevice([&](const torch::Device& device) {
6515     torch::Tensor lazy_a = CopyToDevice(a, device);
6516     torch::Tensor lazy_b = torch::fmod(lazy_a, divisor);
6517     AllClose(b, lazy_b);
6518   });
6519 }
6520 
TEST_F(LazyOpsTest,TestFmodScalarInPlace)6521 TEST_F(LazyOpsTest, TestFmodScalarInPlace) {
6522   torch::Scalar divisor = 2.0;
6523   ForEachDevice([&](const torch::Device& device) {
6524     torch::Tensor a =
6525         torch::rand(
6526             {2, 2},
6527             torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6528         100.0;
6529     torch::Tensor lazy_a = CopyToDevice(a, device);
6530     torch::Tensor b = a.fmod_(divisor);
6531     torch::Tensor lazy_b = lazy_a.fmod_(divisor);
6532     AllClose(b, lazy_b);
6533     AllClose(a, lazy_a);
6534   });
6535 }
6536 
TEST_F(LazyOpsTest,TestFmodTensor)6537 TEST_F(LazyOpsTest, TestFmodTensor) {
6538   torch::Tensor a =
6539       torch::rand(
6540           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6541       100.0;
6542   torch::Tensor b =
6543       torch::rand(
6544           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6545       10.0;
6546   torch::Tensor c = torch::fmod(a, b);
6547   ForEachDevice([&](const torch::Device& device) {
6548     torch::Tensor lazy_a = CopyToDevice(a, device);
6549     torch::Tensor lazy_b = CopyToDevice(b, device);
6550     torch::Tensor lazy_c = torch::fmod(lazy_a, lazy_b);
6551     AllClose(c, lazy_c);
6552   });
6553 }
6554 
TEST_F(LazyOpsTest,TestFmodTensorInPlace)6555 TEST_F(LazyOpsTest, TestFmodTensorInPlace) {
6556   torch::Tensor b =
6557       torch::rand(
6558           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6559       10.0;
6560   ForEachDevice([&](const torch::Device& device) {
6561     torch::Tensor a =
6562         torch::rand(
6563             {2, 2},
6564             torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6565         100.0;
6566     torch::Tensor lazy_a = CopyToDevice(a, device);
6567     torch::Tensor c = a.fmod_(b);
6568     torch::Tensor lazy_b = CopyToDevice(b, device);
6569     torch::Tensor lazy_c = lazy_a.fmod_(lazy_b);
6570     AllClose(c, lazy_c);
6571     AllClose(a, lazy_a);
6572   });
6573 }
6574 
TEST_F(LazyOpsTest,TestRemainderScalar)6575 TEST_F(LazyOpsTest, TestRemainderScalar) {
6576   torch::Tensor a =
6577       torch::randn(
6578           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6579       100.0;
6580   torch::Scalar divisor = -2.0;
6581   torch::Tensor b = torch::remainder(a, divisor);
6582   ForEachDevice([&](const torch::Device& device) {
6583     torch::Tensor lazy_a = CopyToDevice(a, device);
6584     torch::Tensor lazy_b = torch::remainder(lazy_a, divisor);
6585     AllClose(b, lazy_b);
6586   });
6587 }
6588 
TEST_F(LazyOpsTest,TestRemainderScalarInPlace)6589 TEST_F(LazyOpsTest, TestRemainderScalarInPlace) {
6590   torch::Scalar divisor = -2.0;
6591   ForEachDevice([&](const torch::Device& device) {
6592     torch::Tensor a =
6593         torch::randn(
6594             {2, 2},
6595             torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6596         100.0;
6597     torch::Tensor lazy_a = CopyToDevice(a, device);
6598     torch::Tensor b = a.remainder_(divisor);
6599     torch::Tensor lazy_b = lazy_a.remainder_(divisor);
6600     AllClose(b, lazy_b);
6601     AllClose(a, lazy_a);
6602   });
6603 }
6604 
TEST_F(LazyOpsTest,TestRemainderTensor)6605 TEST_F(LazyOpsTest, TestRemainderTensor) {
6606   torch::Tensor a =
6607       torch::randn(
6608           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6609       100.0;
6610   torch::Tensor b =
6611       torch::randn(
6612           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6613       10.0;
6614   torch::Tensor c = torch::remainder(a, b);
6615   ForEachDevice([&](const torch::Device& device) {
6616     torch::Tensor lazy_a = CopyToDevice(a, device);
6617     torch::Tensor lazy_b = CopyToDevice(b, device);
6618     torch::Tensor lazy_c = torch::remainder(lazy_a, lazy_b);
6619     AllClose(c, lazy_c, /*rtol=*/1e-4, /*atol=*/1e-6);
6620   });
6621 }
6622 
TEST_F(LazyOpsTest,TestRemainderTensorInPlace)6623 TEST_F(LazyOpsTest, TestRemainderTensorInPlace) {
6624   torch::Tensor b =
6625       torch::randn(
6626           {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6627       10.0;
6628   ForEachDevice([&](const torch::Device& device) {
6629     torch::Tensor a =
6630         torch::randn(
6631             {2, 2},
6632             torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6633         100.0;
6634     torch::Tensor lazy_a = CopyToDevice(a, device);
6635     torch::Tensor c = a.remainder_(b);
6636     torch::Tensor lazy_b = CopyToDevice(b, device);
6637     torch::Tensor lazy_c = lazy_a.remainder_(lazy_b);
6638     AllClose(c, lazy_c, /*rtol=*/1e-4, /*atol=*/1e-6);
6639     AllClose(a, lazy_a, /*rtol=*/1e-4, /*atol=*/1e-6);
6640   });
6641 }
6642 
TEST_F(LazyOpsTest,TestWhere)6643 TEST_F(LazyOpsTest, TestWhere) {
6644   torch::Tensor a = torch::rand(
6645       {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6646   torch::Tensor b = torch::rand(
6647       {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6648   torch::Tensor c = torch::empty(
6649       {3, 3}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
6650   for (int i = 0; i < 3; ++i) {
6651     for (int j = 0; j < 3; ++j) {
6652       c[i][j] = i == j;
6653     }
6654   }
6655   torch::Tensor d = torch::where(c, a, b);
6656   ForEachDevice([&](const torch::Device& device) {
6657     torch::Tensor lazy_a = CopyToDevice(a, device);
6658     torch::Tensor lazy_b = CopyToDevice(b, device);
6659     torch::Tensor lazy_c = CopyToDevice(c, device);
6660     torch::Tensor lazy_d = torch::where(lazy_c, lazy_a, lazy_b);
6661     AllClose(d, lazy_d);
6662   });
6663 }
6664 
TEST_F(LazyOpsTest,TestWhereBroadcast)6665 TEST_F(LazyOpsTest, TestWhereBroadcast) {
6666   torch::Tensor a = torch::rand(
6667       {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6668   torch::Tensor b = torch::zeros(
6669       {}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6670   torch::Tensor c = torch::empty(
6671       {3, 3}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
6672   for (int i = 0; i < 3; ++i) {
6673     for (int j = 0; j < 3; ++j) {
6674       c[i][j] = i == j;
6675     }
6676   }
6677   torch::Tensor d = torch::where(c, a, b);
6678   ForEachDevice([&](const torch::Device& device) {
6679     torch::Tensor lazy_a = CopyToDevice(a, device);
6680     torch::Tensor lazy_b = CopyToDevice(b, device);
6681     torch::Tensor lazy_c = CopyToDevice(c, device);
6682     torch::Tensor lazy_d = torch::where(lazy_c, lazy_a, lazy_b);
6683     AllClose(d, lazy_d);
6684   });
6685 }
6686 
TEST_F(LazyOpsTest,TestThreshold)6687 TEST_F(LazyOpsTest, TestThreshold) {
6688   torch::Tensor input = torch::rand(
6689       {2, 1, 4, 6},
6690       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6691   float threshold = 0.4;
6692   float value = 20;
6693   torch::Tensor output = torch::threshold(input, threshold, value);
6694   ForEachDevice([&](const torch::Device& device) {
6695     torch::Tensor lazy_input = CopyToDevice(input, device);
6696     torch::Tensor lazy_output = torch::threshold(lazy_input, threshold, value);
6697     AllClose(output, lazy_output);
6698   });
6699 }
6700 
TEST_F(LazyOpsTest,TestThresholdBackward)6701 TEST_F(LazyOpsTest, TestThresholdBackward) {
6702   float threshold = 0.4;
6703   float value = 20;
6704 
6705   auto testFunction =
6706       [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
6707     return torch::threshold(inputs[0], threshold, value);
6708   };
6709 
6710   ForEachDevice([&](const torch::Device& device) {
6711     TestBackward(
6712         {torch::rand(
6713             {2, 1, 4, 6},
6714             torch::TensorOptions(torch::kFloat)
6715                 .device(DefaultDevice())
6716                 .requires_grad(true))},
6717         device,
6718         testFunction);
6719   });
6720 }
6721 
TEST_F(LazyOpsTest,TestThresholdInPlace)6722 TEST_F(LazyOpsTest, TestThresholdInPlace) {
6723   torch::Tensor input = torch::rand(
6724       {2, 1, 4, 6},
6725       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6726   torch::Tensor output = input.clone();
6727   float threshold = 0.4;
6728   float value = 20;
6729   torch::threshold_(output, threshold, value);
6730   ForEachDevice([&](const torch::Device& device) {
6731     torch::Tensor lazy_output = CopyToDevice(input, device);
6732     torch::threshold_(lazy_output, threshold, value);
6733     AllClose(output, lazy_output);
6734   });
6735 }
6736 
TEST_F(LazyOpsTest,TestElu)6737 TEST_F(LazyOpsTest, TestElu) {
6738   torch::Tensor input = torch::rand(
6739       {2, 1, 4, 6},
6740       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6741   torch::Scalar alpha = 0.5;
6742   torch::Scalar scale = 2.5;
6743   torch::Scalar input_scale = 1.5;
6744   torch::Tensor output = torch::elu(input, alpha, scale, input_scale);
6745   ForEachDevice([&](const torch::Device& device) {
6746     torch::Tensor lazy_input = CopyToDevice(input, device);
6747     torch::Tensor lazy_output =
6748         torch::elu(lazy_input, alpha, scale, input_scale);
6749     AllClose(output, lazy_output);
6750   });
6751 }
6752 
TEST_F(LazyOpsTest,TestEluInPlace)6753 TEST_F(LazyOpsTest, TestEluInPlace) {
6754   torch::Tensor input = torch::rand(
6755       {2, 1, 4, 6},
6756       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6757   torch::Scalar alpha = 0.5;
6758   torch::Scalar scale = 2.5;
6759   torch::Scalar input_scale = 1.5;
6760   ForEachDevice([&](const torch::Device& device) {
6761     torch::Tensor lazy_input = CopyToDevice(input, device);
6762     torch::Tensor output = torch::elu_(input, alpha, scale, input_scale);
6763     torch::Tensor lazy_output =
6764         torch::elu_(lazy_input, alpha, scale, input_scale);
6765     AllClose(output, lazy_output);
6766     AllClose(input, lazy_input);
6767   });
6768 }
6769 
TEST_F(LazyOpsTest,TestSelu)6770 TEST_F(LazyOpsTest, TestSelu) {
6771   torch::Tensor input = torch::rand(
6772       {2, 1, 4, 6},
6773       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6774   torch::Tensor output = torch::selu(input);
6775   ForEachDevice([&](const torch::Device& device) {
6776     torch::Tensor lazy_input = CopyToDevice(input, device);
6777     torch::Tensor lazy_output = torch::selu(lazy_input);
6778     AllClose(output, lazy_output);
6779   });
6780 }
6781 
TEST_F(LazyOpsTest,TestSeluInPlace)6782 TEST_F(LazyOpsTest, TestSeluInPlace) {
6783   torch::Tensor input = torch::rand(
6784       {2, 1, 4, 6},
6785       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6786   ForEachDevice([&](const torch::Device& device) {
6787     torch::Tensor lazy_input = CopyToDevice(input, device);
6788     torch::Tensor output = torch::selu_(input);
6789     torch::Tensor lazy_output = torch::selu_(lazy_input);
6790     AllClose(output, lazy_output);
6791     AllClose(input, lazy_input);
6792   });
6793 }
6794 
TEST_F(LazyOpsTest,TestCelu)6795 TEST_F(LazyOpsTest, TestCelu) {
6796   torch::Tensor input = torch::rand(
6797       {2, 1, 4, 6},
6798       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6799   torch::Scalar alpha = 2.5;
6800   torch::Tensor output = torch::celu(input, alpha);
6801   ForEachDevice([&](const torch::Device& device) {
6802     torch::Tensor lazy_input = CopyToDevice(input, device);
6803     torch::Tensor lazy_output = torch::celu(lazy_input, alpha);
6804     AllClose(output, lazy_output);
6805   });
6806 }
6807 
TEST_F(LazyOpsTest,TestCeluInPlace)6808 TEST_F(LazyOpsTest, TestCeluInPlace) {
6809   torch::Tensor input = torch::rand(
6810       {2, 1, 4, 6},
6811       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6812   torch::Scalar alpha = 2.5;
6813   ForEachDevice([&](const torch::Device& device) {
6814     torch::Tensor lazy_input = CopyToDevice(input, device);
6815     torch::Tensor output = torch::celu_(input, alpha);
6816     torch::Tensor lazy_output = torch::celu_(lazy_input, alpha);
6817     AllClose(output, lazy_output);
6818     AllClose(input, lazy_input);
6819   });
6820 }
6821 
TEST_F(LazyOpsTest,TestGelu)6822 TEST_F(LazyOpsTest, TestGelu) {
6823   torch::Tensor input = torch::rand(
6824       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6825   torch::Tensor output = torch::gelu(input);
6826   ForEachDevice([&](const torch::Device& device) {
6827     torch::Tensor lazy_input = CopyToDevice(input, device);
6828     torch::Tensor lazy_output = torch::gelu(lazy_input);
6829     AllClose(output, lazy_output);
6830   });
6831 }
6832 
TEST_F(LazyOpsTest,TestAddMatMul)6833 TEST_F(LazyOpsTest, TestAddMatMul) {
6834   int in_channels = 32;
6835   int out_channels = 320;
6836   int labels = 50;
6837   torch::Tensor input = torch::rand(
6838       {in_channels, out_channels},
6839       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6840   torch::Tensor weight = torch::rand(
6841       {out_channels, labels},
6842       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6843   torch::Tensor bias = torch::rand(
6844       {labels}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6845   // Test beta != 1. through the CPU interop.
6846   for (double beta : {1., 2.}) {
6847     torch::Tensor output = torch::addmm(bias, input, weight, /*beta=*/beta);
6848     ForEachDevice([&](const torch::Device& device) {
6849       torch::Tensor lazy_input = CopyToDevice(input, device);
6850       torch::Tensor lazy_weight = CopyToDevice(weight, device);
6851       torch::Tensor lazy_bias = CopyToDevice(bias, device);
6852       torch::Tensor lazy_output =
6853           torch::addmm(lazy_bias, lazy_input, lazy_weight, /*beta=*/beta);
6854       AllClose(output, lazy_output);
6855     });
6856   }
6857 }
6858 
TEST_F(LazyOpsTest,TestEmbedding)6859 TEST_F(LazyOpsTest, TestEmbedding) {
6860   torch::Tensor a = torch::rand(
6861       {32, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6862   torch::Tensor i = torch::randint(
6863       0,
6864       31,
6865       {3, 4},
6866       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6867   torch::Tensor b = torch::embedding(
6868       a,
6869       i,
6870       /*padding_idx=*/0,
6871       /*scale_grad_by_freq=*/false,
6872       /*sparse=*/false);
6873   ForEachDevice([&](const torch::Device& device) {
6874     torch::Tensor lazy_a = CopyToDevice(a, device);
6875     torch::Tensor lazy_i = CopyToDevice(i, device);
6876     torch::Tensor lazy_b = torch::embedding(
6877         lazy_a,
6878         lazy_i,
6879         /*padding_idx=*/0,
6880         /*scale_grad_by_freq=*/false,
6881         /*sparse=*/false);
6882     AllClose(b, lazy_b);
6883   });
6884 }
6885 
TEST_F(LazyOpsTest,TestOneHot)6886 TEST_F(LazyOpsTest, TestOneHot) {
6887   int num_classes = 5;
6888   torch::Tensor input = torch::randint(
6889       0,
6890       num_classes,
6891       {10},
6892       torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6893   torch::Tensor output = torch::one_hot(input, num_classes);
6894   ForEachDevice([&](const torch::Device& device) {
6895     torch::Tensor lazy_input = CopyToDevice(input, device);
6896     torch::Tensor lazy_output = torch::one_hot(lazy_input, num_classes);
6897     AllEqual(output, lazy_output);
6898   });
6899 }
6900 
TEST_F(LazyOpsTest,TestTranspose)6901 TEST_F(LazyOpsTest, TestTranspose) {
6902   torch::Tensor input = torch::rand(
6903       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6904   torch::Tensor output = torch::t(input);
6905   ForEachDevice([&](const torch::Device& device) {
6906     torch::Tensor lazy_input = CopyToDevice(input, device);
6907     torch::Tensor lazy_output = torch::t(lazy_input);
6908     AllClose(output, lazy_output);
6909   });
6910 }
6911 
TEST_F(LazyOpsTest,TestTransposeInPlace)6912 TEST_F(LazyOpsTest, TestTransposeInPlace) {
6913   torch::Tensor input = torch::rand(
6914       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6915   ForEachDevice([&](const torch::Device& device) {
6916     torch::Tensor lazy_input = CopyToDevice(input, device);
6917     torch::Tensor output = input.t_();
6918     torch::Tensor lazy_output = lazy_input.t_();
6919     EXPECT_EQ(lazy_output.sizes(), output.sizes());
6920     AllClose(output, lazy_output);
6921     AllClose(input, lazy_input);
6922   });
6923 }
6924 
TEST_F(LazyOpsTest,TestReshape)6925 TEST_F(LazyOpsTest, TestReshape) {
6926   torch::Tensor input = torch::rand(
6927       {32, 20, 4, 4},
6928       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6929   torch::Tensor output = torch::reshape(input, {-1, 320});
6930   ForEachDevice([&](const torch::Device& device) {
6931     torch::Tensor lazy_input = CopyToDevice(input, device);
6932     torch::Tensor lazy_output = torch::reshape(lazy_input, {-1, 320});
6933     AllClose(output, lazy_output);
6934   });
6935 }
6936 
TEST_F(LazyOpsTest,TestResize)6937 TEST_F(LazyOpsTest, TestResize) {
6938   // Testing a resize_() with target size bigger than original size is not
6939   // possible, as we fill with zeros, while pytorch fills with random garbage.
6940   torch::Tensor input = torch::rand(
6941       {2, 2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6942   torch::Tensor saved_input = input.clone();
6943   input.resize_({3, 3});
6944   ForEachDevice([&](const torch::Device& device) {
6945     torch::Tensor lazy_input = CopyToDevice(saved_input, device);
6946     lazy_input.resize_({3, 3});
6947     AllClose(input, lazy_input);
6948   });
6949 }
6950 
TEST_F(LazyOpsTest,TestViewResize)6951 TEST_F(LazyOpsTest, TestViewResize) {
6952   torch::Tensor input = torch::zeros(
6953       {8, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6954   torch::Tensor saved_input = input.clone();
6955   torch::Tensor output = input.view({4, 4});
6956   output.resize_({3, 3});
6957   ForEachDevice([&](const torch::Device& device) {
6958     torch::Tensor lazy_input = CopyToDevice(saved_input, device);
6959     torch::Tensor lazy_output = lazy_input.view({4, 4});
6960     lazy_output.resize_({3, 3});
6961     AllClose(input, lazy_input);
6962     AllClose(output, lazy_output);
6963   });
6964 }
6965 
TEST_F(LazyOpsTest,TestView)6966 TEST_F(LazyOpsTest, TestView) {
6967   torch::Tensor input = torch::rand(
6968       {32, 20, 4, 4},
6969       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6970   torch::Tensor output = input.view({-1, 320});
6971   ForEachDevice([&](const torch::Device& device) {
6972     torch::Tensor lazy_input = CopyToDevice(input, device);
6973     torch::Tensor lazy_output = lazy_input.view({-1, 320});
6974     AllClose(output, lazy_output);
6975   });
6976 }
6977 
TEST_F(LazyOpsTest,TestViewMod)6978 TEST_F(LazyOpsTest, TestViewMod) {
6979   torch::Tensor input = torch::zeros(
6980       {32, 20, 4, 4},
6981       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6982   torch::Tensor one = torch::tensor(
6983       1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6984   torch::Tensor output = input.view({-1, 320});
6985   output.add_(one, 1.0);
6986   input.add_(one, 1.0);
6987   ForEachDevice([&](const torch::Device& device) {
6988     torch::Tensor xinput = torch::zeros(
6989         {32, 20, 4, 4},
6990         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6991     torch::Tensor lazy_input = CopyToDevice(xinput, device);
6992     torch::Tensor lazy_one = CopyToDevice(one, device);
6993     torch::Tensor lazy_output = lazy_input.view({-1, 320});
6994     lazy_output.add_(lazy_one, 1.0);
6995     lazy_input.add_(lazy_one, 1.0);
6996     AllClose(output, lazy_output);
6997     AllClose(input, lazy_input);
6998   });
6999 }
7000 
TEST_F(LazyOpsTest,TestViewModComplex)7001 TEST_F(LazyOpsTest, TestViewModComplex) {
7002   torch::Tensor input = torch::zeros(
7003       {32, 20, 4, 4},
7004       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7005   torch::Tensor one = torch::tensor(
7006       1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7007   torch::Tensor output1 = input.view({-1, 320});
7008   output1.add_(one, 1.0);
7009   torch::Tensor output2 = input.view({-1, 160});
7010   output2.add_(one, 1.0);
7011   ForEachDevice([&](const torch::Device& device) {
7012     torch::Tensor xinput = torch::zeros(
7013         {32, 20, 4, 4},
7014         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7015     torch::Tensor lazy_input = CopyToDevice(xinput, device);
7016     torch::Tensor lazy_one = CopyToDevice(one, device);
7017     torch::Tensor lazy_output1 = lazy_input.view({-1, 320});
7018     lazy_output1.add_(lazy_one, 1.0);
7019     torch::Tensor lazy_output2 = lazy_input.view({-1, 160});
7020     lazy_output2.add_(lazy_one, 1.0);
7021     AllClose(output1, lazy_output1);
7022     AllClose(output2, lazy_output2);
7023   });
7024 }
7025 
TEST_F(LazyOpsTest,TestViewOfViewMod)7026 TEST_F(LazyOpsTest, TestViewOfViewMod) {
7027   torch::Tensor input = torch::zeros(
7028       {32, 20, 4, 4},
7029       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7030   torch::Tensor one = torch::tensor(
7031       1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7032   torch::Tensor output1 = input.view({-1, 320});
7033   output1.add_(one, 1.0);
7034   torch::Tensor output2 = output1.view({-1, 160});
7035   output2.add_(one, 1.0);
7036   ForEachDevice([&](const torch::Device& device) {
7037     torch::Tensor xinput = torch::zeros(
7038         {32, 20, 4, 4},
7039         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7040     torch::Tensor lazy_input = CopyToDevice(xinput, device);
7041     torch::Tensor lazy_one = CopyToDevice(one, device);
7042     torch::Tensor lazy_output1 = lazy_input.view({-1, 320});
7043     lazy_output1.add_(lazy_one, 1.0);
7044     torch::Tensor lazy_output2 = lazy_output1.view({-1, 160});
7045     lazy_output2.add_(lazy_one, 1.0);
7046     AllClose(output1, lazy_output1);
7047     AllClose(output2, lazy_output2);
7048   });
7049 }
7050 
TEST_F(LazyOpsTest,TestViewSqueezeAddInPlace)7051 TEST_F(LazyOpsTest, TestViewSqueezeAddInPlace) {
7052   torch::Tensor input = torch::zeros(
7053       {2, 3, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7054   std::vector<int64_t> view_size = {2, 3, 1, 1};
7055   int squeeze_dim = 2;
7056   torch::Tensor one = torch::tensor(
7057       1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7058   ForEachDevice([&](const torch::Device& device) {
7059     torch::Tensor lazy_input = CopyToDevice(input, device);
7060     torch::Tensor output = input.view(view_size);
7061     output.squeeze_(squeeze_dim);
7062     output.add_(one, 1.0);
7063     torch::Tensor lazy_one = CopyToDevice(one, device);
7064     torch::Tensor lazy_output = lazy_input.view(view_size);
7065     lazy_output.squeeze_(squeeze_dim);
7066     lazy_output.add_(lazy_one, 1.0);
7067     AllClose(output, lazy_output);
7068     AllClose(input, lazy_input);
7069   });
7070 }
7071 
TEST_F(LazyOpsTest,TestUnsafeView)7072 TEST_F(LazyOpsTest, TestUnsafeView) {
7073   torch::Tensor input = torch::rand(
7074       {32, 20, 4, 4},
7075       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7076   torch::Tensor output = torch::_unsafe_view(input, {-1, 320});
7077   ForEachDevice([&](const torch::Device& device) {
7078     torch::Tensor lazy_input = CopyToDevice(input, device);
7079     torch::Tensor lazy_output = torch::_unsafe_view(lazy_input, {-1, 320});
7080     AllClose(output, lazy_output);
7081   });
7082 }
7083 
TEST_F(LazyOpsTest,TestNarrow)7084 TEST_F(LazyOpsTest, TestNarrow) {
7085   torch::Tensor a = torch::rand(
7086       {8, 10, 4, 4},
7087       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7088   for (int64_t dim : {1, -3}) {
7089     for (int64_t start : {2, -8}) {
7090       torch::Tensor b = a.narrow(dim, start, 6);
7091       ForEachDevice([&](const torch::Device& device) {
7092         torch::Tensor lazy_a = CopyToDevice(a, device);
7093         torch::Tensor lazy_b = lazy_a.narrow(dim, start, 6);
7094         AllClose(b, lazy_b);
7095       });
7096     }
7097   }
7098 }
7099 
TEST_F(LazyOpsTest,TestNarrowUpdate)7100 TEST_F(LazyOpsTest, TestNarrowUpdate) {
7101   for (int64_t dim : {1, -2}) {
7102     for (int64_t start : {2, -6}) {
7103       torch::Tensor a = torch::rand(
7104           {3, 8, 3},
7105           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7106       torch::Tensor a_copy = a.clone();
7107       torch::Tensor b = torch::rand(
7108           {3, 4, 3},
7109           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7110       torch::Tensor c = a.narrow(dim, start, 4);
7111       c.add_(b, 1.0);
7112       ForEachDevice([&](const torch::Device& device) {
7113         torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7114         torch::Tensor lazy_b = CopyToDevice(b, device);
7115         torch::Tensor lazy_c = lazy_a.narrow(dim, start, 4);
7116         lazy_c.add_(lazy_b, 1.0);
7117         AllClose(c, lazy_c);
7118       });
7119     }
7120   }
7121 }
7122 
TEST_F(LazyOpsTest,TestNarrowUpdateBaseCheck)7123 TEST_F(LazyOpsTest, TestNarrowUpdateBaseCheck) {
7124   for (int64_t dim : {0, -2}) {
7125     for (int64_t start : {2, -6}) {
7126       torch::Tensor a = torch::zeros(
7127           {8, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7128       torch::Tensor a_copy = a.clone();
7129       torch::Tensor b = torch::ones(
7130           {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7131       torch::Tensor c = a.narrow(dim, start, 4);
7132       c.add_(b, 1.0);
7133       ForEachDevice([&](const torch::Device& device) {
7134         torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7135         torch::Tensor lazy_b = CopyToDevice(b, device);
7136         torch::Tensor lazy_c = lazy_a.narrow(dim, start, 4);
7137         lazy_c.add_(lazy_b, 1.0);
7138         AllClose(a, lazy_a);
7139       });
7140     }
7141   }
7142 }
7143 
TEST_F(LazyOpsTest,TestNarrowUpdateTwoSlices)7144 TEST_F(LazyOpsTest, TestNarrowUpdateTwoSlices) {
7145   for (int64_t dim : {0, -2}) {
7146     for (int64_t start0 : {2, -6}) {
7147       for (int64_t start1 : {6, -2}) {
7148         torch::Tensor a = torch::zeros(
7149             {8, 3},
7150             torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7151         torch::Tensor a_copy = a.clone();
7152         torch::Tensor b = torch::ones(
7153             {2, 3},
7154             torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7155         torch::Tensor c = b + 1;
7156         torch::Tensor d = a.narrow(dim, start0, 2);
7157         torch::Tensor e = a.narrow(dim, start1, 2);
7158         d.add_(b, 1.0);
7159         e.add_(c, 1.0);
7160         ForEachDevice([&](const torch::Device& device) {
7161           torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7162           torch::Tensor lazy_b = CopyToDevice(b, device);
7163           torch::Tensor lazy_c = CopyToDevice(c, device);
7164           torch::Tensor lazy_d = lazy_a.narrow(dim, start0, 2);
7165           torch::Tensor lazy_e = lazy_a.narrow(dim, start1, 2);
7166           lazy_d.add_(lazy_b, 1.0);
7167           lazy_e.add_(lazy_c, 1.0);
7168           AllClose(d, lazy_d);
7169           AllClose(e, lazy_e);
7170           AllClose(a, lazy_a);
7171         });
7172       }
7173     }
7174   }
7175 }
7176 
TEST_F(LazyOpsTest,TestNarrowUpdateView)7177 TEST_F(LazyOpsTest, TestNarrowUpdateView) {
7178   for (int64_t dim : {0, -3}) {
7179     for (int64_t start : {2, -6}) {
7180       torch::Tensor a = torch::rand(
7181           {8, 2, 3},
7182           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7183       torch::Tensor a_copy = a.clone();
7184       torch::Tensor b = torch::rand(
7185           {4, 6}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7186       torch::Tensor c = a.narrow(dim, start, 4);
7187       torch::Tensor d = c.view({4, 6});
7188       d.add_(b, 1.0);
7189       ForEachDevice([&](const torch::Device& device) {
7190         torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7191         torch::Tensor lazy_b = CopyToDevice(b, device);
7192         torch::Tensor lazy_c = lazy_a.narrow(dim, start, 4);
7193         torch::Tensor lazy_d = lazy_c.view({4, 6});
7194         lazy_d.add_(lazy_b, 1.0);
7195         AllClose(d, lazy_d);
7196       });
7197     }
7198   }
7199 }
7200 
TEST_F(LazyOpsTest,TestNarrowInNarrowUpdate)7201 TEST_F(LazyOpsTest, TestNarrowInNarrowUpdate) {
7202   for (int64_t dim : {1, -2}) {
7203     for (int64_t start0 : {1, -7}) {
7204       for (int64_t start1 : {1, -5}) {
7205         torch::Tensor a = torch::rand(
7206             {3, 8, 3},
7207             torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7208         torch::Tensor a_copy = a.clone();
7209         torch::Tensor b = torch::rand(
7210             {3, 2, 3},
7211             torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7212         torch::Tensor c = a.narrow(dim, start0, 6);
7213         torch::Tensor d = c.narrow(dim, start1, 2);
7214         d.add_(b, 1.0);
7215         ForEachDevice([&](const torch::Device& device) {
7216           torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7217           torch::Tensor lazy_b = CopyToDevice(b, device);
7218           torch::Tensor lazy_c = lazy_a.narrow(dim, start0, 6);
7219           torch::Tensor lazy_d = lazy_c.narrow(dim, start1, 2);
7220           lazy_d.add_(lazy_b, 1.0);
7221           AllClose(a, lazy_a);
7222         });
7223       }
7224     }
7225   }
7226 }
7227 
TEST_F(LazyOpsTest,TestNarrowCopy)7228 TEST_F(LazyOpsTest, TestNarrowCopy) {
7229   for (int64_t dim : {1, -3}) {
7230     for (int64_t start : {2, -8}) {
7231       ForEachDevice([&](const torch::Device& device) {
7232         torch::Tensor input = torch::rand(
7233             {8, 10, 4, 4},
7234             torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7235         torch::Tensor lazy_input = CopyToDevice(input, device);
7236         torch::Tensor result = input.narrow_copy(dim, start, 6);
7237         input.add_(1);
7238         torch::Tensor lazy_result = lazy_input.narrow_copy(dim, start, 6);
7239         lazy_input.add_(1);
7240         AllClose(result, lazy_result);
7241       });
7242     }
7243   }
7244 }
7245 
TEST_F(LazyOpsTest,TestViewAs)7246 TEST_F(LazyOpsTest, TestViewAs) {
7247   torch::Tensor input = torch::rand(
7248       {32, 20, 4, 4},
7249       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7250   torch::Tensor empty = torch::empty({32, 320});
7251   torch::Tensor output = input.view_as(empty);
7252   ForEachDevice([&](const torch::Device& device) {
7253     torch::Tensor lazy_input = CopyToDevice(input, device);
7254     torch::Tensor lazy_empty = CopyToDevice(empty, device);
7255     torch::Tensor lazy_output = lazy_input.view_as(lazy_empty);
7256     AllClose(output, lazy_output);
7257   });
7258 }
7259 
TEST_F(LazyOpsTest,TestLogSoftmax)7260 TEST_F(LazyOpsTest, TestLogSoftmax) {
7261   torch::Tensor input = torch::rand(
7262       {5, 3, 4, 2},
7263       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7264   ForEachDevice([&](const torch::Device& device) {
7265     torch::Tensor lazy_input = CopyToDevice(input, device);
7266     int rank = input.dim();
7267     for (int dim = -rank; dim < rank; ++dim) {
7268       torch::Tensor output = torch::log_softmax(input, dim);
7269       torch::Tensor lazy_output = torch::log_softmax(lazy_input, dim);
7270       AllClose(output, lazy_output, /*rtol=*/1e-3);
7271     }
7272   });
7273 }
7274 
TEST_F(LazyOpsTest,TestLogSoftmaxCast)7275 TEST_F(LazyOpsTest, TestLogSoftmaxCast) {
7276   torch::Tensor input = torch::rand(
7277       {5, 3, 4, 2},
7278       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7279   ForEachDevice([&](const torch::Device& device) {
7280     torch::Tensor lazy_input = CopyToDevice(input, device);
7281     int rank = input.dim();
7282     for (int dim = -rank; dim < rank; ++dim) {
7283       torch::Tensor output = torch::log_softmax(input, dim, torch::kDouble);
7284       torch::Tensor lazy_output =
7285           torch::log_softmax(lazy_input, dim, torch::kDouble);
7286       AllClose(output, lazy_output, /*rtol=*/1e-3);
7287     }
7288   });
7289 }
7290 
TEST_F(LazyOpsTest,TestLogSoftmaxWrapper)7291 TEST_F(LazyOpsTest, TestLogSoftmaxWrapper) {
7292   torch::Tensor input = torch::rand(
7293       {10, 2, 6, 4},
7294       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7295   ForEachDevice([&](const torch::Device& device) {
7296     torch::Tensor lazy_input = CopyToDevice(input, device);
7297     int rank = input.dim();
7298     for (int dim = -rank; dim < rank; ++dim) {
7299       torch::Tensor output =
7300           torch::_log_softmax(input, dim, /*half_to_float=*/false);
7301       torch::Tensor lazy_output =
7302           torch::_log_softmax(lazy_input, dim, /*half_to_float=*/false);
7303       AllClose(output, lazy_output, /*rtol=*/1e-3);
7304     }
7305   });
7306 }
7307 
TEST_F(LazyOpsTest,TestSoftmax)7308 TEST_F(LazyOpsTest, TestSoftmax) {
7309   torch::Tensor input = torch::rand(
7310       {10, 2, 6, 4},
7311       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7312   ForEachDevice([&](const torch::Device& device) {
7313     torch::Tensor lazy_input = CopyToDevice(input, device);
7314     int rank = input.dim();
7315     for (int dim = -rank; dim < rank; ++dim) {
7316       torch::Tensor output = torch::softmax(input, dim);
7317       torch::Tensor lazy_output = torch::softmax(lazy_input, dim);
7318       AllClose(output, lazy_output, /*rtol=*/1e-3);
7319     }
7320   });
7321 }
7322 
TEST_F(LazyOpsTest,TestSoftmaxCast)7323 TEST_F(LazyOpsTest, TestSoftmaxCast) {
7324   torch::Tensor input = torch::rand(
7325       {10, 2, 6, 4},
7326       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7327   ForEachDevice([&](const torch::Device& device) {
7328     torch::Tensor lazy_input = CopyToDevice(input, device);
7329     int rank = input.dim();
7330     for (int dim = -rank; dim < rank; ++dim) {
7331       torch::Tensor output = torch::softmax(input, dim, torch::kDouble);
7332       torch::Tensor lazy_output =
7333           torch::softmax(lazy_input, dim, torch::kDouble);
7334       AllClose(output, lazy_output, /*rtol=*/1e-3);
7335     }
7336   });
7337 }
7338 
TEST_F(LazyOpsTest,TestSoftmaxWrapper)7339 TEST_F(LazyOpsTest, TestSoftmaxWrapper) {
7340   torch::Tensor input = torch::rand(
7341       {10, 2, 6, 4},
7342       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7343   ForEachDevice([&](const torch::Device& device) {
7344     torch::Tensor lazy_input = CopyToDevice(input, device);
7345     int rank = input.dim();
7346     for (int dim = -rank; dim < rank; ++dim) {
7347       torch::Tensor output =
7348           torch::_softmax(input, dim, /*half_to_float=*/false);
7349       torch::Tensor lazy_output =
7350           torch::_softmax(lazy_input, dim, /*half_to_float=*/false);
7351       AllClose(output, lazy_output, /*rtol=*/1e-3);
7352     }
7353   });
7354 }
7355 
TEST_F(LazyOpsTest,TestSoftplus)7356 TEST_F(LazyOpsTest, TestSoftplus) {
7357   torch::Tensor input = torch::rand(
7358       {2, 1, 4, 6},
7359       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7360   torch::Tensor output = torch::softplus(input);
7361   ForEachDevice([&](const torch::Device& device) {
7362     torch::Tensor lazy_input = CopyToDevice(input, device);
7363     torch::Tensor lazy_output = torch::softplus(lazy_input);
7364     AllClose(output, lazy_output, /*rtol=*/1e-4);
7365   });
7366 }
7367 
TEST_F(LazyOpsTest,TestMaxPool1D)7368 TEST_F(LazyOpsTest, TestMaxPool1D) {
7369   torch::Tensor input = torch::rand(
7370       {1, 16, 56}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7371   int kernel_size = 3;
7372   for (int stride = 1; stride <= 2; ++stride) {
7373     for (int padding = 0; padding <= 1; ++padding) {
7374       // Test ceil_mode=true through the CPU interop.
7375       for (bool ceil_mode : {false, true}) {
7376         // Test dilation through the CPU interop.
7377         for (int dilation = 1; dilation <= 2; ++dilation) {
7378           torch::Tensor output = torch::max_pool1d(
7379               input,
7380               /*kernel_size=*/{kernel_size},
7381               /*stride=*/{stride},
7382               /*padding=*/{padding},
7383               /*dilation=*/{dilation},
7384               /*ceil_mode=*/ceil_mode);
7385           ForEachDevice([&](const torch::Device& device) {
7386             torch::Tensor lazy_input = CopyToDevice(input, device);
7387             torch::Tensor lazy_output = torch::max_pool1d(
7388                 lazy_input,
7389                 /*kernel_size=*/{kernel_size},
7390                 /*stride=*/{stride},
7391                 /*padding=*/{padding},
7392                 /*dilation=*/{dilation},
7393                 /*ceil_mode=*/ceil_mode);
7394             AllClose(output, lazy_output);
7395           });
7396         }
7397       }
7398     }
7399   }
7400 }
7401 
TEST_F(LazyOpsTest,TestMaxPool2D)7402 TEST_F(LazyOpsTest, TestMaxPool2D) {
7403   torch::Tensor input = torch::rand(
7404       {1, 4, 14, 14},
7405       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7406   int kernel_size = 3;
7407   for (int stride = 1; stride <= 2; ++stride) {
7408     for (int padding = 0; padding <= 1; ++padding) {
7409       // Test ceil_mode=true through the CPU interop.
7410       for (bool ceil_mode : {false, true}) {
7411         // Test dilation through the CPU interop.
7412         for (int dilation = 1; dilation <= 2; ++dilation) {
7413           torch::Tensor output = torch::max_pool2d(
7414               input,
7415               /*kernel_size=*/{kernel_size, kernel_size},
7416               /*stride=*/{stride, stride},
7417               /*padding=*/{padding, padding},
7418               /*dilation=*/{dilation, dilation},
7419               /*ceil_mode=*/ceil_mode);
7420           ForEachDevice([&](const torch::Device& device) {
7421             torch::Tensor lazy_input = CopyToDevice(input, device);
7422             torch::Tensor lazy_output = torch::max_pool2d(
7423                 lazy_input,
7424                 /*kernel_size=*/{kernel_size, kernel_size},
7425                 /*stride=*/{stride, stride},
7426                 /*padding=*/{padding, padding},
7427                 /*dilation=*/{dilation, dilation},
7428                 /*ceil_mode=*/ceil_mode);
7429             AllClose(output, lazy_output);
7430           });
7431         }
7432       }
7433     }
7434   }
7435 }
7436 
TEST_F(LazyOpsTest,TestMaxPool2DWithIndices)7437 TEST_F(LazyOpsTest, TestMaxPool2DWithIndices) {
7438   torch::Tensor input = torch::rand(
7439       {1, 4, 14, 14},
7440       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7441   int kernel_size = 3;
7442   for (int stride = 1; stride <= 2; ++stride) {
7443     for (int padding = 0; padding <= 1; ++padding) {
7444       // Test ceil_mode=true through the CPU interop.
7445       for (bool ceil_mode : {false, true}) {
7446         // Test dilation through the CPU interop.
7447         for (int dilation = 1; dilation <= 2; ++dilation) {
7448           auto outputs = torch::max_pool2d_with_indices(
7449               input,
7450               /*kernel_size=*/{kernel_size, kernel_size},
7451               /*stride=*/{stride, stride},
7452               /*padding=*/{padding, padding},
7453               /*dilation=*/{dilation, dilation},
7454               /*ceil_mode=*/ceil_mode);
7455           ForEachDevice([&](const torch::Device& device) {
7456             torch::Tensor lazy_input = CopyToDevice(input, device);
7457             auto lazy_outputs = torch::max_pool2d_with_indices(
7458                 lazy_input,
7459                 /*kernel_size=*/{kernel_size, kernel_size},
7460                 /*stride=*/{stride, stride},
7461                 /*padding=*/{padding, padding},
7462                 /*dilation=*/{dilation, dilation},
7463                 /*ceil_mode=*/ceil_mode);
7464             AllClose(std::get<0>(outputs), std::get<0>(lazy_outputs));
7465             AllClose(std::get<1>(outputs), std::get<1>(lazy_outputs));
7466           });
7467         }
7468       }
7469     }
7470   }
7471 }
7472 
TEST_F(LazyOpsTest,TestMaxPool2DNonSquare)7473 TEST_F(LazyOpsTest, TestMaxPool2DNonSquare) {
7474   torch::Tensor input = torch::rand(
7475       {1, 4, 14, 14},
7476       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7477   int kernel_size = 4;
7478   for (int stride = 1; stride <= 2; ++stride) {
7479     for (int padding = 0; padding <= 1; ++padding) {
7480       // Test ceil_mode=true through the CPU interop.
7481       for (bool ceil_mode : {false, true}) {
7482         // Test dilation through the CPU interop.
7483         for (int dilation = 1; dilation <= 2; ++dilation) {
7484           torch::Tensor output = torch::max_pool2d(
7485               input,
7486               /*kernel_size=*/{kernel_size, kernel_size + 1},
7487               /*stride=*/{stride, stride + 1},
7488               /*padding=*/{padding, padding + 1},
7489               /*dilation=*/{dilation, dilation},
7490               /*ceil_mode=*/ceil_mode);
7491           ForEachDevice([&](const torch::Device& device) {
7492             torch::Tensor lazy_input = CopyToDevice(input, device);
7493             torch::Tensor lazy_output = torch::max_pool2d(
7494                 lazy_input,
7495                 /*kernel_size=*/{kernel_size, kernel_size + 1},
7496                 /*stride=*/{stride, stride + 1},
7497                 /*padding=*/{padding, padding + 1},
7498                 /*dilation=*/{dilation, dilation},
7499                 /*ceil_mode=*/ceil_mode);
7500             AllClose(output, lazy_output);
7501           });
7502         }
7503       }
7504     }
7505   }
7506 }
7507 
TEST_F(LazyOpsTest,TestMaxPool3D)7508 TEST_F(LazyOpsTest, TestMaxPool3D) {
7509   torch::Tensor input = torch::rand(
7510       {1, 1, 8, 8, 8},
7511       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7512   int kernel_size = 3;
7513   for (int stride = 1; stride <= 2; ++stride) {
7514     for (int padding = 0; padding <= 1; ++padding) {
7515       // Test ceil_mode=true through the CPU interop.
7516       for (bool ceil_mode : {false, true}) {
7517         // Test dilation through the CPU interop.
7518         for (int dilation = 1; dilation <= 2; ++dilation) {
7519           torch::Tensor output = torch::max_pool3d(
7520               input,
7521               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7522               /*stride=*/{stride, stride, stride},
7523               /*padding=*/{padding, padding, padding},
7524               /*dilation=*/{dilation, dilation, dilation},
7525               /*ceil_mode=*/ceil_mode);
7526           ForEachDevice([&](const torch::Device& device) {
7527             torch::Tensor lazy_input = CopyToDevice(input, device);
7528             torch::Tensor lazy_output = torch::max_pool3d(
7529                 lazy_input,
7530                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7531                 /*stride=*/{stride, stride, stride},
7532                 /*padding=*/{padding, padding, padding},
7533                 /*dilation=*/{dilation, dilation, dilation},
7534                 /*ceil_mode=*/ceil_mode);
7535             AllClose(output, lazy_output);
7536           });
7537         }
7538       }
7539     }
7540   }
7541 }
7542 
TEST_F(LazyOpsTest,TestMaxPool3DWithIndices)7543 TEST_F(LazyOpsTest, TestMaxPool3DWithIndices) {
7544   torch::Tensor input = torch::rand(
7545       {1, 1, 8, 8, 8},
7546       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7547   int kernel_size = 3;
7548   for (int stride = 1; stride <= 2; ++stride) {
7549     for (int padding = 0; padding <= 1; ++padding) {
7550       // Test ceil_mode=true through the CPU interop.
7551       for (bool ceil_mode : {false, true}) {
7552         // Test dilation through the CPU interop.
7553         for (int dilation = 1; dilation <= 2; ++dilation) {
7554           auto outputs = torch::max_pool3d_with_indices(
7555               input,
7556               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7557               /*stride=*/{stride, stride, stride},
7558               /*padding=*/{padding, padding, padding},
7559               /*dilation=*/{dilation, dilation, dilation},
7560               /*ceil_mode=*/ceil_mode);
7561           ForEachDevice([&](const torch::Device& device) {
7562             torch::Tensor lazy_input = CopyToDevice(input, device);
7563             auto lazy_outputs = torch::max_pool3d_with_indices(
7564                 lazy_input,
7565                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7566                 /*stride=*/{stride, stride, stride},
7567                 /*padding=*/{padding, padding, padding},
7568                 /*dilation=*/{dilation, dilation, dilation},
7569                 /*ceil_mode=*/ceil_mode);
7570 
7571             AllClose(std::get<0>(outputs), std::get<0>(lazy_outputs));
7572             AllClose(std::get<1>(outputs), std::get<1>(lazy_outputs));
7573           });
7574         }
7575       }
7576     }
7577   }
7578 }
7579 
TEST_F(LazyOpsTest,TestMaxPool3DIncompleteAttributes)7580 TEST_F(LazyOpsTest, TestMaxPool3DIncompleteAttributes) {
7581   torch::Tensor input = torch::rand(
7582       {1, 1, 8, 8, 8},
7583       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7584   int kernel_size = 3;
7585   for (int stride = 1; stride <= 2; ++stride) {
7586     for (int padding = 0; padding <= 1; ++padding) {
7587       // Test ceil_mode=true through the CPU interop.
7588       for (bool ceil_mode : {false, true}) {
7589         // Test dilation through the CPU interop.
7590         for (int dilation = 1; dilation <= 2; ++dilation) {
7591           torch::Tensor output = torch::max_pool3d(
7592               input,
7593               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7594               /*stride=*/{},
7595               /*padding=*/{padding},
7596               /*dilation=*/{dilation, dilation, dilation},
7597               /*ceil_mode=*/ceil_mode);
7598           ForEachDevice([&](const torch::Device& device) {
7599             torch::Tensor lazy_input = CopyToDevice(input, device);
7600             torch::Tensor lazy_output = torch::max_pool3d(
7601                 lazy_input,
7602                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7603                 /*stride=*/{},
7604                 /*padding=*/{padding},
7605                 /*dilation=*/{dilation, dilation, dilation},
7606                 /*ceil_mode=*/ceil_mode);
7607             AllClose(output, lazy_output);
7608           });
7609         }
7610       }
7611     }
7612   }
7613 }
7614 
TEST_F(LazyOpsTest,TestMaxPool3DNonSquare)7615 TEST_F(LazyOpsTest, TestMaxPool3DNonSquare) {
7616   torch::Tensor input = torch::rand(
7617       {1, 1, 8, 8, 8},
7618       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7619   int kernel_size = 4;
7620   for (int stride = 1; stride <= 2; ++stride) {
7621     for (int padding = 0; padding <= 1; ++padding) {
7622       // Test ceil_mode=true through the CPU interop.
7623       for (bool ceil_mode : {false, true}) {
7624         // Test dilation through the CPU interop.
7625         for (int dilation = 1; dilation <= 2; ++dilation) {
7626           torch::Tensor output = torch::max_pool3d(
7627               input,
7628               /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7629               /*stride=*/{stride, stride + 1, stride},
7630               /*padding=*/{padding, padding + 1, padding},
7631               /*dilation=*/{dilation, dilation, dilation},
7632               /*ceil_mode=*/ceil_mode);
7633           ForEachDevice([&](const torch::Device& device) {
7634             torch::Tensor lazy_input = CopyToDevice(input, device);
7635             torch::Tensor lazy_output = torch::max_pool3d(
7636                 lazy_input,
7637                 /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7638                 /*stride=*/{stride, stride + 1, stride},
7639                 /*padding=*/{padding, padding + 1, padding},
7640                 /*dilation=*/{dilation, dilation, dilation},
7641                 /*ceil_mode=*/ceil_mode);
7642             AllClose(output, lazy_output);
7643           });
7644         }
7645       }
7646     }
7647   }
7648 }
7649 
TEST_F(LazyOpsTest,TestMaxPool2DNoBatch)7650 TEST_F(LazyOpsTest, TestMaxPool2DNoBatch) {
7651   torch::Tensor input = torch::rand(
7652       {4, 14, 14}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7653   int kernel_size = 3;
7654   for (int stride = 1; stride <= 2; ++stride) {
7655     for (int padding = 0; padding <= 1; ++padding) {
7656       // Test ceil_mode=true through the CPU interop.
7657       for (bool ceil_mode : {false, true}) {
7658         // Test dilation through the CPU interop.
7659         for (int dilation = 1; dilation <= 2; ++dilation) {
7660           torch::Tensor output = torch::max_pool2d(
7661               input,
7662               /*kernel_size=*/{kernel_size, kernel_size},
7663               /*stride=*/{stride, stride},
7664               /*padding=*/{padding, padding},
7665               /*dilation=*/{dilation, dilation},
7666               /*ceil_mode=*/ceil_mode);
7667           ForEachDevice([&](const torch::Device& device) {
7668             torch::Tensor lazy_input = CopyToDevice(input, device);
7669             torch::Tensor lazy_output = torch::max_pool2d(
7670                 lazy_input,
7671                 /*kernel_size=*/{kernel_size, kernel_size},
7672                 /*stride=*/{stride, stride},
7673                 /*padding=*/{padding, padding},
7674                 /*dilation=*/{dilation, dilation},
7675                 /*ceil_mode=*/ceil_mode);
7676             AllClose(output, lazy_output);
7677           });
7678         }
7679       }
7680     }
7681   }
7682 }
7683 
TEST_F(LazyOpsTest,TestMaxPool3DNoBatch)7684 TEST_F(LazyOpsTest, TestMaxPool3DNoBatch) {
7685   torch::Tensor input = torch::rand(
7686       {1, 8, 8, 8},
7687       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7688   int kernel_size = 3;
7689   for (int stride = 1; stride <= 2; ++stride) {
7690     for (int padding = 0; padding <= 1; ++padding) {
7691       // Test ceil_mode=true through the CPU interop.
7692       for (bool ceil_mode : {false, true}) {
7693         // Test dilation through the CPU interop.
7694         for (int dilation = 1; dilation <= 2; ++dilation) {
7695           torch::Tensor output = torch::max_pool3d(
7696               input,
7697               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7698               /*stride=*/{stride, stride, stride},
7699               /*padding=*/{padding, padding, padding},
7700               /*dilation=*/{dilation, dilation, dilation},
7701               /*ceil_mode=*/ceil_mode);
7702           ForEachDevice([&](const torch::Device& device) {
7703             torch::Tensor lazy_input = CopyToDevice(input, device);
7704             torch::Tensor lazy_output = torch::max_pool3d(
7705                 lazy_input,
7706                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7707                 /*stride=*/{stride, stride, stride},
7708                 /*padding=*/{padding, padding, padding},
7709                 /*dilation=*/{dilation, dilation, dilation},
7710                 /*ceil_mode=*/ceil_mode);
7711             AllClose(output, lazy_output);
7712           });
7713         }
7714       }
7715     }
7716   }
7717 }
7718 
TEST_F(LazyOpsTest,TestAvgPool1D)7719 TEST_F(LazyOpsTest, TestAvgPool1D) {
7720   torch::Tensor input = torch::rand(
7721       {4, 1, 28}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7722   int kernel_size = 2;
7723   for (int stride = 1; stride <= 2; ++stride) {
7724     for (int padding = 0; padding <= 1; ++padding) {
7725       for (bool count_include_pad : {true, false}) {
7726         // Test ceil_mode=true through the CPU interop.
7727         for (bool ceil_mode : {false, true}) {
7728           torch::Tensor output = torch::avg_pool1d(
7729               input,
7730               /*kernel_size=*/{kernel_size},
7731               /*stride=*/{stride},
7732               /*padding=*/{padding},
7733               /*ceil_mode=*/ceil_mode,
7734               /*count_include_pad=*/count_include_pad);
7735           ForEachDevice([&](const torch::Device& device) {
7736             torch::Tensor lazy_input = CopyToDevice(input, device);
7737             torch::Tensor lazy_output = torch::avg_pool1d(
7738                 lazy_input,
7739                 /*kernel_size=*/{kernel_size},
7740                 /*stride=*/{stride},
7741                 /*padding=*/{padding},
7742                 /*ceil_mode=*/ceil_mode,
7743                 /*count_include_pad=*/count_include_pad);
7744             AllClose(output, lazy_output);
7745           });
7746         }
7747       }
7748     }
7749   }
7750 }
7751 
TEST_F(LazyOpsTest,TestAvgPool2D)7752 TEST_F(LazyOpsTest, TestAvgPool2D) {
7753   torch::Tensor input = torch::rand(
7754       {2, 1, 14, 14},
7755       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7756   int kernel_size = 2;
7757   for (int stride = 1; stride <= 2; ++stride) {
7758     for (int padding = 0; padding <= 1; ++padding) {
7759       for (bool count_include_pad : {true, false}) {
7760         // Test ceil_mode=true through the CPU interop.
7761         for (bool ceil_mode : {false, true}) {
7762           torch::Tensor output = torch::avg_pool2d(
7763               input,
7764               /*kernel_size=*/{kernel_size, kernel_size},
7765               /*stride=*/{stride, stride},
7766               /*padding=*/{padding, padding},
7767               /*ceil_mode=*/ceil_mode,
7768               /*count_include_pad=*/count_include_pad);
7769           ForEachDevice([&](const torch::Device& device) {
7770             // torch::Tensor lazy_input = CopyToDevice(input, device);
7771             torch::Tensor lazy_input = CopyToDevice(input, device);
7772             torch::Tensor lazy_output = torch::avg_pool2d(
7773                 lazy_input,
7774                 /*kernel_size=*/{kernel_size, kernel_size},
7775                 /*stride=*/{stride, stride},
7776                 /*padding=*/{padding, padding},
7777                 /*ceil_mode=*/ceil_mode,
7778                 /*count_include_pad=*/count_include_pad);
7779             AllClose(output, lazy_output.to(torch::kCPU));
7780           });
7781         }
7782       }
7783     }
7784   }
7785 }
7786 
TEST_F(LazyOpsTest,TestAvgPool2DNonSquare)7787 TEST_F(LazyOpsTest, TestAvgPool2DNonSquare) {
7788   torch::Tensor input = torch::rand(
7789       {2, 1, 14, 14},
7790       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7791   int kernel_size = 4;
7792   for (int stride = 1; stride <= 2; ++stride) {
7793     for (int padding = 0; padding <= 1; ++padding) {
7794       for (bool count_include_pad : {true, false}) {
7795         // Test ceil_mode=true through the CPU interop.
7796         for (bool ceil_mode : {false, true}) {
7797           torch::Tensor output = torch::avg_pool2d(
7798               input,
7799               /*kernel_size=*/{kernel_size, kernel_size + 1},
7800               /*stride=*/{stride, stride + 1},
7801               /*padding=*/{padding, padding + 1},
7802               /*ceil_mode=*/ceil_mode,
7803               /*count_include_pad=*/count_include_pad);
7804           ForEachDevice([&](const torch::Device& device) {
7805             torch::Tensor lazy_input = CopyToDevice(input, device);
7806             torch::Tensor lazy_output = torch::avg_pool2d(
7807                 lazy_input,
7808                 /*kernel_size=*/{kernel_size, kernel_size + 1},
7809                 /*stride=*/{stride, stride + 1},
7810                 /*padding=*/{padding, padding + 1},
7811                 /*ceil_mode=*/ceil_mode,
7812                 /*count_include_pad=*/count_include_pad);
7813             AllClose(output, lazy_output);
7814           });
7815         }
7816       }
7817     }
7818   }
7819 }
7820 
TEST_F(LazyOpsTest,TestAvgPool3D)7821 TEST_F(LazyOpsTest, TestAvgPool3D) {
7822   torch::Tensor input = torch::rand(
7823       {1, 1, 7, 7, 7},
7824       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7825   int kernel_size = 2;
7826   for (int stride = 1; stride <= 2; ++stride) {
7827     for (int padding = 0; padding <= 1; ++padding) {
7828       for (bool count_include_pad : {true, false}) {
7829         // Test ceil_mode=true through the CPU interop.
7830         for (bool ceil_mode : {false, true}) {
7831           torch::Tensor output = torch::avg_pool3d(
7832               input,
7833               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7834               /*stride=*/{stride, stride, stride},
7835               /*padding=*/{padding, padding, padding},
7836               /*ceil_mode=*/ceil_mode,
7837               /*count_include_pad=*/count_include_pad);
7838           ForEachDevice([&](const torch::Device& device) {
7839             torch::Tensor lazy_input = CopyToDevice(input, device);
7840             torch::Tensor lazy_output = torch::avg_pool3d(
7841                 lazy_input,
7842                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7843                 /*stride=*/{stride, stride, stride},
7844                 /*padding=*/{padding, padding, padding},
7845                 /*ceil_mode=*/ceil_mode,
7846                 /*count_include_pad=*/count_include_pad);
7847             AllClose(output, lazy_output);
7848           });
7849         }
7850       }
7851     }
7852   }
7853 }
7854 
TEST_F(LazyOpsTest,TestAvgPool3DIncompleteAttributes)7855 TEST_F(LazyOpsTest, TestAvgPool3DIncompleteAttributes) {
7856   torch::Tensor input = torch::rand(
7857       {1, 1, 7, 7, 7},
7858       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7859   int kernel_size = 2;
7860   for (int stride = 1; stride <= 2; ++stride) {
7861     for (int padding = 0; padding <= 1; ++padding) {
7862       for (bool count_include_pad : {true, false}) {
7863         // Test ceil_mode=true through the CPU interop.
7864         for (bool ceil_mode : {false, true}) {
7865           torch::Tensor output = torch::avg_pool3d(
7866               input,
7867               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7868               /*stride=*/{},
7869               /*padding=*/{padding, padding, padding},
7870               /*ceil_mode=*/ceil_mode,
7871               /*count_include_pad=*/count_include_pad);
7872           ForEachDevice([&](const torch::Device& device) {
7873             torch::Tensor lazy_input = CopyToDevice(input, device);
7874             torch::Tensor lazy_output = torch::avg_pool3d(
7875                 lazy_input,
7876                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7877                 /*stride=*/{},
7878                 /*padding=*/{padding, padding, padding},
7879                 /*ceil_mode=*/ceil_mode,
7880                 /*count_include_pad=*/count_include_pad);
7881             AllClose(output, lazy_output);
7882           });
7883         }
7884       }
7885     }
7886   }
7887 }
7888 
TEST_F(LazyOpsTest,TestAvgPool3DNonSquare)7889 TEST_F(LazyOpsTest, TestAvgPool3DNonSquare) {
7890   torch::Tensor input = torch::rand(
7891       {1, 1, 7, 7, 7},
7892       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7893   int kernel_size = 4;
7894   for (int stride = 1; stride <= 2; ++stride) {
7895     for (int padding = 0; padding <= 1; ++padding) {
7896       for (bool count_include_pad : {true, false}) {
7897         // Test ceil_mode=true through the CPU interop.
7898         for (bool ceil_mode : {false, true}) {
7899           torch::Tensor output = torch::avg_pool3d(
7900               input,
7901               /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7902               /*stride=*/{stride, stride + 1, stride},
7903               /*padding=*/{padding, padding + 1, padding},
7904               /*ceil_mode=*/ceil_mode,
7905               /*count_include_pad=*/count_include_pad);
7906           ForEachDevice([&](const torch::Device& device) {
7907             torch::Tensor lazy_input = CopyToDevice(input, device);
7908             torch::Tensor lazy_output = torch::avg_pool3d(
7909                 lazy_input,
7910                 /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7911                 /*stride=*/{stride, stride + 1, stride},
7912                 /*padding=*/{padding, padding + 1, padding},
7913                 /*ceil_mode=*/ceil_mode,
7914                 /*count_include_pad=*/count_include_pad);
7915             AllClose(output, lazy_output);
7916           });
7917         }
7918       }
7919     }
7920   }
7921 }
7922 
TEST_F(LazyOpsTest,TestAvgPool2DNoBatch)7923 TEST_F(LazyOpsTest, TestAvgPool2DNoBatch) {
7924   torch::Tensor input = torch::rand(
7925       {1, 7, 7}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7926   int kernel_size = 2;
7927   for (int stride = 1; stride <= 2; ++stride) {
7928     for (int padding = 0; padding <= 1; ++padding) {
7929       for (bool count_include_pad : {true, false}) {
7930         // Test ceil_mode=true through the CPU interop.
7931         for (bool ceil_mode : {false, true}) {
7932           torch::Tensor output = torch::avg_pool2d(
7933               input,
7934               /*kernel_size=*/{kernel_size, kernel_size},
7935               /*stride=*/{stride, stride},
7936               /*padding=*/{padding, padding},
7937               /*ceil_mode=*/ceil_mode,
7938               /*count_include_pad=*/count_include_pad);
7939           ForEachDevice([&](const torch::Device& device) {
7940             torch::Tensor lazy_input = CopyToDevice(input, device);
7941             torch::Tensor lazy_output = torch::avg_pool2d(
7942                 lazy_input,
7943                 /*kernel_size=*/{kernel_size, kernel_size},
7944                 /*stride=*/{stride, stride},
7945                 /*padding=*/{padding, padding},
7946                 /*ceil_mode=*/ceil_mode,
7947                 /*count_include_pad=*/count_include_pad);
7948             AllClose(output, lazy_output);
7949           });
7950         }
7951       }
7952     }
7953   }
7954 }
7955 
TEST_F(LazyOpsTest,TestAvgPool3DNoBatch)7956 TEST_F(LazyOpsTest, TestAvgPool3DNoBatch) {
7957   torch::Tensor input = torch::rand(
7958       {1, 7, 7, 7},
7959       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7960   int kernel_size = 2;
7961   for (int stride = 1; stride <= 2; ++stride) {
7962     for (int padding = 0; padding <= 1; ++padding) {
7963       for (bool count_include_pad : {true, false}) {
7964         // Test ceil_mode=true through the CPU interop.
7965         for (bool ceil_mode : {false, true}) {
7966           torch::Tensor output = torch::avg_pool3d(
7967               input,
7968               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7969               /*stride=*/{stride, stride, stride},
7970               /*padding=*/{padding, padding, padding},
7971               /*ceil_mode=*/ceil_mode,
7972               /*count_include_pad=*/count_include_pad);
7973           ForEachDevice([&](const torch::Device& device) {
7974             torch::Tensor lazy_input = CopyToDevice(input, device);
7975             torch::Tensor lazy_output = torch::avg_pool3d(
7976                 lazy_input,
7977                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7978                 /*stride=*/{stride, stride, stride},
7979                 /*padding=*/{padding, padding, padding},
7980                 /*ceil_mode=*/ceil_mode,
7981                 /*count_include_pad=*/count_include_pad);
7982             AllClose(output, lazy_output);
7983           });
7984         }
7985       }
7986     }
7987   }
7988 }
7989 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool2D)7990 TEST_F(LazyOpsTest, TestAdaptiveAvgPool2D) {
7991   torch::Tensor input = torch::rand(
7992       {4, 1, 28, 28},
7993       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7994   for (int64_t output_size : {7, 4}) {
7995     torch::Tensor output =
7996         torch::adaptive_avg_pool2d(input, {output_size, output_size});
7997     ForEachDevice([&](const torch::Device& device) {
7998       torch::Tensor lazy_input = CopyToDevice(input, device);
7999       torch::Tensor lazy_output =
8000           torch::adaptive_avg_pool2d(lazy_input, {output_size, output_size});
8001       AllClose(output, lazy_output);
8002     });
8003   }
8004 }
8005 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool3D)8006 TEST_F(LazyOpsTest, TestAdaptiveAvgPool3D) {
8007   torch::Tensor input = torch::rand(
8008       {9, 4, 56, 28, 28},
8009       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8010   for (int64_t output_size : {7, 4}) {
8011     torch::Tensor output = torch::adaptive_avg_pool3d(
8012         input, {output_size, output_size, output_size});
8013     ForEachDevice([&](const torch::Device& device) {
8014       torch::Tensor lazy_input = CopyToDevice(input, device);
8015       torch::Tensor lazy_output = torch::adaptive_avg_pool3d(
8016           lazy_input, {output_size, output_size, output_size});
8017       AllClose(output, lazy_output);
8018     });
8019   }
8020 }
8021 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool3DNoBatch)8022 TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DNoBatch) {
8023   torch::Tensor input = torch::rand(
8024       {3, 56, 28, 28},
8025       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8026   for (int64_t output_size : {7, 4}) {
8027     torch::Tensor output = torch::adaptive_avg_pool3d(
8028         input, {output_size, output_size, output_size});
8029     ForEachDevice([&](const torch::Device& device) {
8030       torch::Tensor lazy_input = CopyToDevice(input, device);
8031       torch::Tensor lazy_output = torch::adaptive_avg_pool3d(
8032           lazy_input, {output_size, output_size, output_size});
8033       AllClose(output, lazy_output);
8034     });
8035   }
8036 }
8037 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool2DNoBatch)8038 TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DNoBatch) {
8039   torch::Tensor input = torch::rand(
8040       {1, 56, 56}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8041   for (int64_t output_size : {7, 8}) {
8042     torch::Tensor output =
8043         torch::adaptive_avg_pool2d(input, {output_size, output_size});
8044     ForEachDevice([&](const torch::Device& device) {
8045       torch::Tensor lazy_input = CopyToDevice(input, device);
8046       torch::Tensor lazy_output =
8047           torch::adaptive_avg_pool2d(lazy_input, {output_size, output_size});
8048       AllClose(output, lazy_output);
8049     });
8050   }
8051 }
8052 
TEST_F(LazyOpsTest,TestMaxUnpool2D)8053 TEST_F(LazyOpsTest, TestMaxUnpool2D) {
8054   int kernel_size = 2;
8055   torch::Tensor input = torch::rand(
8056       {2, 2, 8, 8},
8057       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8058   for (int stride = 1; stride <= 2; ++stride) {
8059     for (int padding = 0; padding <= 1; ++padding) {
8060       // Test ceil_mode=true through the CPU interop.
8061       for (bool ceil_mode : {false, true}) {
8062         // Test dilation through the CPU interop.
8063         for (int dilation = 1; dilation <= 2; ++dilation) {
8064           torch::Tensor output;
8065           torch::Tensor indices;
8066           std::tie(output, indices) = torch::max_pool2d_with_indices(
8067               input,
8068               /*kernel_size=*/{kernel_size, kernel_size},
8069               /*stride=*/{stride, stride},
8070               /*padding=*/{padding, padding},
8071               /*dilation=*/{dilation, dilation},
8072               /*ceil_mode=*/ceil_mode);
8073 
8074           std::vector<int64_t> output_size({input.size(2), input.size(3)});
8075           at::Tensor utensor =
8076               torch::max_unpool2d(output, indices, output_size);
8077 
8078           ForEachDevice([&](const torch::Device& device) {
8079             torch::Tensor lazy_output = CopyToDevice(output, device);
8080             torch::Tensor lazy_indices = CopyToDevice(indices, device);
8081             at::Tensor lazy_utensor =
8082                 torch::max_unpool2d(lazy_output, lazy_indices, output_size);
8083             AllClose(utensor, lazy_utensor);
8084           });
8085         }
8086       }
8087     }
8088   }
8089 }
8090 
TEST_F(LazyOpsTest,TestMaxUnpool3D)8091 TEST_F(LazyOpsTest, TestMaxUnpool3D) {
8092   int kernel_size = 2;
8093   torch::Tensor input = torch::rand(
8094       {1, 1, 4, 4, 4},
8095       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8096   for (int stride = 1; stride <= 2; ++stride) {
8097     for (int padding = 0; padding <= 1; ++padding) {
8098       // Test ceil_mode=true through the CPU interop.
8099       for (bool ceil_mode : {false, true}) {
8100         // Test dilation through the CPU interop.
8101         for (int dilation = 1; dilation <= 2; ++dilation) {
8102           torch::Tensor output;
8103           torch::Tensor indices;
8104           std::tie(output, indices) = torch::max_pool3d_with_indices(
8105               input,
8106               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
8107               /*stride=*/{stride, stride, stride},
8108               /*padding=*/{padding, padding, padding},
8109               /*dilation=*/{dilation, dilation, dilation},
8110               /*ceil_mode=*/ceil_mode);
8111 
8112           std::vector<int64_t> output_size(
8113               {input.size(2), input.size(3), input.size(4)});
8114           at::Tensor utensor = torch::max_unpool3d(
8115               output,
8116               indices,
8117               output_size,
8118               /*stride=*/{stride, stride, stride},
8119               /*padding=*/{padding, padding, padding});
8120 
8121           ForEachDevice([&](const torch::Device& device) {
8122             torch::Tensor lazy_output = CopyToDevice(output, device);
8123             torch::Tensor lazy_indices = CopyToDevice(indices, device);
8124             at::Tensor lazy_utensor = torch::max_unpool3d(
8125                 lazy_output,
8126                 lazy_indices,
8127                 output_size,
8128                 /*stride=*/{stride, stride, stride},
8129                 /*padding=*/{padding, padding, padding});
8130             AllClose(utensor, lazy_utensor);
8131           });
8132         }
8133       }
8134     }
8135   }
8136 }
8137 
TEST_F(LazyOpsTest,TestNllLoss)8138 TEST_F(LazyOpsTest, TestNllLoss) {
8139   // TODO(whc) debug divide-by-zero failure under ASAN
8140   GTEST_SKIP();
8141 
8142   int batch = 6;
8143   int classes = 2;
8144   // TODO(asuhan): Fix the torch::kDouble case.
8145   for (auto dtype : {torch::kFloat}) {
8146     for (int ignore_index : {-1, 0, 1, 5}) {
8147       for (bool def_weight : {false, true}) {
8148         torch::Tensor input = torch::rand(
8149             {batch, classes},
8150             torch::TensorOptions(dtype).device(DefaultDevice()));
8151         torch::Tensor target = torch::randint(
8152             std::min(ignore_index, 0),
8153             classes,
8154             {batch},
8155             torch::TensorOptions(torch::kLong).device(DefaultDevice()));
8156         torch::Tensor weight;
8157         if (def_weight) {
8158           weight = torch::rand(
8159               {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
8160         }
8161         for (torch::Reduction::Reduction reduction :
8162              {torch::Reduction::Mean,
8163               torch::Reduction::Sum,
8164               torch::Reduction::None}) {
8165           torch::Tensor output = torch::nll_loss(
8166               /*self=*/input,
8167               /*target=*/target,
8168               /*weight=*/weight,
8169               /*reduction=*/reduction,
8170               /*ignore_index=*/ignore_index);
8171 
8172           ForEachDevice([&](const torch::Device& device) {
8173             torch::Tensor lazy_input = CopyToDevice(input, device);
8174             torch::Tensor lazy_target = CopyToDevice(target, device);
8175             torch::Tensor lazy_weight =
8176                 def_weight ? CopyToDevice(weight, device) : torch::Tensor();
8177             torch::Tensor lazy_output = torch::nll_loss(
8178                 /*self=*/lazy_input,
8179                 /*target=*/lazy_target,
8180                 /*weight=*/lazy_weight,
8181                 /*reduction=*/reduction,
8182                 /*ignore_index=*/ignore_index);
8183             AllClose(output, lazy_output);
8184           });
8185         }
8186       }
8187     }
8188   }
8189 }
8190 
TEST_F(LazyOpsTest,TestNllLoss2d)8191 TEST_F(LazyOpsTest, TestNllLoss2d) {
8192   int batch = 6;
8193   int classes = 2;
8194   int height = 3;
8195   int width = 3;
8196   // TODO(asuhan): Fix the torch::kDouble case.
8197   for (auto dtype : {torch::kFloat}) {
8198     for (int ignore_index : {-1, 0, 1, 5}) {
8199       for (bool def_weight : {false, true}) {
8200         torch::Tensor input = torch::rand(
8201             {batch, classes, height, width},
8202             torch::TensorOptions(dtype).device(DefaultDevice()));
8203         torch::Tensor target = torch::randint(
8204             std::min(ignore_index, 0),
8205             classes,
8206             {batch, height, width},
8207             torch::TensorOptions(torch::kLong).device(DefaultDevice()));
8208         torch::Tensor weight;
8209         if (def_weight) {
8210           weight = torch::rand(
8211               {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
8212         }
8213         for (torch::Reduction::Reduction reduction :
8214              {torch::Reduction::Mean,
8215               torch::Reduction::Sum,
8216               torch::Reduction::None}) {
8217           torch::Tensor output = torch::nll_loss2d(
8218               /*self=*/input,
8219               /*target=*/target,
8220               /*weight=*/weight,
8221               /*reduction=*/reduction,
8222               /*ignore_index=*/ignore_index);
8223 
8224           ForEachDevice([&](const torch::Device& device) {
8225             torch::Tensor lazy_input = CopyToDevice(input, device);
8226             torch::Tensor lazy_target = CopyToDevice(target, device);
8227             torch::Tensor lazy_weight =
8228                 def_weight ? CopyToDevice(weight, device) : torch::Tensor();
8229             torch::Tensor lazy_output = torch::nll_loss2d(
8230                 /*self=*/lazy_input,
8231                 /*target=*/lazy_target,
8232                 /*weight=*/lazy_weight,
8233                 /*reduction=*/reduction,
8234                 /*ignore_index=*/ignore_index);
8235             AllClose(output, lazy_output);
8236           });
8237         }
8238       }
8239     }
8240   }
8241 }
8242 
TEST_F(LazyOpsTest,TestSmoothL1Loss)8243 TEST_F(LazyOpsTest, TestSmoothL1Loss) {
8244   torch::Tensor input = torch::randn(
8245       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8246   torch::Tensor target = torch::randn(
8247       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8248   for (torch::Reduction::Reduction reduction :
8249        {torch::Reduction::None,
8250         torch::Reduction::Mean,
8251         torch::Reduction::Sum}) {
8252     for (double beta : {0.25, 1.}) {
8253       torch::Tensor output =
8254           torch::smooth_l1_loss(input, target, reduction, beta);
8255       ForEachDevice([&](const torch::Device& device) {
8256         torch::Tensor lazy_input = CopyToDevice(input, device);
8257         torch::Tensor lazy_target = CopyToDevice(target, device);
8258         torch::Tensor lazy_output =
8259             torch::smooth_l1_loss(lazy_input, lazy_target, reduction, beta);
8260         AllClose(output, lazy_output);
8261       });
8262     }
8263   }
8264 }
8265 
TEST_F(LazyOpsTest,TestL1Loss)8266 TEST_F(LazyOpsTest, TestL1Loss) {
8267   torch::Tensor input = torch::randn(
8268       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8269   torch::Tensor target = torch::randn(
8270       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8271   for (torch::Reduction::Reduction reduction :
8272        {torch::Reduction::None,
8273         torch::Reduction::Mean,
8274         torch::Reduction::Sum}) {
8275     torch::Tensor output = torch::l1_loss(input, target, reduction);
8276     ForEachDevice([&](const torch::Device& device) {
8277       torch::Tensor lazy_input = CopyToDevice(input, device);
8278       torch::Tensor lazy_target = CopyToDevice(target, device);
8279       torch::Tensor lazy_output =
8280           torch::l1_loss(lazy_input, lazy_target, reduction);
8281       AllClose(output, lazy_output);
8282     });
8283   }
8284 }
8285 
TEST_F(LazyOpsTest,TestL1LossBackward)8286 TEST_F(LazyOpsTest, TestL1LossBackward) {
8287   for (torch::Reduction::Reduction reduction :
8288        {torch::Reduction::None,
8289         torch::Reduction::Mean,
8290         torch::Reduction::Sum}) {
8291     auto testfn =
8292         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
8293       return torch::l1_loss(inputs[0], inputs[1], reduction);
8294     };
8295     ForEachDevice([&](const torch::Device& device) {
8296       TestBackward(
8297           {torch::rand(
8298                {2, 4},
8299                torch::TensorOptions(torch::kFloat)
8300                    .device(DefaultDevice())
8301                    .requires_grad(true)),
8302            torch::rand(
8303                {2, 4},
8304                torch::TensorOptions(torch::kFloat).device(DefaultDevice()))},
8305           device,
8306           testfn);
8307     });
8308   }
8309 }
8310 
TEST_F(LazyOpsTest,TestMseLoss)8311 TEST_F(LazyOpsTest, TestMseLoss) {
8312   torch::Tensor input = torch::randn(
8313       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8314   torch::Tensor target = torch::randn(
8315       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8316   for (torch::Reduction::Reduction reduction :
8317        {torch::Reduction::None,
8318         torch::Reduction::Mean,
8319         torch::Reduction::Sum}) {
8320     torch::Tensor output = torch::mse_loss(input, target, reduction);
8321     ForEachDevice([&](const torch::Device& device) {
8322       torch::Tensor lazy_input = CopyToDevice(input, device);
8323       torch::Tensor lazy_target = CopyToDevice(target, device);
8324       torch::Tensor lazy_output =
8325           torch::mse_loss(lazy_input, lazy_target, reduction);
8326       AllClose(output, lazy_output);
8327     });
8328   }
8329 }
8330 
TEST_F(LazyOpsTest,TestMseLossBackward)8331 TEST_F(LazyOpsTest, TestMseLossBackward) {
8332   for (torch::Reduction::Reduction reduction :
8333        {torch::Reduction::None,
8334         torch::Reduction::Mean,
8335         torch::Reduction::Sum}) {
8336     auto testfn =
8337         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
8338       return torch::mse_loss(inputs[0], inputs[1], reduction);
8339     };
8340     ForEachDevice([&](const torch::Device& device) {
8341       TestBackward(
8342           {torch::rand(
8343                {2, 4},
8344                torch::TensorOptions(torch::kFloat)
8345                    .device(DefaultDevice())
8346                    .requires_grad(true)),
8347            torch::rand(
8348                {2, 4},
8349                torch::TensorOptions(torch::kFloat).device(DefaultDevice()))},
8350           device,
8351           testfn);
8352     });
8353   }
8354 }
8355 
TEST_F(LazyOpsTest,TestBatchNorm1D)8356 TEST_F(LazyOpsTest, TestBatchNorm1D) {
8357   int num_features = 3;
8358   torch::Tensor input = torch::rand(
8359       {2, num_features, 4},
8360       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8361   torch::Tensor weight = torch::rand(
8362       {num_features},
8363       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8364   torch::Tensor bias = torch::rand(
8365       {num_features},
8366       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8367   torch::Tensor running_mean = torch::zeros(
8368       {num_features},
8369       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8370   torch::Tensor running_var = torch::ones(
8371       {num_features},
8372       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8373   double momentum = 0.1;
8374   double eps = 0.5;
8375   torch::Tensor undef;
8376   for (bool training : {true, false}) {
8377     for (bool undef_weight_bias : {false, true}) {
8378       torch::Tensor output = torch::batch_norm(
8379           /*input=*/input,
8380           /*weight=*/undef_weight_bias ? undef : weight,
8381           /*bias=*/undef_weight_bias ? undef : bias,
8382           /*running_mean=*/running_mean,
8383           /*running_var=*/running_var,
8384           /*training=*/training,
8385           /*momentum=*/momentum,
8386           /*eps=*/eps,
8387           /*cudnn_enabled=*/false);
8388       ForEachDevice([&](const torch::Device& device) {
8389         torch::Tensor lazy_input = CopyToDevice(input, device);
8390         torch::Tensor lazy_weight =
8391             undef_weight_bias ? undef : CopyToDevice(weight, device);
8392         torch::Tensor lazy_bias =
8393             undef_weight_bias ? undef : CopyToDevice(bias, device);
8394         torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device);
8395         torch::Tensor lazy_running_var = CopyToDevice(running_var, device);
8396         torch::Tensor lazy_output = torch::batch_norm(
8397             /*input=*/lazy_input,
8398             /*weight=*/lazy_weight,
8399             /*bias=*/lazy_bias,
8400             /*running_mean=*/lazy_running_mean,
8401             /*running_var=*/lazy_running_var,
8402             /*training=*/training,
8403             /*momentum=*/momentum,
8404             /*eps=*/eps,
8405             /*cudnn_enabled=*/false);
8406         AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
8407       });
8408     }
8409   }
8410 }
8411 
TEST_F(LazyOpsTest,TestBatchNorm2D)8412 TEST_F(LazyOpsTest, TestBatchNorm2D) {
8413   int num_features = 3;
8414   torch::Tensor input = torch::rand(
8415       {2, num_features, 4, 4},
8416       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8417   torch::Tensor weight = torch::rand(
8418       {num_features},
8419       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8420   torch::Tensor bias = torch::rand(
8421       {num_features},
8422       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8423   torch::Tensor running_mean = torch::zeros(
8424       {num_features},
8425       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8426   torch::Tensor running_var = torch::ones(
8427       {num_features},
8428       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8429   double momentum = 0.1;
8430   double eps = 0.5;
8431   torch::Tensor undef;
8432   for (bool training : {true, false}) {
8433     for (bool undef_weight_bias : {false, true}) {
8434       torch::Tensor output = torch::batch_norm(
8435           /*input=*/input,
8436           /*weight=*/undef_weight_bias ? undef : weight,
8437           /*bias=*/undef_weight_bias ? undef : bias,
8438           /*running_mean=*/running_mean,
8439           /*running_var=*/running_var,
8440           /*training=*/training,
8441           /*momentum=*/momentum,
8442           /*eps=*/eps,
8443           /*cudnn_enabled=*/false);
8444       ForEachDevice([&](const torch::Device& device) {
8445         torch::Tensor lazy_input = CopyToDevice(input, device);
8446         torch::Tensor lazy_weight =
8447             undef_weight_bias ? undef : CopyToDevice(weight, device);
8448         torch::Tensor lazy_bias =
8449             undef_weight_bias ? undef : CopyToDevice(bias, device);
8450         torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device);
8451         torch::Tensor lazy_running_var = CopyToDevice(running_var, device);
8452         torch::Tensor lazy_output = torch::batch_norm(
8453             /*input=*/lazy_input,
8454             /*weight=*/lazy_weight,
8455             /*bias=*/lazy_bias,
8456             /*running_mean=*/lazy_running_mean,
8457             /*running_var=*/lazy_running_var,
8458             /*training=*/training,
8459             /*momentum=*/momentum,
8460             /*eps=*/eps,
8461             /*cudnn_enabled=*/false);
8462         AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
8463       });
8464     }
8465   }
8466 }
8467 
TEST_F(LazyOpsTest,TestDim)8468 TEST_F(LazyOpsTest, TestDim) {
8469   torch::Tensor input = torch::rand(
8470       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8471   ForEachDevice([&](const torch::Device& device) {
8472     torch::Tensor lazy_input = CopyToDevice(input, device);
8473     EXPECT_EQ(input.dim(), lazy_input.dim());
8474   });
8475 }
8476 
TEST_F(LazyOpsTest,TestContiguous)8477 TEST_F(LazyOpsTest, TestContiguous) {
8478   torch::Tensor input = torch::rand(
8479       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8480   torch::Tensor output = torch::native::contiguous(input);
8481   ForEachDevice([&](const torch::Device& device) {
8482     torch::Tensor lazy_input = CopyToDevice(input, device);
8483     torch::Tensor lazy_output = torch::native::contiguous(lazy_input);
8484     AllClose(output, lazy_output);
8485   });
8486 }
8487 
TEST_F(LazyOpsTest,TestSqueezeAll)8488 TEST_F(LazyOpsTest, TestSqueezeAll) {
8489   torch::Tensor input = torch::rand(
8490       {2, 1, 3, 1},
8491       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8492   torch::Tensor output = torch::squeeze(input);
8493   ForEachDevice([&](const torch::Device& device) {
8494     torch::Tensor lazy_input = CopyToDevice(input, device);
8495     torch::Tensor lazy_output = torch::squeeze(lazy_input);
8496     AllClose(output, lazy_output);
8497   });
8498 }
8499 
TEST_F(LazyOpsTest,TestSqueezeAllInPlace)8500 TEST_F(LazyOpsTest, TestSqueezeAllInPlace) {
8501   ForEachDevice([&](const torch::Device& device) {
8502     torch::Tensor input = torch::rand(
8503         {2, 1, 3, 1},
8504         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8505     torch::Tensor lazy_input = CopyToDevice(input, device);
8506     torch::Tensor output = input.squeeze_();
8507     torch::Tensor lazy_output = lazy_input.squeeze_();
8508     AllClose(output, lazy_output);
8509     AllClose(input, lazy_input);
8510     ASSERT_EQ(input.dim(), lazy_input.dim());
8511     for (int64_t dim_idx = 0; dim_idx < input.dim(); ++dim_idx) {
8512       ASSERT_EQ(input.size(dim_idx), lazy_input.size(dim_idx));
8513     }
8514   });
8515 }
8516 
TEST_F(LazyOpsTest,TestSqueezeOne)8517 TEST_F(LazyOpsTest, TestSqueezeOne) {
8518   torch::Tensor input = torch::rand(
8519       {2, 1, 3, 1},
8520       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8521   int rank = input.dim();
8522   for (int dim = -rank; dim < rank; ++dim) {
8523     torch::Tensor output = torch::squeeze(input, dim);
8524     ForEachDevice([&](const torch::Device& device) {
8525       torch::Tensor lazy_input = CopyToDevice(input, device);
8526       torch::Tensor lazy_output = torch::squeeze(lazy_input, dim);
8527       AllClose(output, lazy_output);
8528     });
8529   }
8530 }
8531 
TEST_F(LazyOpsTest,TestSqueezeOneInPlace)8532 TEST_F(LazyOpsTest, TestSqueezeOneInPlace) {
8533   int rank = 4;
8534   for (int dim = -rank; dim < rank; ++dim) {
8535     ForEachDevice([&](const torch::Device& device) {
8536       torch::Tensor input = torch::rand(
8537           {2, 1, 3, 1},
8538           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8539       torch::Tensor lazy_input = CopyToDevice(input, device);
8540       torch::Tensor output = input.squeeze_(dim);
8541       torch::Tensor lazy_output = lazy_input.squeeze_(dim);
8542       AllClose(output, lazy_output);
8543       AllClose(input, lazy_input);
8544       ASSERT_EQ(input.dim(), lazy_input.dim());
8545       for (int64_t dim_idx = 0; dim_idx < input.dim(); ++dim_idx) {
8546         ASSERT_EQ(input.size(dim_idx), lazy_input.size(dim_idx));
8547       }
8548     });
8549   }
8550 }
8551 
TEST_F(LazyOpsTest,TestUnsqueeze)8552 TEST_F(LazyOpsTest, TestUnsqueeze) {
8553   torch::Tensor input = torch::rand(
8554       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8555   int rank = input.dim() + 1;
8556   for (int dim = -rank; dim < rank; ++dim) {
8557     torch::Tensor output = torch::unsqueeze(input, dim);
8558     ForEachDevice([&](const torch::Device& device) {
8559       torch::Tensor lazy_input = CopyToDevice(input, device);
8560       torch::Tensor lazy_output = torch::unsqueeze(lazy_input, dim);
8561       AllClose(output, lazy_output);
8562     });
8563   }
8564 }
8565 
TEST_F(LazyOpsTest,TestUnsqueezeInPlace)8566 TEST_F(LazyOpsTest, TestUnsqueezeInPlace) {
8567   torch::Tensor input = torch::rand(
8568       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8569   int rank = input.dim() + 1;
8570   for (int dim = -rank; dim < rank; ++dim) {
8571     ForEachDevice([&](const torch::Device& device) {
8572       torch::Tensor lazy_input = CopyToDevice(input, device);
8573       torch::Tensor output = input.unsqueeze_(dim);
8574       torch::Tensor lazy_output = lazy_input.unsqueeze_(dim);
8575       AllClose(output, lazy_output);
8576       AllClose(input, lazy_input);
8577       ASSERT_EQ(input.dim(), lazy_input.dim());
8578       for (int64_t dim_idx = 0; dim_idx < input.dim(); ++dim_idx) {
8579         ASSERT_EQ(input.size(dim_idx), lazy_input.size(dim_idx));
8580       }
8581     });
8582   }
8583 }
8584 
TEST_F(LazyOpsTest,TestMaskedFill)8585 TEST_F(LazyOpsTest, TestMaskedFill) {
8586   torch::Tensor input = torch::rand(
8587       {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8588   torch::Tensor mask = torch::randint(
8589       0, 2, {2, 3}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
8590   torch::Scalar value(42);
8591   torch::Tensor result = torch::masked_fill(input, mask, value);
8592   ForEachDevice([&](const torch::Device& device) {
8593     torch::Tensor lazy_input = CopyToDevice(input, device);
8594     torch::Tensor lazy_mask = CopyToDevice(mask, device);
8595     torch::Tensor lazy_result =
8596         torch::masked_fill(lazy_input, lazy_mask, value);
8597     AllClose(result, lazy_result);
8598   });
8599 }
8600 
TEST_F(LazyOpsTest,TestMaskedFillInPlace)8601 TEST_F(LazyOpsTest, TestMaskedFillInPlace) {
8602   torch::Scalar value(42);
8603   torch::Tensor mask = torch::randint(
8604       0, 2, {2, 3}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
8605   ForEachDevice([&](const torch::Device& device) {
8606     torch::Tensor input = torch::rand(
8607         {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8608     torch::Tensor lazy_input = CopyToDevice(input, device);
8609     torch::Tensor lazy_mask = CopyToDevice(mask, device);
8610     torch::Tensor result = input.masked_fill_(mask, value);
8611     torch::Tensor lazy_result = lazy_input.masked_fill_(lazy_mask, value);
8612     AllClose(result, lazy_result);
8613     AllClose(input, lazy_input);
8614   });
8615 }
8616 
TEST_F(LazyOpsTest,TestMaskedFillBroadcast)8617 TEST_F(LazyOpsTest, TestMaskedFillBroadcast) {
8618   torch::Tensor input = torch::rand(
8619       {2, 5, 4, 3},
8620       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8621   torch::Tensor mask = torch::randint(
8622       0, 2, {4, 1}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
8623   torch::Scalar value(42);
8624   torch::Tensor result = torch::masked_fill(input, mask, value);
8625   ForEachDevice([&](const torch::Device& device) {
8626     torch::Tensor lazy_input = CopyToDevice(input, device);
8627     torch::Tensor lazy_mask = CopyToDevice(mask, device);
8628     torch::Tensor lazy_result =
8629         torch::masked_fill(lazy_input, lazy_mask, value);
8630     AllClose(result, lazy_result);
8631   });
8632 }
8633 
TEST_F(LazyOpsTest,TestFill)8634 TEST_F(LazyOpsTest, TestFill) {
8635   torch::Scalar value(42);
8636   ForEachDevice([&](const torch::Device& device) {
8637     torch::Tensor input = torch::empty(
8638         {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8639     torch::Tensor lazy_input = CopyToDevice(input, device);
8640     torch::Tensor result = torch::fill_(input, value);
8641     torch::Tensor lazy_result = torch::fill_(lazy_input, value);
8642     AllClose(result, lazy_result);
8643     AllClose(input, lazy_input);
8644   });
8645 }
8646 
TEST_F(LazyOpsTest,TestFillWithRank0)8647 TEST_F(LazyOpsTest, TestFillWithRank0) {
8648   torch::Tensor value = torch::scalar_tensor(42);
8649   ForEachDevice([&](const torch::Device& device) {
8650     torch::Tensor input = torch::empty(
8651         {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8652     torch::Tensor lazy_input = CopyToDevice(input, device);
8653     torch::Tensor result = torch::fill_(input, value);
8654     torch::Tensor lazy_value = CopyToDevice(value, device);
8655     torch::Tensor lazy_result = torch::fill_(lazy_input, value);
8656     AllClose(result, lazy_result);
8657     AllClose(input, lazy_input);
8658   });
8659 }
8660 
TEST_F(LazyOpsTest,TestPermute)8661 TEST_F(LazyOpsTest, TestPermute) {
8662   torch::Tensor input = torch::rand(
8663       {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8664   std::vector<std::vector<int64_t>> dims_permutations = {
8665       {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}};
8666   int rank = input.dim();
8667   for (std::vector<int64_t> dims_permutation : dims_permutations) {
8668     for (bool negative_dims : {false, true}) {
8669       if (negative_dims) {
8670         std::for_each(
8671             dims_permutation.begin(),
8672             dims_permutation.end(),
8673             [rank](int64_t& dim) { dim -= rank; });
8674       }
8675       torch::Tensor output = input.permute(dims_permutation);
8676       ForEachDevice([&](const torch::Device& device) {
8677         torch::Tensor lazy_input = CopyToDevice(input, device);
8678         torch::Tensor lazy_output = lazy_input.permute(dims_permutation);
8679         AllClose(output, lazy_output);
8680       });
8681     }
8682   }
8683 }
8684 
TEST_F(LazyOpsTest,TestPermuteMod)8685 TEST_F(LazyOpsTest, TestPermuteMod) {
8686   std::vector<std::vector<int64_t>> dims_permutations = {
8687       {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}};
8688   std::vector<int64_t> input_sizes = {2, 3, 4};
8689   int rank = input_sizes.size();
8690   for (std::vector<int64_t> dims_permutation : dims_permutations) {
8691     for (bool negative_dims : {false, true}) {
8692       if (negative_dims) {
8693         std::for_each(
8694             dims_permutation.begin(),
8695             dims_permutation.end(),
8696             [rank](int64_t& dim) { dim -= rank; });
8697       }
8698       torch::Tensor input = torch::zeros(
8699           input_sizes,
8700           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8701       torch::Tensor one = torch::tensor(
8702           1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8703       torch::Tensor output = input.permute(dims_permutation);
8704       output.add_(one, 1.0);
8705       input.add_(one, 1.0);
8706       ForEachDevice([&](const torch::Device& device) {
8707         torch::Tensor xinput = torch::zeros(
8708             input_sizes,
8709             torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8710         torch::Tensor lazy_input = CopyToDevice(xinput, device);
8711         torch::Tensor lazy_one = CopyToDevice(one, device);
8712         torch::Tensor lazy_output = lazy_input.permute(dims_permutation);
8713         lazy_output.add_(lazy_one, 1.0);
8714         lazy_input.add_(lazy_one, 1.0);
8715         AllClose(output, lazy_output);
8716         AllClose(input, lazy_input);
8717       });
8718     }
8719   }
8720 }
8721 
TEST_F(LazyOpsTest,TestFlip)8722 TEST_F(LazyOpsTest, TestFlip) {
8723   torch::Tensor input = torch::rand(
8724       {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8725   std::vector<std::vector<int64_t>> dim_powerset = {
8726       {0}, {1}, {2}, {0, 1}, {1, 2}, {2, 0}, {0, 1, 2}};
8727   for (std::vector<int64_t> flip_dims : dim_powerset) {
8728     for (bool negative_dims : {false, true}) {
8729       if (negative_dims) {
8730         std::for_each(
8731             flip_dims.begin(), flip_dims.end(), [](int64_t& dim) { dim -= 3; });
8732       }
8733       torch::Tensor output = torch::flip(input, flip_dims);
8734       ForEachDevice([&](const torch::Device& device) {
8735         torch::Tensor lazy_input = CopyToDevice(input, device);
8736         torch::Tensor lazy_output = torch::flip(lazy_input, flip_dims);
8737         AllClose(output, lazy_output);
8738       });
8739     }
8740   }
8741 }
8742 
TEST_F(LazyOpsTest,TestPixelShuffle)8743 TEST_F(LazyOpsTest, TestPixelShuffle) {
8744   torch::Tensor input = torch::rand(
8745       {5, 18, 4, 4},
8746       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8747   int upscale_factor = 3;
8748   ForEachDevice([&](const torch::Device& device) {
8749     torch::Tensor lazy_input = CopyToDevice(input, device);
8750     torch::Tensor output = torch::pixel_shuffle(input, upscale_factor);
8751     torch::Tensor lazy_output =
8752         torch::pixel_shuffle(lazy_input, upscale_factor);
8753     AllClose(output, lazy_output);
8754   });
8755 }
8756 
TEST_F(LazyOpsTest,TestSumToSize)8757 TEST_F(LazyOpsTest, TestSumToSize) {
8758   torch::Tensor input = torch::rand(
8759       {4, 6, 3, 7},
8760       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8761   std::vector<int64_t> out_size = {4, 1, 1, 7};
8762   ForEachDevice([&](const torch::Device& device) {
8763     torch::Tensor lazy_input = CopyToDevice(input, device);
8764     torch::Tensor output = input.sum_to_size(out_size);
8765     torch::Tensor lazy_output = lazy_input.sum_to_size(out_size);
8766     AllClose(output, lazy_output);
8767   });
8768 }
8769 
TEST_F(LazyOpsTest,TestTransposeDims)8770 TEST_F(LazyOpsTest, TestTransposeDims) {
8771   torch::Tensor input = torch::rand(
8772       {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8773   int dim0 = 0;
8774   int dim1 = 2;
8775   torch::Tensor output = torch::transpose(input, dim0, dim1);
8776   ForEachDevice([&](const torch::Device& device) {
8777     torch::Tensor lazy_input = CopyToDevice(input, device);
8778     torch::Tensor lazy_output = torch::transpose(lazy_input, dim0, dim1);
8779     AllClose(output, lazy_output);
8780   });
8781 }
8782 
TEST_F(LazyOpsTest,TestTransposeDimsMod)8783 TEST_F(LazyOpsTest, TestTransposeDimsMod) {
8784   std::vector<int64_t> input_sizes = {2, 3, 4};
8785   int dim0 = 0;
8786   int dim1 = 2;
8787   torch::Tensor input = torch::zeros(
8788       input_sizes, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8789   torch::Tensor one = torch::tensor(
8790       1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8791   torch::Tensor output = torch::transpose(input, dim0, dim1);
8792   output.add_(one, 1.0);
8793   input.add_(one, 1.0);
8794   ForEachDevice([&](const torch::Device& device) {
8795     torch::Tensor xinput = torch::zeros(
8796         input_sizes,
8797         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8798     torch::Tensor lazy_input = CopyToDevice(xinput, device);
8799     torch::Tensor lazy_one = CopyToDevice(one, device);
8800     torch::Tensor lazy_output = torch::transpose(lazy_input, dim0, dim1);
8801     lazy_output.add_(lazy_one, 1.0);
8802     lazy_input.add_(lazy_one, 1.0);
8803     AllClose(output, lazy_output);
8804     AllClose(input, lazy_input);
8805   });
8806 }
8807 
TEST_F(LazyOpsTest,TestTransposeDimsInPlace)8808 TEST_F(LazyOpsTest, TestTransposeDimsInPlace) {
8809   torch::Tensor input = torch::rand(
8810       {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8811   int dim0 = 0;
8812   int dim1 = 2;
8813   ForEachDevice([&](const torch::Device& device) {
8814     torch::Tensor lazy_input = CopyToDevice(input, device);
8815     torch::Tensor output = input.transpose_(dim0, dim1);
8816     torch::Tensor lazy_output = lazy_input.transpose_(dim0, dim1);
8817     AllClose(output, lazy_output);
8818     AllClose(input, lazy_input);
8819   });
8820 }
8821 
TEST_F(LazyOpsTest,TestSplit)8822 TEST_F(LazyOpsTest, TestSplit) {
8823   torch::Tensor input = torch::rand(
8824       {7, 8, 9}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8825   int rank = input.dim();
8826   for (int split_size : {2, 3}) {
8827     for (int dim = -rank; dim < rank; ++dim) {
8828       std::vector<torch::Tensor> outputs = torch::split(input, split_size, dim);
8829       ForEachDevice([&](const torch::Device& device) {
8830         torch::Tensor lazy_input = CopyToDevice(input, device);
8831         std::vector<torch::Tensor> lazy_outputs =
8832             torch::split(lazy_input, split_size, dim);
8833         ASSERT_EQ(outputs.size(), lazy_outputs.size());
8834         for (size_t i = 0; i < outputs.size(); ++i) {
8835           AllClose(outputs[i], lazy_outputs[i]);
8836         }
8837       });
8838     }
8839   }
8840 }
8841 
TEST_F(LazyOpsTest,TestSplitEmpty)8842 TEST_F(LazyOpsTest, TestSplitEmpty) {
8843   torch::Tensor input = torch::rand(
8844       {0}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8845   int split_size = 0;
8846   int dim = 0;
8847   std::vector<torch::Tensor> outputs = torch::split(input, split_size, dim);
8848   ForEachDevice([&](const torch::Device& device) {
8849     torch::Tensor lazy_input = CopyToDevice(input, device);
8850     std::vector<torch::Tensor> lazy_outputs =
8851         torch::split(lazy_input, split_size, dim);
8852     ASSERT_EQ(outputs.size(), lazy_outputs.size());
8853     for (size_t i = 0; i < outputs.size(); ++i) {
8854       AllClose(outputs[i], lazy_outputs[i]);
8855     }
8856   });
8857 }
8858 
TEST_F(LazyOpsTest,TestSplitWithSizes)8859 TEST_F(LazyOpsTest, TestSplitWithSizes) {
8860   torch::Tensor input = torch::rand(
8861       {15, 15, 15},
8862       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8863   int rank = input.dim();
8864   for (int dim = -rank; dim < rank; ++dim) {
8865     std::vector<torch::Tensor> outputs =
8866         torch::split_with_sizes(input, {4, 5, 6}, dim);
8867     ForEachDevice([&](const torch::Device& device) {
8868       torch::Tensor lazy_input = CopyToDevice(input, device);
8869       std::vector<torch::Tensor> lazy_outputs =
8870           torch::split_with_sizes(lazy_input, {4, 5, 6}, dim);
8871       ASSERT_EQ(outputs.size(), lazy_outputs.size());
8872       for (size_t i = 0; i < outputs.size(); ++i) {
8873         AllClose(outputs[i], lazy_outputs[i]);
8874       }
8875     });
8876   }
8877 }
8878 
TEST_F(LazyOpsTest,TestCrossImplicitDim)8879 TEST_F(LazyOpsTest, TestCrossImplicitDim) {
8880   std::vector<std::vector<int64_t>> dim_sizes = {
8881       {4, 5, 3}, {4, 3, 5}, {3, 4, 5}};
8882   for (auto dim_size : dim_sizes) {
8883     torch::Tensor input = torch::rand(
8884         dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8885     torch::Tensor other = torch::rand(
8886         dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8887     torch::Tensor result = torch::cross(input, other);
8888     ForEachDevice([&](const torch::Device& device) {
8889       torch::Tensor lazy_input = CopyToDevice(input, device);
8890       torch::Tensor lazy_other = CopyToDevice(other, device);
8891       torch::Tensor lazy_result = torch::cross(lazy_input, lazy_other);
8892       AllClose(result, lazy_result);
8893     });
8894   }
8895 }
8896 
TEST_F(LazyOpsTest,TestCrossExplicitDim)8897 TEST_F(LazyOpsTest, TestCrossExplicitDim) {
8898   std::vector<int64_t> dim_size = {3, 3};
8899   torch::Tensor input = torch::rand(
8900       dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8901   torch::Tensor other = torch::rand(
8902       dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8903   int rank = dim_size.size();
8904   for (int dim = -rank; dim < rank; ++dim) {
8905     torch::Tensor result = torch::cross(input, other, dim);
8906     ForEachDevice([&](const torch::Device& device) {
8907       torch::Tensor lazy_input = CopyToDevice(input, device);
8908       torch::Tensor lazy_other = CopyToDevice(other, device);
8909       torch::Tensor lazy_result = torch::cross(lazy_input, lazy_other, dim);
8910       AllClose(result, lazy_result);
8911     });
8912   }
8913 }
8914 
TEST_F(LazyOpsTest,TestCrossZeroDim)8915 TEST_F(LazyOpsTest, TestCrossZeroDim) {
8916   torch::Tensor input = torch::rand(
8917       {0, 1, 3, 0},
8918       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8919   torch::Tensor result = torch::cross(input, input);
8920   ForEachDevice([&](const torch::Device& device) {
8921     torch::Tensor lazy_input = CopyToDevice(input, device);
8922     torch::Tensor lazy_result = torch::cross(lazy_input, lazy_input);
8923     AllClose(result, lazy_result);
8924   });
8925 }
8926 
TEST_F(LazyOpsTest,TestTriu)8927 TEST_F(LazyOpsTest, TestTriu) {
8928   int size = 5;
8929   torch::Tensor input = torch::rand(
8930       {size, size},
8931       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8932   // Test all diagonals and out of bounds (must be no-op).
8933   for (int diagonal = -size; diagonal <= size; ++diagonal) {
8934     torch::Tensor output = torch::triu(input, diagonal);
8935     ForEachDevice([&](const torch::Device& device) {
8936       torch::Tensor lazy_input = CopyToDevice(input, device);
8937       torch::Tensor lazy_output = torch::triu(lazy_input, diagonal);
8938       AllClose(output, lazy_output);
8939     });
8940   }
8941 }
8942 
TEST_F(LazyOpsTest,TestTriuNonSquare)8943 TEST_F(LazyOpsTest, TestTriuNonSquare) {
8944   int size = 5;
8945   torch::Tensor input = torch::rand(
8946       {size, size + 1},
8947       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8948   // Test all diagonals and out of bounds (must be no-op).
8949   for (int diagonal = -size; diagonal <= size; ++diagonal) {
8950     torch::Tensor output = torch::triu(input, diagonal);
8951     ForEachDevice([&](const torch::Device& device) {
8952       torch::Tensor lazy_input = CopyToDevice(input, device);
8953       torch::Tensor lazy_output = torch::triu(lazy_input, diagonal);
8954       AllClose(output, lazy_output);
8955     });
8956   }
8957 }
8958 
TEST_F(LazyOpsTest,TestTriuBatch)8959 TEST_F(LazyOpsTest, TestTriuBatch) {
8960   int size = 5;
8961   int batch_size = 3;
8962   torch::Tensor input = torch::rand(
8963       {batch_size, size, size},
8964       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8965   // Test all diagonals and out of bounds (must be no-op).
8966   for (int diagonal = -size; diagonal <= size; ++diagonal) {
8967     torch::Tensor output = torch::triu(input, diagonal);
8968     ForEachDevice([&](const torch::Device& device) {
8969       torch::Tensor lazy_input = CopyToDevice(input, device);
8970       torch::Tensor lazy_output = torch::triu(lazy_input, diagonal);
8971       AllClose(output, lazy_output);
8972     });
8973   }
8974 }
8975 
TEST_F(LazyOpsTest,TestTril)8976 TEST_F(LazyOpsTest, TestTril) {
8977   int size = 5;
8978   torch::Tensor input = torch::rand(
8979       {size, size},
8980       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8981   // Test all diagonals and out of bounds (must be no-op).
8982   for (int diagonal = -size; diagonal <= size; ++diagonal) {
8983     torch::Tensor output = torch::tril(input, diagonal);
8984     ForEachDevice([&](const torch::Device& device) {
8985       torch::Tensor lazy_input = CopyToDevice(input, device);
8986       torch::Tensor lazy_output = torch::tril(lazy_input, diagonal);
8987       AllClose(output, lazy_output);
8988     });
8989   }
8990 }
8991 
TEST_F(LazyOpsTest,TestTrilNonSquare)8992 TEST_F(LazyOpsTest, TestTrilNonSquare) {
8993   int size = 5;
8994   torch::Tensor input = torch::rand(
8995       {size, size + 1},
8996       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8997   // Test all diagonals and out of bounds (must be no-op).
8998   for (int diagonal = -size; diagonal <= size; ++diagonal) {
8999     torch::Tensor output = torch::tril(input, diagonal);
9000     ForEachDevice([&](const torch::Device& device) {
9001       torch::Tensor lazy_input = CopyToDevice(input, device);
9002       torch::Tensor lazy_output = torch::tril(lazy_input, diagonal);
9003       AllClose(output, lazy_output);
9004     });
9005   }
9006 }
9007 
TEST_F(LazyOpsTest,TestTrilBatch)9008 TEST_F(LazyOpsTest, TestTrilBatch) {
9009   int size = 5;
9010   int batch_size = 3;
9011   torch::Tensor input = torch::rand(
9012       {batch_size, size, size},
9013       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9014   // Test all diagonals and out of bounds (must be no-op).
9015   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9016     torch::Tensor output = torch::tril(input, diagonal);
9017     ForEachDevice([&](const torch::Device& device) {
9018       torch::Tensor lazy_input = CopyToDevice(input, device);
9019       torch::Tensor lazy_output = torch::tril(lazy_input, diagonal);
9020       AllClose(output, lazy_output);
9021     });
9022   }
9023 }
9024 
TEST_F(LazyOpsTest,TestTriuInPlace)9025 TEST_F(LazyOpsTest, TestTriuInPlace) {
9026   int size = 5;
9027   // Test all diagonals and out of bounds (must be no-op).
9028   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9029     ForEachDevice([&](const torch::Device& device) {
9030       torch::Tensor input = torch::rand(
9031           {size, size},
9032           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9033       torch::Tensor lazy_input = CopyToDevice(input, device);
9034       torch::Tensor output = input.triu_(diagonal);
9035       torch::Tensor lazy_output = lazy_input.triu_(diagonal);
9036       AllClose(output, lazy_output);
9037       AllClose(input, lazy_input);
9038     });
9039   }
9040 }
9041 
TEST_F(LazyOpsTest,TestTrilInPlace)9042 TEST_F(LazyOpsTest, TestTrilInPlace) {
9043   int size = 5;
9044   // Test all diagonals and out of bounds (must be no-op).
9045   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9046     ForEachDevice([&](const torch::Device& device) {
9047       torch::Tensor input = torch::rand(
9048           {size, size},
9049           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9050       torch::Tensor lazy_input = CopyToDevice(input, device);
9051       torch::Tensor output = input.tril_(diagonal);
9052       torch::Tensor lazy_output = lazy_input.tril_(diagonal);
9053       AllClose(output, lazy_output);
9054       AllClose(input, lazy_input);
9055     });
9056   }
9057 }
9058 
TEST_F(LazyOpsTest,TestTrace)9059 TEST_F(LazyOpsTest, TestTrace) {
9060   int n = 5;
9061   torch::Tensor input = torch::rand(
9062       {n, n}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9063   torch::Tensor output = torch::trace(input);
9064   ForEachDevice([&](const torch::Device& device) {
9065     torch::Tensor lazy_input = CopyToDevice(input, device);
9066     torch::Tensor lazy_output = torch::trace(lazy_input);
9067     AllClose(output, lazy_output);
9068   });
9069 }
9070 
TEST_F(LazyOpsTest,TestTraceWide)9071 TEST_F(LazyOpsTest, TestTraceWide) {
9072   int lines = 3;
9073   int cols = 5;
9074   torch::Tensor input = torch::rand(
9075       {lines, cols},
9076       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9077   torch::Tensor output = torch::trace(input);
9078   ForEachDevice([&](const torch::Device& device) {
9079     torch::Tensor lazy_input = CopyToDevice(input, device);
9080     torch::Tensor lazy_output = torch::trace(lazy_input);
9081     AllClose(output, lazy_output);
9082   });
9083 }
9084 
TEST_F(LazyOpsTest,TestTraceNarrow)9085 TEST_F(LazyOpsTest, TestTraceNarrow) {
9086   int lines = 5;
9087   int cols = 3;
9088   torch::Tensor input = torch::rand(
9089       {lines, cols},
9090       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9091   torch::Tensor output = torch::trace(input);
9092   ForEachDevice([&](const torch::Device& device) {
9093     torch::Tensor lazy_input = CopyToDevice(input, device);
9094     torch::Tensor lazy_output = torch::trace(lazy_input);
9095     AllClose(output, lazy_output);
9096   });
9097 }
9098 
TEST_F(LazyOpsTest,TestDiagRank1)9099 TEST_F(LazyOpsTest, TestDiagRank1) {
9100   int size = 7;
9101   torch::Tensor input = torch::rand(
9102       {size}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9103   // Test all diagonals and out of bounds (must be no-op).
9104   for (int diagonal = -2 * size; diagonal <= 2 * size; ++diagonal) {
9105     torch::Tensor output = torch::diag(input, diagonal);
9106     ForEachDevice([&](const torch::Device& device) {
9107       torch::Tensor lazy_input = CopyToDevice(input, device);
9108       torch::Tensor lazy_output = torch::diag(lazy_input, diagonal);
9109       AllClose(output, lazy_output);
9110     });
9111   }
9112 }
9113 
TEST_F(LazyOpsTest,TestDiagRank2)9114 TEST_F(LazyOpsTest, TestDiagRank2) {
9115   int size = 7;
9116   torch::Tensor input = torch::rand(
9117       {size, size},
9118       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9119   // Test all diagonals and out of bounds (must be no-op).
9120   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9121     torch::Tensor output = torch::diag(input, diagonal);
9122     ForEachDevice([&](const torch::Device& device) {
9123       torch::Tensor lazy_input = CopyToDevice(input, device);
9124       torch::Tensor lazy_output = torch::diag(lazy_input, diagonal);
9125       AllClose(output, lazy_output);
9126     });
9127   }
9128 }
9129 
TEST_F(LazyOpsTest,TestDiagFlat)9130 TEST_F(LazyOpsTest, TestDiagFlat) {
9131   torch::Tensor input = torch::rand(
9132       {4, 3, 6, 7},
9133       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9134   for (int diagonal = -10; diagonal < 10; ++diagonal) {
9135     torch::Tensor output = torch::diagflat(input, diagonal);
9136     ForEachDevice([&](const torch::Device& device) {
9137       torch::Tensor lazy_input = CopyToDevice(input, device);
9138       torch::Tensor lazy_output = torch::diagflat(lazy_input, diagonal);
9139       AllClose(output, lazy_output);
9140     });
9141   }
9142 }
9143 
TEST_F(LazyOpsTest,TestDiagonal)9144 TEST_F(LazyOpsTest, TestDiagonal) {
9145   int size = 5;
9146   torch::Tensor input = torch::rand(
9147       {size, size},
9148       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9149   // Test all diagonals and out of bounds (must be no-op).
9150   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9151     torch::Tensor output = torch::diagonal(input, diagonal);
9152     ForEachDevice([&](const torch::Device& device) {
9153       torch::Tensor lazy_input = CopyToDevice(input, device);
9154       torch::Tensor lazy_output = torch::diagonal(lazy_input, diagonal);
9155       AllClose(output, lazy_output);
9156     });
9157   }
9158 }
9159 
TEST_F(LazyOpsTest,TestDiagonalUpdate)9160 TEST_F(LazyOpsTest, TestDiagonalUpdate) {
9161   int size = 5;
9162   // Test all diagonals and out of bounds (must be no-op).
9163   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9164     auto input = torch::rand(
9165         {size, size},
9166         torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9167     auto input_clone = input.clone();
9168     auto output = torch::diagonal(input, diagonal);
9169     output.add_(1);
9170 
9171     ForEachDevice([&](const torch::Device& device) {
9172       torch::Tensor lazy_input = CopyToDevice(input_clone, device);
9173       torch::Tensor lazy_output = torch::diagonal(lazy_input, diagonal);
9174       lazy_output.add_(1);
9175 
9176       AllClose(output, lazy_output);
9177       AllClose(input, lazy_input);
9178     });
9179   }
9180 }
9181 
TEST_F(LazyOpsTest,TestDiagonalNonSquare)9182 TEST_F(LazyOpsTest, TestDiagonalNonSquare) {
9183   int size = 5;
9184   torch::Tensor input = torch::rand(
9185       {size, size + 1},
9186       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9187   // Test all diagonals and out of bounds (must be no-op).
9188   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9189     torch::Tensor output = torch::diagonal(input, diagonal);
9190     ForEachDevice([&](const torch::Device& device) {
9191       torch::Tensor lazy_input = CopyToDevice(input, device);
9192       torch::Tensor lazy_output = torch::diagonal(lazy_input, diagonal);
9193       AllClose(output, lazy_output);
9194     });
9195   }
9196 }
9197 
TEST_F(LazyOpsTest,TestDiagonalBatch)9198 TEST_F(LazyOpsTest, TestDiagonalBatch) {
9199   int size = 5;
9200   int batch_size = 3;
9201   int dim1 = 1;
9202   int dim2 = 2;
9203   torch::Tensor input = torch::rand(
9204       {batch_size, size, size},
9205       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9206   // Test all diagonals and out of bounds (must be no-op).
9207   for (int diagonal = -size; diagonal <= size; ++diagonal) {
9208     torch::Tensor output =
9209         torch::diagonal(input, diagonal, /*dim1=*/dim1, /*dim1=*/dim2);
9210     ForEachDevice([&](const torch::Device& device) {
9211       torch::Tensor lazy_input = CopyToDevice(input, device);
9212       torch::Tensor lazy_output =
9213           torch::diagonal(lazy_input, diagonal, /*dim1=*/dim1, /*dim1=*/dim2);
9214       AllClose(output, lazy_output);
9215     });
9216   }
9217 }
9218 
TEST_F(LazyOpsTest,TestFlatten)9219 TEST_F(LazyOpsTest, TestFlatten) {
9220   torch::Tensor input = torch::rand({4, 7, 5, 3});
9221   int rank = input.dim();
9222   for (int pos_start_dim = 0; pos_start_dim < rank; ++pos_start_dim) {
9223     for (int pos_end_dim = pos_start_dim; pos_end_dim < rank; ++pos_end_dim) {
9224       for (bool negative_start_dim : {false, true}) {
9225         for (bool negative_end_dim : {false, true}) {
9226           int start_dim =
9227               negative_start_dim ? pos_start_dim - rank : pos_start_dim;
9228           int end_dim = negative_end_dim ? pos_end_dim - rank : pos_end_dim;
9229           torch::Tensor output = torch::flatten(input, start_dim, end_dim);
9230           ForEachDevice([&](const torch::Device& device) {
9231             torch::Tensor lazy_input = CopyToDevice(input, device);
9232             torch::Tensor lazy_output =
9233                 torch::flatten(lazy_input, start_dim, end_dim);
9234             AllClose(output, lazy_output);
9235           });
9236         }
9237       }
9238     }
9239   }
9240 }
9241 
TEST_F(LazyOpsTest,TestLogicalAnd)9242 TEST_F(LazyOpsTest, TestLogicalAnd) {
9243   for (torch::ScalarType scalar_type1 :
9244        {torch::kFloat,
9245         torch::kByte,
9246         torch::kChar,
9247         torch::kShort,
9248         torch::kInt,
9249         torch::kLong}) {
9250     torch::Tensor lhs = isFloatingType(scalar_type1)
9251         ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
9252         : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1));
9253     for (torch::ScalarType scalar_type2 :
9254          {torch::kFloat,
9255           torch::kByte,
9256           torch::kChar,
9257           torch::kShort,
9258           torch::kInt,
9259           torch::kLong}) {
9260       torch::Tensor rhs = isFloatingType(scalar_type2)
9261           ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
9262           : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
9263       torch::Tensor result = torch::logical_and(lhs, rhs);
9264       ForEachDevice([&](const torch::Device& device) {
9265         torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9266         torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9267         torch::Tensor lazy_result = torch::logical_and(lazy_lhs, lazy_rhs);
9268         AllEqual(result, lazy_result);
9269       });
9270     }
9271   }
9272 
9273   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
9274   ExpectCounterChanged("xla::logical_and_out", GetIgnoredCounters());
9275 }
9276 
TEST_F(LazyOpsTest,TestBitwiseAnd)9277 TEST_F(LazyOpsTest, TestBitwiseAnd) {
9278   torch::Tensor lhs = torch::randint(
9279       0,
9280       std::numeric_limits<int32_t>::max(),
9281       {4, 2},
9282       torch::TensorOptions(torch::kInt));
9283   torch::Tensor rhs = torch::randint(
9284       0,
9285       std::numeric_limits<int32_t>::max(),
9286       {4, 2},
9287       torch::TensorOptions(torch::kInt));
9288   torch::Tensor result = lhs.__and__(rhs);
9289   ForEachDevice([&](const torch::Device& device) {
9290     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9291     torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9292     torch::Tensor lazy_result = lazy_lhs.__and__(lazy_rhs);
9293     AllEqual(result, lazy_result);
9294   });
9295 }
9296 
TEST_F(LazyOpsTest,TestBitwiseAndInPlace)9297 TEST_F(LazyOpsTest, TestBitwiseAndInPlace) {
9298   torch::Tensor lhs = torch::randint(
9299       0,
9300       std::numeric_limits<int32_t>::max(),
9301       {4, 2},
9302       torch::TensorOptions(torch::kInt));
9303   torch::Tensor rhs = torch::randint(
9304       0,
9305       std::numeric_limits<int32_t>::max(),
9306       {4, 2},
9307       torch::TensorOptions(torch::kInt));
9308   ForEachDevice([&](const torch::Device& device) {
9309     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9310     torch::Tensor result = lhs.__iand__(rhs);
9311     torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9312     torch::Tensor lazy_result = lazy_lhs.__iand__(lazy_rhs);
9313     AllEqual(result, lazy_result);
9314     AllEqual(lhs, lazy_lhs);
9315   });
9316 }
9317 
TEST_F(LazyOpsTest,TestBitwiseAndScalar)9318 TEST_F(LazyOpsTest, TestBitwiseAndScalar) {
9319   torch::Tensor lhs = torch::randint(
9320       0,
9321       std::numeric_limits<int32_t>::max(),
9322       {4, 2},
9323       torch::TensorOptions(torch::kInt));
9324   torch::Scalar rhs(123456789);
9325   torch::Tensor result = lhs.__and__(rhs);
9326   ForEachDevice([&](const torch::Device& device) {
9327     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9328     torch::Tensor lazy_result = lazy_lhs.__and__(rhs);
9329     AllEqual(result, lazy_result);
9330   });
9331 }
9332 
TEST_F(LazyOpsTest,TestBitwiseAndScalarInPlace)9333 TEST_F(LazyOpsTest, TestBitwiseAndScalarInPlace) {
9334   torch::Tensor lhs = torch::randint(
9335       0,
9336       std::numeric_limits<int32_t>::max(),
9337       {4, 2},
9338       torch::TensorOptions(torch::kInt));
9339   torch::Scalar rhs(123456789);
9340   ForEachDevice([&](const torch::Device& device) {
9341     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9342     torch::Tensor result = lhs.__iand__(rhs);
9343     torch::Tensor lazy_result = lazy_lhs.__iand__(rhs);
9344     AllEqual(result, lazy_result);
9345     AllEqual(lhs, lazy_lhs);
9346   });
9347 }
9348 
TEST_F(LazyOpsTest,TestBitwiseAndPromotion)9349 TEST_F(LazyOpsTest, TestBitwiseAndPromotion) {
9350   torch::Tensor input = torch::rand(
9351       {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9352   torch::Tensor view = input.reshape(-1);
9353   torch::Tensor result = torch::__and__(view.gt(0), view.ne(0));
9354   ForEachDevice([&](const torch::Device& device) {
9355     torch::Tensor lazy_input = CopyToDevice(input, device);
9356     torch::Tensor lazy_view = lazy_input.reshape(-1);
9357     torch::Tensor lazy_result =
9358         torch::__and__(lazy_view.gt(0), lazy_view.ne(0));
9359     AllEqual(result, lazy_result);
9360   });
9361 }
9362 
TEST_F(LazyOpsTest,TestBitwiseOr)9363 TEST_F(LazyOpsTest, TestBitwiseOr) {
9364   torch::Tensor lhs = torch::randint(
9365       0,
9366       std::numeric_limits<int32_t>::max(),
9367       {4, 2},
9368       torch::TensorOptions(torch::kInt));
9369   torch::Tensor rhs = torch::randint(
9370       0,
9371       std::numeric_limits<int32_t>::max(),
9372       {4, 2},
9373       torch::TensorOptions(torch::kInt));
9374   torch::Tensor result = lhs.__or__(rhs);
9375   ForEachDevice([&](const torch::Device& device) {
9376     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9377     torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9378     torch::Tensor lazy_result = lazy_lhs.__or__(lazy_rhs);
9379     AllEqual(result, lazy_result);
9380   });
9381 }
9382 
TEST_F(LazyOpsTest,TestBitwiseOrInPlace)9383 TEST_F(LazyOpsTest, TestBitwiseOrInPlace) {
9384   torch::Tensor lhs = torch::randint(
9385       0,
9386       std::numeric_limits<int32_t>::max(),
9387       {4, 2},
9388       torch::TensorOptions(torch::kInt));
9389   torch::Tensor rhs = torch::randint(
9390       0,
9391       std::numeric_limits<int32_t>::max(),
9392       {4, 2},
9393       torch::TensorOptions(torch::kInt));
9394   ForEachDevice([&](const torch::Device& device) {
9395     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9396     torch::Tensor result = lhs.__ior__(rhs);
9397     torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9398     torch::Tensor lazy_result = lazy_lhs.__ior__(lazy_rhs);
9399     AllEqual(result, lazy_result);
9400     AllEqual(lhs, lazy_lhs);
9401   });
9402 }
9403 
TEST_F(LazyOpsTest,TestBitwiseOrScalar)9404 TEST_F(LazyOpsTest, TestBitwiseOrScalar) {
9405   torch::Tensor lhs = torch::randint(
9406       0,
9407       std::numeric_limits<int32_t>::max(),
9408       {4, 2},
9409       torch::TensorOptions(torch::kInt));
9410   torch::Scalar rhs(123456789);
9411   torch::Tensor result = lhs.__or__(rhs);
9412   ForEachDevice([&](const torch::Device& device) {
9413     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9414     torch::Tensor lazy_result = lazy_lhs.__or__(rhs);
9415     AllEqual(result, lazy_result);
9416   });
9417 }
9418 
TEST_F(LazyOpsTest,TestBitwiseOrScalarInPlace)9419 TEST_F(LazyOpsTest, TestBitwiseOrScalarInPlace) {
9420   torch::Tensor lhs = torch::randint(
9421       0,
9422       std::numeric_limits<int32_t>::max(),
9423       {4, 2},
9424       torch::TensorOptions(torch::kInt));
9425   torch::Scalar rhs(123456789);
9426   ForEachDevice([&](const torch::Device& device) {
9427     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9428     torch::Tensor result = lhs.__ior__(rhs);
9429     torch::Tensor lazy_result = lazy_lhs.__ior__(rhs);
9430     AllEqual(result, lazy_result);
9431     AllEqual(lhs, lazy_lhs);
9432   });
9433 }
9434 
TEST_F(LazyOpsTest,TestBitwiseXor)9435 TEST_F(LazyOpsTest, TestBitwiseXor) {
9436   torch::Tensor lhs = torch::randint(
9437       0,
9438       std::numeric_limits<int32_t>::max(),
9439       {4, 2},
9440       torch::TensorOptions(torch::kInt));
9441   torch::Tensor rhs = torch::randint(
9442       0,
9443       std::numeric_limits<int32_t>::max(),
9444       {4, 2},
9445       torch::TensorOptions(torch::kInt));
9446   torch::Tensor result = lhs.__xor__(rhs);
9447   ForEachDevice([&](const torch::Device& device) {
9448     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9449     torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9450     torch::Tensor lazy_result = lazy_lhs.__xor__(lazy_rhs);
9451     AllEqual(result, lazy_result);
9452   });
9453 }
9454 
TEST_F(LazyOpsTest,TestBitwiseXorInPlace)9455 TEST_F(LazyOpsTest, TestBitwiseXorInPlace) {
9456   torch::Tensor lhs = torch::randint(
9457       0,
9458       std::numeric_limits<int32_t>::max(),
9459       {4, 2},
9460       torch::TensorOptions(torch::kInt));
9461   torch::Tensor rhs = torch::randint(
9462       0,
9463       std::numeric_limits<int32_t>::max(),
9464       {4, 2},
9465       torch::TensorOptions(torch::kInt));
9466   ForEachDevice([&](const torch::Device& device) {
9467     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9468     torch::Tensor result = lhs.__ixor__(rhs);
9469     torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9470     torch::Tensor lazy_result = lazy_lhs.__ixor__(lazy_rhs);
9471     AllEqual(result, lazy_result);
9472     AllEqual(lhs, lazy_lhs);
9473   });
9474 }
9475 
TEST_F(LazyOpsTest,TestBitwiseXorScalar)9476 TEST_F(LazyOpsTest, TestBitwiseXorScalar) {
9477   torch::Tensor lhs = torch::randint(
9478       0,
9479       std::numeric_limits<int32_t>::max(),
9480       {4, 2},
9481       torch::TensorOptions(torch::kInt));
9482   torch::Scalar rhs(123456789);
9483   torch::Tensor result = lhs.__xor__(rhs);
9484   ForEachDevice([&](const torch::Device& device) {
9485     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9486     torch::Tensor lazy_result = lazy_lhs.__xor__(rhs);
9487     AllEqual(result, lazy_result);
9488   });
9489 }
9490 
TEST_F(LazyOpsTest,TestBitwiseXorScalarInPlace)9491 TEST_F(LazyOpsTest, TestBitwiseXorScalarInPlace) {
9492   torch::Tensor lhs = torch::randint(
9493       0,
9494       std::numeric_limits<int32_t>::max(),
9495       {4, 2},
9496       torch::TensorOptions(torch::kInt));
9497   torch::Scalar rhs(123456789);
9498   ForEachDevice([&](const torch::Device& device) {
9499     torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9500     torch::Tensor result = lhs.__ixor__(rhs);
9501     torch::Tensor lazy_result = lazy_lhs.__ixor__(rhs);
9502     AllEqual(result, lazy_result);
9503     AllEqual(lhs, lazy_lhs);
9504   });
9505 }
9506 
TEST_F(LazyOpsTest,TestLshift)9507 TEST_F(LazyOpsTest, TestLshift) {
9508   torch::Tensor input = torch::ones(
9509       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9510   torch::Tensor shift_amount = torch::randint(
9511       16,
9512       input.sizes(),
9513       torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9514   torch::Tensor result = torch::__lshift__(input, shift_amount);
9515   ForEachDevice([&](const torch::Device& device) {
9516     torch::Tensor lazy_input = CopyToDevice(input, device);
9517     torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9518     torch::Tensor lazy_result =
9519         torch::__lshift__(lazy_input, lazy_shift_amount);
9520     AllClose(result, lazy_result);
9521   });
9522 }
9523 
TEST_F(LazyOpsTest,TestLshiftInPlace)9524 TEST_F(LazyOpsTest, TestLshiftInPlace) {
9525   torch::Tensor input = torch::ones(
9526       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9527   ForEachDevice([&](const torch::Device& device) {
9528     torch::Tensor lazy_input = CopyToDevice(input, device);
9529     torch::Tensor shift_amount = torch::randint(
9530         16,
9531         input.sizes(),
9532         torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9533     torch::Tensor result = input.__ilshift__(shift_amount);
9534     torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9535     torch::Tensor lazy_result = lazy_input.__ilshift__(lazy_shift_amount);
9536     AllClose(result, lazy_result);
9537     AllClose(input, lazy_input);
9538   });
9539 }
9540 
TEST_F(LazyOpsTest,TestLshiftScalar)9541 TEST_F(LazyOpsTest, TestLshiftScalar) {
9542   torch::Tensor input = torch::ones(
9543       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9544   torch::Scalar shift_amount = 3;
9545   torch::Tensor result = torch::__lshift__(input, shift_amount);
9546   ForEachDevice([&](const torch::Device& device) {
9547     torch::Tensor lazy_input = CopyToDevice(input, device);
9548     torch::Tensor lazy_result = torch::__lshift__(lazy_input, shift_amount);
9549     AllClose(result, lazy_result);
9550   });
9551 }
9552 
TEST_F(LazyOpsTest,TestLshiftScalarInPlace)9553 TEST_F(LazyOpsTest, TestLshiftScalarInPlace) {
9554   torch::Tensor input = torch::ones(
9555       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9556   torch::Scalar shift_amount = 3;
9557   ForEachDevice([&](const torch::Device& device) {
9558     torch::Tensor lazy_input = CopyToDevice(input, device);
9559     torch::Tensor result = input.__ilshift__(shift_amount);
9560     torch::Tensor lazy_result = lazy_input.__ilshift__(shift_amount);
9561     AllClose(result, lazy_result);
9562     AllClose(input, lazy_input);
9563   });
9564 }
9565 
TEST_F(LazyOpsTest,TestRshift)9566 TEST_F(LazyOpsTest, TestRshift) {
9567   torch::Tensor input = torch::ones(
9568       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9569   torch::Tensor shift_amount = torch::randint(
9570       16,
9571       input.sizes(),
9572       torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9573   torch::Tensor result = torch::__rshift__(input, shift_amount);
9574   ForEachDevice([&](const torch::Device& device) {
9575     torch::Tensor lazy_input = CopyToDevice(input, device);
9576     torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9577     torch::Tensor lazy_result =
9578         torch::__rshift__(lazy_input, lazy_shift_amount);
9579     AllClose(result, lazy_result);
9580   });
9581 }
9582 
TEST_F(LazyOpsTest,TestRshiftInPlace)9583 TEST_F(LazyOpsTest, TestRshiftInPlace) {
9584   torch::Tensor input = torch::ones(
9585       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9586   ForEachDevice([&](const torch::Device& device) {
9587     torch::Tensor lazy_input = CopyToDevice(input, device);
9588     torch::Tensor shift_amount = torch::randint(
9589         16,
9590         input.sizes(),
9591         torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9592     torch::Tensor result = input.__irshift__(shift_amount);
9593     torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9594     torch::Tensor lazy_result = lazy_input.__irshift__(lazy_shift_amount);
9595     AllClose(result, lazy_result);
9596     AllClose(input, lazy_input);
9597   });
9598 }
9599 
TEST_F(LazyOpsTest,TestRshiftScalar)9600 TEST_F(LazyOpsTest, TestRshiftScalar) {
9601   torch::Tensor input = torch::ones(
9602       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9603   torch::Scalar shift_amount = 3;
9604   torch::Tensor result = torch::__rshift__(input, shift_amount);
9605   ForEachDevice([&](const torch::Device& device) {
9606     torch::Tensor lazy_input = CopyToDevice(input, device);
9607     torch::Tensor lazy_result = torch::__rshift__(lazy_input, shift_amount);
9608     AllClose(result, lazy_result);
9609   });
9610 }
9611 
TEST_F(LazyOpsTest,TestRshiftScalarInPlace)9612 TEST_F(LazyOpsTest, TestRshiftScalarInPlace) {
9613   torch::Tensor input = torch::ones(
9614       {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9615   torch::Scalar shift_amount = 3;
9616   ForEachDevice([&](const torch::Device& device) {
9617     torch::Tensor lazy_input = CopyToDevice(input, device);
9618     torch::Tensor result = input.__irshift__(shift_amount);
9619     torch::Tensor lazy_result = lazy_input.__irshift__(shift_amount);
9620     AllClose(result, lazy_result);
9621     AllClose(input, lazy_input);
9622   });
9623 }
9624 
TEST_F(LazyOpsTest,TestMeshgrid)9625 TEST_F(LazyOpsTest, TestMeshgrid) {
9626   torch::Tensor a = torch::rand(
9627       {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9628   torch::Tensor b = torch::rand(
9629       {2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9630   torch::Tensor c = torch::rand(
9631       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9632   auto d = torch::meshgrid({a, b, c});
9633   ForEachDevice([&](const torch::Device& device) {
9634     torch::Tensor lazy_a = CopyToDevice(a, device);
9635     torch::Tensor lazy_b = CopyToDevice(b, device);
9636     torch::Tensor lazy_c = CopyToDevice(c, device);
9637     auto lazy_d = torch::meshgrid({lazy_a, lazy_b, lazy_c});
9638     EXPECT_EQ(d.size(), lazy_d.size());
9639     for (size_t i = 0; i < d.size(); ++i) {
9640       AllClose(d[i], lazy_d[i]);
9641     }
9642   });
9643 }
9644 
TEST_F(LazyOpsTest,TestConstantPad)9645 TEST_F(LazyOpsTest, TestConstantPad) {
9646   torch::Tensor input = torch::rand(
9647       {4, 2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9648   std::vector<int64_t> pad{1, 2, 3, 4, 5, 6};
9649   float pad_value = 5;
9650   torch::Tensor output = torch::constant_pad_nd(input, pad, pad_value);
9651   ForEachDevice([&](const torch::Device& device) {
9652     torch::Tensor lazy_input = CopyToDevice(input, device);
9653     torch::Tensor lazy_output =
9654         torch::constant_pad_nd(lazy_input, pad, pad_value);
9655     AllClose(output, lazy_output);
9656   });
9657 }
9658 
TEST_F(LazyOpsTest,TestConstantPadIncomplete)9659 TEST_F(LazyOpsTest, TestConstantPadIncomplete) {
9660   torch::Tensor input = torch::rand(
9661       {4, 2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9662   std::vector<int64_t> pad{1, 2};
9663   float pad_value = 5;
9664   torch::Tensor output = torch::constant_pad_nd(input, pad, pad_value);
9665   ForEachDevice([&](const torch::Device& device) {
9666     torch::Tensor lazy_input = CopyToDevice(input, device);
9667     torch::Tensor lazy_output =
9668         torch::constant_pad_nd(lazy_input, pad, pad_value);
9669     AllClose(output, lazy_output);
9670   });
9671 }
9672 
TEST_F(LazyOpsTest,TestReflectionPad2dRank3)9673 TEST_F(LazyOpsTest, TestReflectionPad2dRank3) {
9674   torch::Tensor input = torch::rand(
9675       {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9676   std::vector<int64_t> pad{2, 2, 2, 2};
9677   torch::Tensor output = torch::reflection_pad2d(input, pad);
9678   ForEachDevice([&](const torch::Device& device) {
9679     torch::Tensor lazy_input = CopyToDevice(input, device);
9680     torch::Tensor lazy_output = torch::reflection_pad2d(lazy_input, pad);
9681     AllClose(output, lazy_output);
9682   });
9683 }
9684 
TEST_F(LazyOpsTest,TestReflectionPad2dRank4)9685 TEST_F(LazyOpsTest, TestReflectionPad2dRank4) {
9686   torch::Tensor input = torch::rand(
9687       {2, 2, 3, 4},
9688       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9689   std::vector<int64_t> pad{2, 2, 2, 2};
9690   torch::Tensor output = torch::reflection_pad2d(input, pad);
9691   ForEachDevice([&](const torch::Device& device) {
9692     torch::Tensor lazy_input = CopyToDevice(input, device);
9693     torch::Tensor lazy_output = torch::reflection_pad2d(lazy_input, pad);
9694     AllClose(output, lazy_output);
9695   });
9696 }
9697 
TEST_F(LazyOpsTest,TestReflectionPad2dBackward)9698 TEST_F(LazyOpsTest, TestReflectionPad2dBackward) {
9699   std::vector<int64_t> pad{2, 3, 1, 2};
9700   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9701     return torch::reflection_pad2d(inputs[0], pad);
9702   };
9703   ForEachDevice([&](const torch::Device& device) {
9704     TestBackward(
9705         {torch::rand(
9706             {1, 2, 4, 4},
9707             torch::TensorOptions(torch::kFloat)
9708                 .device(DefaultDevice())
9709                 .requires_grad(true))},
9710         device,
9711         testfn);
9712   });
9713 }
9714 
TEST_F(LazyOpsTest,TestReplicationPad1d)9715 TEST_F(LazyOpsTest, TestReplicationPad1d) {
9716   torch::Tensor input = torch::rand(
9717       {1, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9718   std::vector<int64_t> pad{1, 2};
9719   torch::Tensor output = torch::replication_pad1d(input, pad);
9720   ForEachDevice([&](const torch::Device& device) {
9721     torch::Tensor lazy_input = CopyToDevice(input, device);
9722     torch::Tensor lazy_output = torch::replication_pad1d(lazy_input, pad);
9723     AllClose(output, lazy_output);
9724   });
9725 }
9726 
TEST_F(LazyOpsTest,TestReplicationPad1dZeroPad)9727 TEST_F(LazyOpsTest, TestReplicationPad1dZeroPad) {
9728   torch::Tensor input = torch::rand(
9729       {1, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9730   std::vector<int64_t> pad{1, 0};
9731   torch::Tensor output = torch::replication_pad1d(input, pad);
9732   ForEachDevice([&](const torch::Device& device) {
9733     torch::Tensor lazy_input = CopyToDevice(input, device);
9734     torch::Tensor lazy_output = torch::replication_pad1d(lazy_input, pad);
9735     AllClose(output, lazy_output);
9736   });
9737 }
9738 
TEST_F(LazyOpsTest,TestReplicationPad1dBackward)9739 TEST_F(LazyOpsTest, TestReplicationPad1dBackward) {
9740   std::vector<int64_t> pad{2, 3};
9741   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9742     return torch::replication_pad1d(inputs[0], pad);
9743   };
9744   ForEachDevice([&](const torch::Device& device) {
9745     TestBackward(
9746         {torch::rand(
9747             {2, 4},
9748             torch::TensorOptions(torch::kFloat)
9749                 .device(DefaultDevice())
9750                 .requires_grad(true))},
9751         device,
9752         testfn);
9753   });
9754 }
9755 
TEST_F(LazyOpsTest,TestReplicationPad2d)9756 TEST_F(LazyOpsTest, TestReplicationPad2d) {
9757   torch::Tensor input = torch::rand(
9758       {1, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9759   std::vector<int64_t> pad{1, 2, 2, 1};
9760   torch::Tensor output = torch::replication_pad2d(input, pad);
9761   ForEachDevice([&](const torch::Device& device) {
9762     torch::Tensor lazy_input = CopyToDevice(input, device);
9763     torch::Tensor lazy_output = torch::replication_pad2d(lazy_input, pad);
9764     AllClose(output, lazy_output);
9765   });
9766 }
9767 
TEST_F(LazyOpsTest,TestReplicationPad2dZeroPad)9768 TEST_F(LazyOpsTest, TestReplicationPad2dZeroPad) {
9769   torch::Tensor input = torch::rand(
9770       {1, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9771   std::vector<int64_t> pad{1, 0, 0, 1};
9772   torch::Tensor output = torch::replication_pad2d(input, pad);
9773   ForEachDevice([&](const torch::Device& device) {
9774     torch::Tensor lazy_input = CopyToDevice(input, device);
9775     torch::Tensor lazy_output = torch::replication_pad2d(lazy_input, pad);
9776     AllClose(output, lazy_output);
9777   });
9778 }
9779 
TEST_F(LazyOpsTest,TestReplicationPad2dBackward)9780 TEST_F(LazyOpsTest, TestReplicationPad2dBackward) {
9781   std::vector<int64_t> pad{2, 3, 1, 1};
9782   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9783     return torch::replication_pad2d(inputs[0], pad);
9784   };
9785   ForEachDevice([&](const torch::Device& device) {
9786     TestBackward(
9787         {torch::rand(
9788             {2, 3, 4},
9789             torch::TensorOptions(torch::kFloat)
9790                 .device(DefaultDevice())
9791                 .requires_grad(true))},
9792         device,
9793         testfn);
9794   });
9795 }
9796 
TEST_F(LazyOpsTest,TestAsStrided)9797 TEST_F(LazyOpsTest, TestAsStrided) {
9798   torch::Tensor input = torch::rand(
9799       {128, 320}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9800   std::vector<int64_t> size = {128, 20, 4, 4};
9801   std::vector<int64_t> stride = {320, 16, 4, 1};
9802   torch::Tensor output =
9803       torch::as_strided(input, /*size=*/size, /*stride=*/stride);
9804   ForEachDevice([&](const torch::Device& device) {
9805     torch::Tensor lazy_input = CopyToDevice(input, device);
9806     torch::Tensor lazy_output =
9807         torch::as_strided(lazy_input, /*size=*/size, /*stride=*/stride);
9808     AllClose(output, lazy_output);
9809   });
9810 }
9811 
TEST_F(LazyOpsTest,TestAsStridedInPlace)9812 TEST_F(LazyOpsTest, TestAsStridedInPlace) {
9813   torch::Tensor input = torch::rand(
9814       {128, 320}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9815   std::vector<int64_t> size = {128, 20, 4, 4};
9816   std::vector<int64_t> stride = {320, 16, 4, 1};
9817   ForEachDevice([&](const torch::Device& device) {
9818     torch::Tensor lazy_input = CopyToDevice(input, device);
9819     torch::Tensor output =
9820         torch::as_strided_(input, /*size=*/size, /*stride=*/stride);
9821     torch::Tensor lazy_output =
9822         torch::as_strided_(lazy_input, /*size=*/size, /*stride=*/stride);
9823     AllClose(output, lazy_output);
9824     AllClose(input, lazy_input);
9825   });
9826 }
9827 
TEST_F(LazyOpsTest,TestAsStridedWithOffset)9828 TEST_F(LazyOpsTest, TestAsStridedWithOffset) {
9829   torch::Tensor input = torch::rand(
9830       {4, 8, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9831   std::vector<int64_t> size = {4, 4, 2};
9832   std::vector<int64_t> stride = {8, 2, 1};
9833   int64_t storage_offset = 4;
9834   torch::Tensor output = torch::as_strided(
9835       input,
9836       /*size=*/size,
9837       /*stride=*/stride,
9838       /*storage_offset=*/storage_offset);
9839   ForEachDevice([&](const torch::Device& device) {
9840     torch::Tensor lazy_input = CopyToDevice(input, device);
9841     torch::Tensor lazy_output = torch::as_strided(
9842         lazy_input,
9843         /*size=*/size,
9844         /*stride=*/stride,
9845         /*storage_offset=*/storage_offset);
9846     AllClose(output, lazy_output);
9847   });
9848 }
9849 
TEST_F(LazyOpsTest,TestAsStridedWithInplaceCopy)9850 TEST_F(LazyOpsTest, TestAsStridedWithInplaceCopy) {
9851   torch::Tensor grad = torch::ones(
9852       {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9853   std::vector<int64_t> size = {4};
9854   std::vector<int64_t> stride = {1};
9855   torch::Tensor output = torch::zeros({4}, grad.options());
9856   output.as_strided(size, stride).copy_(grad);
9857   ForEachDevice([&](const torch::Device& device) {
9858     torch::Tensor lazy_grad = CopyToDevice(grad, device);
9859     torch::Tensor lazy_output = torch::zeros({4}, lazy_grad.options());
9860     lazy_output.as_strided(size, stride).copy_(lazy_grad);
9861     AllClose(output, lazy_output);
9862   });
9863 }
9864 
TEST_F(LazyOpsTest,TestEmptyStrided)9865 TEST_F(LazyOpsTest, TestEmptyStrided) {
9866   std::vector<int64_t> size = {4, 4, 2};
9867   std::vector<int64_t> stride = {8, 2, 1};
9868   torch::Tensor output = torch::empty_strided(/*size=*/size, /*stride=*/stride);
9869   ForEachDevice([&](const torch::Device& device) {
9870     torch::Tensor lazy_output =
9871         torch::empty_strided(/*size=*/size, /*stride=*/stride);
9872     EXPECT_EQ(output.sizes(), lazy_output.sizes());
9873     EXPECT_EQ(output.strides(), lazy_output.strides());
9874   });
9875 }
9876 
TEST_F(LazyOpsTest,TestAvgPool2DBackward)9877 TEST_F(LazyOpsTest, TestAvgPool2DBackward) {
9878   int kernel_size = 2;
9879   for (int stride = 1; stride <= 2; ++stride) {
9880     for (int padding = 0; padding <= 1; ++padding) {
9881       for (bool count_include_pad : {true, false}) {
9882         // Test ceil_mode=true through the CPU interop.
9883         for (bool ceil_mode : {false, true}) {
9884           auto testfn =
9885               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9886             return torch::avg_pool2d(
9887                 inputs[0],
9888                 /*kernel_size=*/{kernel_size, kernel_size},
9889                 /*stride=*/{stride, stride},
9890                 /*padding=*/{padding, padding},
9891                 /*ceil_mode=*/ceil_mode,
9892                 /*count_include_pad=*/count_include_pad);
9893           };
9894 
9895           ForEachDevice([&](const torch::Device& device) {
9896             TestBackward(
9897                 {torch::rand(
9898                     {1, 1, 7, 7},
9899                     torch::TensorOptions(torch::kFloat)
9900                         .device(DefaultDevice())
9901                         .requires_grad(true))},
9902                 device,
9903                 testfn);
9904           });
9905         }
9906       }
9907     }
9908   }
9909 }
9910 
TEST_F(LazyOpsTest,TestAvgPool3DBackward)9911 TEST_F(LazyOpsTest, TestAvgPool3DBackward) {
9912   int kernel_size = 2;
9913   for (int stride = 1; stride <= 2; ++stride) {
9914     for (int padding = 0; padding <= 1; ++padding) {
9915       for (bool count_include_pad : {true, false}) {
9916         // Test ceil_mode=true through the CPU interop.
9917         for (bool ceil_mode : {false, true}) {
9918           auto testfn =
9919               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9920             return torch::avg_pool3d(
9921                 inputs[0],
9922                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
9923                 /*stride=*/{stride, stride, stride},
9924                 /*padding=*/{padding, padding, padding},
9925                 /*ceil_mode=*/ceil_mode,
9926                 /*count_include_pad=*/count_include_pad);
9927           };
9928 
9929           ForEachDevice([&](const torch::Device& device) {
9930             TestBackward(
9931                 {torch::rand(
9932                     {1, 1, 7, 7, 7},
9933                     torch::TensorOptions(torch::kFloat)
9934                         .device(DefaultDevice())
9935                         .requires_grad(true))},
9936                 device,
9937                 testfn);
9938           });
9939         }
9940       }
9941     }
9942   }
9943 }
9944 
TEST_F(LazyOpsTest,TestAvgPool2DNoBatchBackward)9945 TEST_F(LazyOpsTest, TestAvgPool2DNoBatchBackward) {
9946   int kernel_size = 2;
9947   for (int stride = 1; stride <= 2; ++stride) {
9948     for (int padding = 0; padding <= 1; ++padding) {
9949       for (bool count_include_pad : {true, false}) {
9950         // Test ceil_mode=true through the CPU interop.
9951         for (bool ceil_mode : {false, true}) {
9952           auto testfn =
9953               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9954             return torch::avg_pool2d(
9955                 inputs[0],
9956                 /*kernel_size=*/{kernel_size, kernel_size},
9957                 /*stride=*/{stride, stride},
9958                 /*padding=*/{padding, padding},
9959                 /*ceil_mode=*/ceil_mode,
9960                 /*count_include_pad=*/count_include_pad);
9961           };
9962 
9963           ForEachDevice([&](const torch::Device& device) {
9964             TestBackward(
9965                 {torch::rand(
9966                     {1, 7, 7},
9967                     torch::TensorOptions(torch::kFloat)
9968                         .device(DefaultDevice())
9969                         .requires_grad(true))},
9970                 device,
9971                 testfn);
9972           });
9973         }
9974       }
9975     }
9976   }
9977 }
9978 
TEST_F(LazyOpsTest,TestAvgPool3DNoBatchBackward)9979 TEST_F(LazyOpsTest, TestAvgPool3DNoBatchBackward) {
9980   int kernel_size = 2;
9981   for (int stride = 1; stride <= 2; ++stride) {
9982     for (int padding = 0; padding <= 1; ++padding) {
9983       for (bool count_include_pad : {true, false}) {
9984         // Test ceil_mode=true through the CPU interop.
9985         for (bool ceil_mode : {false, true}) {
9986           auto testfn =
9987               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9988             return torch::avg_pool3d(
9989                 inputs[0],
9990                 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
9991                 /*stride=*/{stride, stride, stride},
9992                 /*padding=*/{padding, padding, padding},
9993                 /*ceil_mode=*/ceil_mode,
9994                 /*count_include_pad=*/count_include_pad);
9995           };
9996 
9997           ForEachDevice([&](const torch::Device& device) {
9998             TestBackward(
9999                 {torch::rand(
10000                     {1, 7, 7, 7},
10001                     torch::TensorOptions(torch::kFloat)
10002                         .device(DefaultDevice())
10003                         .requires_grad(true))},
10004                 device,
10005                 testfn);
10006           });
10007         }
10008       }
10009     }
10010   }
10011 }
10012 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool3DNoBatchBackward)10013 TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DNoBatchBackward) {
10014   if (IsCuda()) {
10015     GTEST_SKIP();
10016   }
10017   for (int64_t output_size : {7, 4}) {
10018     auto testfn =
10019         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10020       return torch::adaptive_avg_pool3d(
10021           inputs[0], {output_size, output_size, output_size});
10022     };
10023     ForEachDevice([&](const torch::Device& device) {
10024       TestBackward(
10025           {torch::rand(
10026               {1, 56, 28, 28},
10027               torch::TensorOptions(torch::kFloat)
10028                   .device(DefaultDevice())
10029                   .requires_grad(true))},
10030           device,
10031           testfn);
10032     });
10033   }
10034 }
10035 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool3DBackward)10036 TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DBackward) {
10037   if (IsCuda()) {
10038     GTEST_SKIP();
10039   }
10040   for (int64_t output_size : {7, 4}) {
10041     auto testfn =
10042         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10043       return torch::adaptive_avg_pool3d(
10044           inputs[0], {output_size, output_size, output_size});
10045     };
10046     ForEachDevice([&](const torch::Device& device) {
10047       TestBackward(
10048           {torch::rand(
10049               {4, 1, 56, 28, 28},
10050               torch::TensorOptions(torch::kFloat)
10051                   .device(DefaultDevice())
10052                   .requires_grad(true))},
10053           device,
10054           testfn);
10055     });
10056   }
10057 }
10058 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool2DBackward)10059 TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DBackward) {
10060   for (int64_t output_size : {7, 8}) {
10061     auto testfn =
10062         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10063       return torch::adaptive_avg_pool2d(inputs[0], {output_size, output_size});
10064     };
10065     ForEachDevice([&](const torch::Device& device) {
10066       TestBackward(
10067           {torch::rand(
10068               {4, 1, 56, 56},
10069               torch::TensorOptions(torch::kFloat)
10070                   .device(DefaultDevice())
10071                   .requires_grad(true))},
10072           device,
10073           testfn);
10074     });
10075   }
10076 }
10077 
TEST_F(LazyOpsTest,TestAdaptiveAvgPool2DNoBatchBackward)10078 TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DNoBatchBackward) {
10079   for (int64_t output_size : {7, 8}) {
10080     auto testfn =
10081         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10082       return torch::adaptive_avg_pool2d(inputs[0], {output_size, output_size});
10083     };
10084     ForEachDevice([&](const torch::Device& device) {
10085       TestBackward(
10086           {torch::rand(
10087               {1, 56, 56},
10088               torch::TensorOptions(torch::kFloat).requires_grad(true))},
10089           device,
10090           testfn);
10091     });
10092   }
10093 }
10094 
TEST_F(LazyOpsTest,TestConv2D)10095 TEST_F(LazyOpsTest, TestConv2D) {
10096   int in_channels = 4;
10097   int out_channels = 4;
10098   int kernel_size = 3;
10099   for (int stride = 1; stride <= 3; ++stride) {
10100     for (int padding = 0; padding <= 2; ++padding) {
10101       for (bool with_bias : {true, false}) {
10102         for (int dilation = 1; dilation <= 3; ++dilation) {
10103           for (int groups :
10104                {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10105             ForEachDevice([&](const torch::Device& device) {
10106               torch::Tensor input = torch::rand(
10107                   {1, in_channels, 7, 7},
10108                   torch::TensorOptions(torch::kDouble).device(DefaultDevice()));
10109               torch::Tensor weight = torch::rand(
10110                   {out_channels,
10111                    in_channels / groups,
10112                    kernel_size,
10113                    kernel_size},
10114                   torch::TensorOptions(torch::kDouble).device(DefaultDevice()));
10115               torch::Tensor bias = with_bias
10116                   ? torch::rand(
10117                         {out_channels},
10118                         torch::TensorOptions(torch::kDouble)
10119                             .device(DefaultDevice()))
10120                   : torch::Tensor();
10121 
10122               torch::Tensor lazy_input = CopyToDevice(input, device);
10123               torch::Tensor lazy_weight = CopyToDevice(weight, device);
10124               torch::Tensor lazy_bias =
10125                   with_bias ? CopyToDevice(bias, device) : torch::Tensor();
10126 
10127               torch::Tensor output = torch::conv2d(
10128                   input,
10129                   weight,
10130                   bias,
10131                   /*stride=*/{stride, stride},
10132                   /*padding=*/{padding, padding},
10133                   /*dilation=*/{dilation, dilation},
10134                   groups);
10135               torch::Tensor lazy_output = torch::conv2d(
10136                   lazy_input,
10137                   lazy_weight,
10138                   lazy_bias,
10139                   /*stride=*/{stride, stride},
10140                   /*padding=*/{padding, padding},
10141                   /*dilation=*/{dilation, dilation},
10142                   groups);
10143               AllClose(output, lazy_output);
10144             });
10145           }
10146         }
10147       }
10148     }
10149   }
10150 }
10151 
TEST_F(LazyOpsTest,TestConv2DBackward)10152 TEST_F(LazyOpsTest, TestConv2DBackward) {
10153   int in_channels = 4;
10154   int out_channels = 4;
10155   int kernel_size = 3;
10156   for (int stride = 1; stride <= 3; ++stride) {
10157     for (int padding = 0; padding <= 2; ++padding) {
10158       for (bool with_bias : {true, false}) {
10159         for (int dilation = 1; dilation <= 3; ++dilation) {
10160           for (int groups :
10161                {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10162             auto testfn =
10163                 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10164               return torch::conv2d(
10165                   inputs[0],
10166                   inputs[1],
10167                   inputs[2],
10168                   /*stride=*/{stride, stride},
10169                   /*padding=*/{padding, padding},
10170                   /*dilation=*/{dilation, dilation},
10171                   groups);
10172             };
10173 
10174             ForEachDevice([&](const torch::Device& device) {
10175               torch::Tensor bias = with_bias
10176                   ? torch::rand(
10177                         {out_channels},
10178                         torch::TensorOptions(torch::kDouble)
10179                             .device(DefaultDevice()))
10180                   : torch::Tensor();
10181               TestBackward(
10182                   {torch::rand(
10183                        {1, in_channels, 7, 7},
10184                        torch::TensorOptions(torch::kDouble)
10185                            .device(DefaultDevice())
10186                            .requires_grad(true)),
10187                    torch::rand(
10188                        {out_channels,
10189                         in_channels / groups,
10190                         kernel_size,
10191                         kernel_size},
10192                        torch::TensorOptions(torch::kDouble)
10193                            .device(DefaultDevice())
10194                            .requires_grad(true)),
10195                    bias},
10196                   device,
10197                   testfn);
10198             });
10199           }
10200         };
10201       }
10202     }
10203   }
10204 }
10205 
TEST_F(LazyOpsTest,TestTransposedConv2DBackward)10206 TEST_F(LazyOpsTest, TestTransposedConv2DBackward) {
10207   int in_channels = 4;
10208   int out_channels = 4;
10209   int kernel_size = 3;
10210   for (int stride = 1; stride <= 2; ++stride) {
10211     for (int padding = 0; padding <= 1; ++padding) {
10212       for (int dilation = 1; dilation <= 2; ++dilation) {
10213         for (int output_padding = 0;
10214              output_padding < std::max(stride, dilation);
10215              ++output_padding) {
10216           for (bool with_bias : {true, false}) {
10217             for (int groups :
10218                  {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10219               auto testfn = [&](const std::vector<torch::Tensor>& inputs)
10220                   -> torch::Tensor {
10221                 return torch::conv_transpose2d(
10222                     inputs[0],
10223                     inputs[1],
10224                     inputs[2],
10225                     /*stride=*/{stride, stride + 1},
10226                     /*padding=*/{padding, padding + 1},
10227                     /*output_padding=*/output_padding,
10228                     /*groups=*/groups,
10229                     /*dilation=*/{dilation, dilation + 1});
10230               };
10231               ForEachDevice([&](const torch::Device& device) {
10232                 torch::Tensor input = torch::rand(
10233                     {4, out_channels, 7, 7},
10234                     torch::TensorOptions(torch::kFloat)
10235                         .device(DefaultDevice())
10236                         .requires_grad(true));
10237                 torch::Tensor weight = torch::rand(
10238                     {out_channels,
10239                      in_channels / groups,
10240                      kernel_size,
10241                      kernel_size},
10242                     torch::TensorOptions(torch::kFloat)
10243                         .device(DefaultDevice())
10244                         .requires_grad(true));
10245                 torch::Tensor bias = with_bias
10246                     ? torch::rand(
10247                           {in_channels},
10248                           torch::TensorOptions(torch::kFloat)
10249                               .device(DefaultDevice())
10250                               .requires_grad(true))
10251                     : torch::Tensor();
10252                 TestBackward(
10253                     {input, weight, bias},
10254                     device,
10255                     testfn,
10256                     /*rtol=*/1e-5,
10257                     /*atol=*/1e-5);
10258               });
10259             }
10260           };
10261         }
10262       }
10263     }
10264   }
10265 }
10266 
TEST_F(LazyOpsTest,TestConv3DBackward)10267 TEST_F(LazyOpsTest, TestConv3DBackward) {
10268   int in_channels = 4;
10269   int out_channels = 4;
10270   int kernel_size = 3;
10271   for (int stride = 1; stride <= 3; ++stride) {
10272     for (int padding = 1; padding <= 2; ++padding) {
10273       for (bool with_bias : {true, false}) {
10274         for (int dilation = 1; dilation <= 2; ++dilation) {
10275           for (int groups :
10276                {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10277             auto testfn =
10278                 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10279               return torch::conv3d(
10280                   inputs[0],
10281                   inputs[1],
10282                   inputs[2],
10283                   /*stride=*/{stride, stride, stride},
10284                   /*padding=*/{padding, padding, padding},
10285                   /*dilation=*/{dilation, dilation, dilation},
10286                   groups);
10287             };
10288 
10289             ForEachDevice([&](const torch::Device& device) {
10290               torch::Tensor bias = with_bias
10291                   ? torch::rand(
10292                         {out_channels},
10293                         torch::TensorOptions(torch::kDouble)
10294                             .device(DefaultDevice()))
10295                   : torch::Tensor();
10296               TestBackward(
10297                   {torch::rand(
10298                        {4, in_channels, 7, 7, 7},
10299                        torch::TensorOptions(torch::kDouble)
10300                            .device(DefaultDevice())
10301                            .requires_grad(true)),
10302                    torch::rand(
10303                        {out_channels,
10304                         in_channels / groups,
10305                         kernel_size,
10306                         kernel_size,
10307                         kernel_size},
10308                        torch::TensorOptions(torch::kDouble)
10309                            .device(DefaultDevice())
10310                            .requires_grad(true)),
10311                    bias},
10312                   device,
10313                   testfn);
10314             });
10315           }
10316         };
10317       }
10318     }
10319   }
10320 }
10321 
TEST_F(LazyOpsTest,TestTransposedConv3DBackward)10322 TEST_F(LazyOpsTest, TestTransposedConv3DBackward) {
10323   int in_channels = 4;
10324   int out_channels = 4;
10325   int kernel_size = 3;
10326   for (int stride = 1; stride <= 2; ++stride) {
10327     for (int padding = 0; padding <= 1; ++padding) {
10328       for (int dilation = 1; dilation <= 2; ++dilation) {
10329         for (int output_padding = 0;
10330              output_padding < std::max(stride, dilation);
10331              ++output_padding) {
10332           for (bool with_bias : {true, false}) {
10333             for (int groups :
10334                  {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10335               auto testfn = [&](const std::vector<torch::Tensor>& inputs)
10336                   -> torch::Tensor {
10337                 return torch::conv_transpose3d(
10338                     inputs[0],
10339                     inputs[1],
10340                     inputs[2],
10341                     /*stride=*/{stride, stride + 1, stride},
10342                     /*padding=*/{padding, padding + 1, stride},
10343                     /*output_padding=*/output_padding,
10344                     /*groups=*/groups,
10345                     /*dilation=*/{dilation, dilation + 1, dilation});
10346               };
10347               ForEachDevice([&](const torch::Device& device) {
10348                 torch::Tensor input = torch::rand(
10349                     {4, out_channels, 7, 7, 7},
10350                     torch::TensorOptions(torch::kDouble)
10351                         .device(DefaultDevice())
10352                         .requires_grad(true));
10353                 torch::Tensor weight = torch::rand(
10354                     {out_channels,
10355                      in_channels / groups,
10356                      kernel_size,
10357                      kernel_size,
10358                      kernel_size},
10359                     torch::TensorOptions(torch::kDouble)
10360                         .device(DefaultDevice())
10361                         .requires_grad(true));
10362                 torch::Tensor bias = with_bias
10363                     ? torch::rand(
10364                           {in_channels},
10365                           torch::TensorOptions(torch::kDouble)
10366                               .device(DefaultDevice())
10367                               .requires_grad(true))
10368                     : torch::Tensor();
10369                 TestBackward({input, weight, bias}, device, testfn);
10370               });
10371             }
10372           };
10373         }
10374       }
10375     }
10376   }
10377 }
10378 
TEST_F(LazyOpsTest,TestMaxPool2DBackward)10379 TEST_F(LazyOpsTest, TestMaxPool2DBackward) {
10380   int kernel_size = 3;
10381   for (int stride = 1; stride <= 2; ++stride) {
10382     for (int padding = 0; padding <= 1; ++padding) {
10383       // Test ceil_mode=true through the CPU interop.
10384       for (bool ceil_mode : {false, true}) {
10385         auto testfn =
10386             [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10387           return torch::max_pool2d(
10388               inputs[0],
10389               /*kernel_size=*/{kernel_size, kernel_size},
10390               /*stride=*/{stride, stride},
10391               /*padding=*/{padding, padding},
10392               /*dilation=*/{1, 1},
10393               /*ceil_mode=*/ceil_mode);
10394         };
10395 
10396         ForEachDevice([&](const torch::Device& device) {
10397           TestBackward(
10398               {torch::rand(
10399                   {1, 2, 8, 8},
10400                   torch::TensorOptions(torch::kFloat)
10401                       .device(DefaultDevice())
10402                       .requires_grad(true))},
10403               device,
10404               testfn);
10405         });
10406       }
10407     }
10408   }
10409 }
10410 
TEST_F(LazyOpsTest,TestMaxPool3DBackward)10411 TEST_F(LazyOpsTest, TestMaxPool3DBackward) {
10412   int kernel_size = 3;
10413   for (int stride = 1; stride <= 2; ++stride) {
10414     for (int padding = 0; padding <= 1; ++padding) {
10415       // Test ceil_mode=true through the CPU interop.
10416       for (bool ceil_mode : {false, true}) {
10417         auto testfn =
10418             [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10419           return torch::max_pool3d(
10420               inputs[0],
10421               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
10422               /*stride=*/{stride, stride, stride},
10423               /*padding=*/{padding, padding, padding},
10424               /*dilation=*/{1, 1, 1},
10425               /*ceil_mode=*/ceil_mode);
10426         };
10427 
10428         ForEachDevice([&](const torch::Device& device) {
10429           TestBackward(
10430               {torch::rand(
10431                   {1, 2, 4, 4, 4},
10432                   torch::TensorOptions(torch::kFloat)
10433                       .device(DefaultDevice())
10434                       .requires_grad(true))},
10435               device,
10436               testfn);
10437         });
10438       }
10439     }
10440   }
10441 }
10442 
TEST_F(LazyOpsTest,TestMaxPool2DNoBatchBackward)10443 TEST_F(LazyOpsTest, TestMaxPool2DNoBatchBackward) {
10444   int kernel_size = 3;
10445   for (int stride = 1; stride <= 2; ++stride) {
10446     for (int padding = 0; padding <= 1; ++padding) {
10447       // Test ceil_mode=true through the CPU interop.
10448       for (bool ceil_mode : {false, true}) {
10449         auto testfn =
10450             [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10451           return torch::max_pool2d(
10452               inputs[0],
10453               /*kernel_size=*/{kernel_size, kernel_size},
10454               /*stride=*/{stride, stride},
10455               /*padding=*/{padding, padding},
10456               /*dilation=*/{1, 1},
10457               /*ceil_mode=*/ceil_mode);
10458         };
10459 
10460         ForEachDevice([&](const torch::Device& device) {
10461           TestBackward(
10462               {torch::rand(
10463                   {2, 8, 8},
10464                   torch::TensorOptions(torch::kFloat)
10465                       .device(DefaultDevice())
10466                       .requires_grad(true))},
10467               device,
10468               testfn);
10469         });
10470       }
10471     }
10472   }
10473 }
10474 
TEST_F(LazyOpsTest,TestMaxPool3DNoBatchBackward)10475 TEST_F(LazyOpsTest, TestMaxPool3DNoBatchBackward) {
10476   int kernel_size = 3;
10477   for (int stride = 1; stride <= 2; ++stride) {
10478     for (int padding = 0; padding <= 1; ++padding) {
10479       // Test ceil_mode=true through the CPU interop.
10480       for (bool ceil_mode : {false, true}) {
10481         auto testfn =
10482             [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10483           return torch::max_pool3d(
10484               inputs[0],
10485               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
10486               /*stride=*/{stride, stride, stride},
10487               /*padding=*/{padding, padding, padding},
10488               /*dilation=*/{1, 1, 1},
10489               /*ceil_mode=*/ceil_mode);
10490         };
10491 
10492         ForEachDevice([&](const torch::Device& device) {
10493           TestBackward(
10494               {torch::rand(
10495                   {2, 4, 4, 4},
10496                   torch::TensorOptions(torch::kFloat)
10497                       .device(DefaultDevice())
10498                       .requires_grad(true))},
10499               device,
10500               testfn);
10501         });
10502       }
10503     }
10504   }
10505 }
10506 
TEST_F(LazyOpsTest,TestMaxUnpool2DBackward)10507 TEST_F(LazyOpsTest, TestMaxUnpool2DBackward) {
10508   int kernel_size = 2;
10509   torch::Tensor input = torch::rand(
10510       {2, 2, 8, 8},
10511       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
10512   for (int stride = 1; stride <= 2; ++stride) {
10513     for (int padding = 0; padding <= 1; ++padding) {
10514       // Test ceil_mode=true through the CPU interop.
10515       for (bool ceil_mode : {false, true}) {
10516         for (int dilation = 1; dilation <= 2; ++dilation) {
10517           torch::Tensor output;
10518           torch::Tensor indices;
10519           std::tie(output, indices) = torch::max_pool2d_with_indices(
10520               input,
10521               /*kernel_size=*/{kernel_size, kernel_size},
10522               /*stride=*/{stride, stride},
10523               /*padding=*/{padding, padding},
10524               /*dilation=*/{dilation, dilation},
10525               /*ceil_mode=*/ceil_mode);
10526 
10527           std::vector<int64_t> output_size({input.size(2), input.size(3)});
10528           auto testfn =
10529               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10530             return torch::max_unpool2d(inputs[0], inputs[1], output_size);
10531           };
10532 
10533           ForEachDevice([&](const torch::Device& device) {
10534             TestBackward(
10535                 {output.requires_grad_(true), indices}, device, testfn);
10536           });
10537         }
10538       }
10539     }
10540   }
10541 }
10542 
TEST_F(LazyOpsTest,TestMaxUnpool3DBackward)10543 TEST_F(LazyOpsTest, TestMaxUnpool3DBackward) {
10544   int kernel_size = 2;
10545   torch::Tensor input = torch::rand(
10546       {1, 1, 4, 4, 4},
10547       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
10548   for (int stride = 1; stride <= 2; ++stride) {
10549     for (int padding = 0; padding <= 1; ++padding) {
10550       // Test ceil_mode=true through the CPU interop.
10551       for (bool ceil_mode : {false, true}) {
10552         for (int dilation = 1; dilation <= 2; ++dilation) {
10553           torch::Tensor output;
10554           torch::Tensor indices;
10555           std::tie(output, indices) = torch::max_pool3d_with_indices(
10556               input,
10557               /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
10558               /*stride=*/{stride, stride, stride},
10559               /*padding=*/{padding, padding, padding},
10560               /*dilation=*/{dilation, dilation, dilation},
10561               /*ceil_mode=*/ceil_mode);
10562 
10563           std::vector<int64_t> output_size(
10564               {input.size(2), input.size(3), input.size(4)});
10565           auto testfn =
10566               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10567             return torch::max_unpool3d(
10568                 inputs[0],
10569                 inputs[1],
10570                 output_size,
10571                 /*stride=*/{stride, stride, stride},
10572                 /*padding=*/{padding, padding, padding});
10573           };
10574 
10575           ForEachDevice([&](const torch::Device& device) {
10576             TestBackward(
10577                 {output.requires_grad_(true), indices}, device, testfn);
10578           });
10579         }
10580       }
10581     }
10582   }
10583 }
10584 
TEST_F(LazyOpsTest,TestTanhBackward)10585 TEST_F(LazyOpsTest, TestTanhBackward) {
10586   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10587     return torch::tanh(inputs[0]);
10588   };
10589   ForEachDevice([&](const torch::Device& device) {
10590     TestBackward(
10591         {torch::rand(
10592             {2, 2},
10593             torch::TensorOptions(torch::kFloat)
10594                 .device(DefaultDevice())
10595                 .requires_grad(true))},
10596         device,
10597         testfn,
10598         /*rtol=*/1e-3,
10599         /*atol=*/1e-5);
10600   });
10601 }
10602 
TEST_F(LazyOpsTest,TestSigmoidBackward)10603 TEST_F(LazyOpsTest, TestSigmoidBackward) {
10604   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10605     return torch::sigmoid(inputs[0]);
10606   };
10607   ForEachDevice([&](const torch::Device& device) {
10608     TestBackward(
10609         {torch::rand(
10610             {2, 2},
10611             torch::TensorOptions(torch::kFloat)
10612                 .device(DefaultDevice())
10613                 .requires_grad(true))},
10614         device,
10615         testfn);
10616   });
10617 }
10618 
TEST_F(LazyOpsTest,TestLogSigmoidBackward)10619 TEST_F(LazyOpsTest, TestLogSigmoidBackward) {
10620   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10621     return torch::log_sigmoid(inputs[0]);
10622   };
10623   ForEachDevice([&](const torch::Device& device) {
10624     TestBackward(
10625         {torch::rand(
10626             {2, 2},
10627             torch::TensorOptions(torch::kFloat)
10628                 .device(DefaultDevice())
10629                 .requires_grad(true))},
10630         device,
10631         testfn,
10632         /*rtol=*/1e-3,
10633         /*atol=*/1e-5);
10634   });
10635 }
10636 
TEST_F(LazyOpsTest,TestLogSoftmaxBackward)10637 TEST_F(LazyOpsTest, TestLogSoftmaxBackward) {
10638   for (int dim = -4; dim < 4; ++dim) {
10639     auto testfn =
10640         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10641       return torch::log_softmax(inputs[0], dim);
10642     };
10643 
10644     ForEachDevice([&](const torch::Device& device) {
10645       TestBackward(
10646           {torch::rand(
10647               {5, 3, 4, 2},
10648               torch::TensorOptions(torch::kFloat)
10649                   .device(DefaultDevice())
10650                   .requires_grad(true))},
10651           device,
10652           testfn,
10653           /*rtol=*/1e-3,
10654           /*atol=*/1e-4);
10655     });
10656   }
10657 }
10658 
TEST_F(LazyOpsTest,TestSoftmaxBackward)10659 TEST_F(LazyOpsTest, TestSoftmaxBackward) {
10660   for (int dim = -4; dim < 4; ++dim) {
10661     auto testfn =
10662         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10663       return torch::softmax(inputs[0], dim);
10664     };
10665 
10666     ForEachDevice([&](const torch::Device& device) {
10667       TestBackward(
10668           {torch::rand(
10669               {5, 3, 4, 2},
10670               torch::TensorOptions(torch::kFloat)
10671                   .device(DefaultDevice())
10672                   .requires_grad(true))},
10673           device,
10674           testfn,
10675           /*rtol=*/1e-3,
10676           /*atol=*/1e-4);
10677     });
10678   }
10679 }
10680 
TEST_F(LazyOpsTest,TestSoftplusBackward)10681 TEST_F(LazyOpsTest, TestSoftplusBackward) {
10682   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10683     return torch::softplus(inputs[0]);
10684   };
10685   ForEachDevice([&](const torch::Device& device) {
10686     TestBackward(
10687         {torch::rand(
10688             {2, 1, 4, 6},
10689             torch::TensorOptions(torch::kFloat)
10690                 .device(DefaultDevice())
10691                 .requires_grad(true))},
10692         device,
10693         testfn,
10694         /*rtol=*/1e-4);
10695   });
10696 }
10697 
TEST_F(LazyOpsTest,TestReluBackward)10698 TEST_F(LazyOpsTest, TestReluBackward) {
10699   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10700     return torch::relu(inputs[0]);
10701   };
10702   ForEachDevice([&](const torch::Device& device) {
10703     TestBackward(
10704         {torch::rand(
10705             {2, 1, 4, 6},
10706             torch::TensorOptions(torch::kFloat)
10707                 .device(DefaultDevice())
10708                 .requires_grad(true))},
10709         device,
10710         testfn);
10711   });
10712 }
10713 
TEST_F(LazyOpsTest,TestRreluBackward)10714 TEST_F(LazyOpsTest, TestRreluBackward) {
10715   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10716     return torch::rrelu(inputs[0]);
10717   };
10718   ForEachDevice([&](const torch::Device& device) {
10719     TestBackward(
10720         {torch::rand(
10721             {2, 1, 4, 6},
10722             torch::TensorOptions(torch::kFloat)
10723                 .device(DefaultDevice())
10724                 .requires_grad(true))},
10725         device,
10726         testfn);
10727   });
10728 }
10729 
TEST_F(LazyOpsTest,TestHardshrinkBackward)10730 TEST_F(LazyOpsTest, TestHardshrinkBackward) {
10731   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10732     return torch::hardshrink(inputs[0]);
10733   };
10734   ForEachDevice([&](const torch::Device& device) {
10735     TestBackward(
10736         {torch::randn(
10737             {100},
10738             torch::TensorOptions(torch::kFloat)
10739                 .device(DefaultDevice())
10740                 .requires_grad(true))},
10741         device,
10742         testfn);
10743   });
10744 }
10745 
TEST_F(LazyOpsTest,TestSoftshrinkBackward)10746 TEST_F(LazyOpsTest, TestSoftshrinkBackward) {
10747   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10748     return torch::softshrink(inputs[0]);
10749   };
10750   ForEachDevice([&](const torch::Device& device) {
10751     TestBackward(
10752         {torch::randn(
10753             {100},
10754             torch::TensorOptions(torch::kFloat)
10755                 .device(DefaultDevice())
10756                 .requires_grad(true))},
10757         device,
10758         testfn);
10759   });
10760 }
10761 
TEST_F(LazyOpsTest,TestHardtanhBackward)10762 TEST_F(LazyOpsTest, TestHardtanhBackward) {
10763   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10764     return torch::hardtanh(inputs[0]);
10765   };
10766   ForEachDevice([&](const torch::Device& device) {
10767     TestBackward(
10768         {torch::randn(
10769             {100},
10770             torch::TensorOptions(torch::kFloat)
10771                 .device(DefaultDevice())
10772                 .requires_grad(true))},
10773         device,
10774         testfn);
10775   });
10776 }
10777 
TEST_F(LazyOpsTest,TestEluBackward)10778 TEST_F(LazyOpsTest, TestEluBackward) {
10779   torch::Scalar alpha = 0.5;
10780   torch::Scalar scale = 2.5;
10781   torch::Scalar input_scale = 1.5;
10782   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10783     return torch::elu(inputs[0], alpha, scale, input_scale);
10784   };
10785   ForEachDevice([&](const torch::Device& device) {
10786     TestBackward(
10787         {torch::rand(
10788             {2, 1, 4, 6},
10789             torch::TensorOptions(torch::kFloat)
10790                 .device(DefaultDevice())
10791                 .requires_grad(true))},
10792         device,
10793         testfn);
10794   });
10795 }
10796 
TEST_F(LazyOpsTest,TestGeluBackward)10797 TEST_F(LazyOpsTest, TestGeluBackward) {
10798   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10799     return torch::gelu(inputs[0]);
10800   };
10801   ForEachDevice([&](const torch::Device& device) {
10802     TestBackward(
10803         {torch::rand(
10804             {2, 3},
10805             torch::TensorOptions(torch::kFloat)
10806                 .device(DefaultDevice())
10807                 .requires_grad(true))},
10808         device,
10809         testfn);
10810   });
10811   ExpectCounterChanged("lazy::gelu_backward", GetIgnoredCounters());
10812 }
10813 
TEST_F(LazyOpsTest,TestLeakyReluBackward)10814 TEST_F(LazyOpsTest, TestLeakyReluBackward) {
10815   double negative_slope = 0.01;
10816   auto testfn = [=](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10817     return torch::leaky_relu(inputs[0], negative_slope);
10818   };
10819   ForEachDevice([&](const torch::Device& device) {
10820     TestBackward(
10821         {torch::rand(
10822             {2, 1, 4, 6},
10823             torch::TensorOptions(torch::kFloat)
10824                 .device(DefaultDevice())
10825                 .requires_grad(true))},
10826         device,
10827         testfn);
10828   });
10829 }
10830 
TEST_F(LazyOpsTest,TestTransposeBackward)10831 TEST_F(LazyOpsTest, TestTransposeBackward) {
10832   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10833     return torch::t(inputs[0]);
10834   };
10835   ForEachDevice([&](const torch::Device& device) {
10836     TestBackward(
10837         {torch::rand(
10838             {2, 3},
10839             torch::TensorOptions(torch::kFloat)
10840                 .device(DefaultDevice())
10841                 .requires_grad(true))},
10842         device,
10843         testfn);
10844   });
10845 }
10846 
TEST_F(LazyOpsTest,TestAddMatMulBackward)10847 TEST_F(LazyOpsTest, TestAddMatMulBackward) {
10848   int in_channels = 32;
10849   int out_channels = 320;
10850   int labels = 50;
10851   // Test beta != 1. through the CPU interop.
10852   for (double beta : {1., 2.}) {
10853     auto testfn =
10854         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10855       return torch::addmm(inputs[0], inputs[1], inputs[2], /*beta=*/beta);
10856     };
10857     ForEachDevice([&](const torch::Device& device) {
10858       TestBackward(
10859           {torch::rand(
10860                {labels},
10861                torch::TensorOptions(torch::kFloat)
10862                    .device(DefaultDevice())
10863                    .requires_grad(true)),
10864            torch::rand(
10865                {in_channels, out_channels},
10866                torch::TensorOptions(torch::kFloat)
10867                    .device(DefaultDevice())
10868                    .requires_grad(true)),
10869            torch::rand(
10870                {out_channels, labels},
10871                torch::TensorOptions(torch::kFloat)
10872                    .device(DefaultDevice())
10873                    .requires_grad(true))},
10874           device,
10875           testfn);
10876     });
10877   }
10878 }
10879 
TEST_F(LazyOpsTest,TestBinaryCrossEntropyBackward)10880 TEST_F(LazyOpsTest, TestBinaryCrossEntropyBackward) {
10881   int batch = 6;
10882   int classes = 2;
10883   // TODO(asuhan): Fix the torch::kDouble case.
10884   for (auto dtype : {torch::kFloat}) {
10885     for (bool def_weight : {false, true}) {
10886       torch::Tensor input = torch::rand(
10887           {batch, classes}, torch::TensorOptions(dtype).requires_grad(true));
10888       torch::Tensor target =
10889           torch::rand({batch, classes}, torch::TensorOptions(dtype));
10890       torch::Tensor weight;
10891       if (def_weight) {
10892         weight = torch::rand({batch, classes}, torch::TensorOptions(dtype));
10893       }
10894       for (torch::Reduction::Reduction reduction :
10895            {torch::Reduction::Mean,
10896             torch::Reduction::Sum,
10897             torch::Reduction::None}) {
10898         auto testfn =
10899             [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10900           return torch::binary_cross_entropy(
10901               /*self=*/inputs[0],
10902               /*target=*/inputs[1],
10903               /*weight=*/inputs[2],
10904               /*reduction=*/reduction);
10905         };
10906         ForEachDevice([&](const torch::Device& device) {
10907           TestBackward(
10908               {input, target, weight},
10909               device,
10910               testfn,
10911               /*rtol=*/1e-4,
10912               /*atol=*/1e-7);
10913         });
10914       }
10915     }
10916   }
10917 }
10918 
TEST_F(LazyOpsTest,TestNllLossBackward)10919 TEST_F(LazyOpsTest, TestNllLossBackward) {
10920   // TODO(whc) debug divide-by-zero failure under ASAN
10921   GTEST_SKIP();
10922 
10923   int batch = 6;
10924   int classes = 2;
10925   // TODO(asuhan): Fix the torch::kDouble case.
10926   for (auto dtype : {torch::kFloat}) {
10927     for (int ignore_index : {-1, 0, 1, 5}) {
10928       for (bool def_weight : {false, true}) {
10929         torch::Tensor input = torch::rand(
10930             {batch, classes},
10931             torch::TensorOptions(dtype)
10932                 .device(DefaultDevice())
10933                 .requires_grad(true));
10934         torch::Tensor target = torch::randint(
10935             std::min(ignore_index, 0),
10936             classes,
10937             {batch},
10938             torch::TensorOptions(torch::kLong).device(DefaultDevice()));
10939         torch::Tensor weight;
10940         if (def_weight) {
10941           weight = torch::rand(
10942               {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
10943         }
10944         for (torch::Reduction::Reduction reduction :
10945              {torch::Reduction::Mean,
10946               torch::Reduction::Sum,
10947               torch::Reduction::None}) {
10948           auto testfn =
10949               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10950             return torch::nll_loss(
10951                 /*self=*/inputs[0],
10952                 /*target=*/inputs[1],
10953                 /*weight=*/inputs[2],
10954                 /*reduction=*/reduction,
10955                 /*ignore_index=*/ignore_index);
10956           };
10957           ForEachDevice([&](const torch::Device& device) {
10958             TestBackward(
10959                 {input, target, weight},
10960                 device,
10961                 testfn,
10962                 /*rtol=*/1e-5,
10963                 /*atol=*/1e-8);
10964           });
10965         }
10966       }
10967     }
10968   }
10969 }
10970 
TEST_F(LazyOpsTest,TestNllLoss2dBackward)10971 TEST_F(LazyOpsTest, TestNllLoss2dBackward) {
10972   int batch = 6;
10973   int classes = 2;
10974   int height = 3;
10975   int width = 3;
10976   // TODO(asuhan): Fix the torch::kDouble case.
10977   for (auto dtype : {torch::kFloat}) {
10978     for (int ignore_index : {-1, 0, 1, 5}) {
10979       for (bool def_weight : {false, true}) {
10980         torch::Tensor input = torch::rand(
10981             {batch, classes, height, width},
10982             torch::TensorOptions(dtype)
10983                 .device(DefaultDevice())
10984                 .requires_grad(true));
10985         torch::Tensor target = torch::randint(
10986             std::min(ignore_index, 0),
10987             classes,
10988             {batch, height, width},
10989             torch::TensorOptions(torch::kLong).device(DefaultDevice()));
10990         torch::Tensor weight;
10991         if (def_weight) {
10992           weight = torch::rand(
10993               {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
10994         }
10995         for (torch::Reduction::Reduction reduction :
10996              {torch::Reduction::Mean,
10997               torch::Reduction::Sum,
10998               torch::Reduction::None}) {
10999           auto testfn =
11000               [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11001             return torch::nll_loss2d(
11002                 /*self=*/inputs[0],
11003                 /*target=*/inputs[1],
11004                 /*weight=*/inputs[2],
11005                 /*reduction=*/reduction,
11006                 /*ignore_index=*/ignore_index);
11007           };
11008           ForEachDevice([&](const torch::Device& device) {
11009             TestBackward(
11010                 {input, target, weight},
11011                 device,
11012                 testfn,
11013                 /*rtol=*/1e-5,
11014                 /*atol=*/1e-8);
11015           });
11016         }
11017       }
11018     }
11019   }
11020 }
11021 
TEST_F(LazyOpsTest,TestSmoothL1LossBackward)11022 TEST_F(LazyOpsTest, TestSmoothL1LossBackward) {
11023   torch::Tensor input = torch::randn(
11024       {2, 4},
11025       torch::TensorOptions(torch::kFloat)
11026           .device(DefaultDevice())
11027           .requires_grad(true));
11028   torch::Tensor target = torch::randn(
11029       {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11030   for (torch::Reduction::Reduction reduction :
11031        {torch::Reduction::None,
11032         torch::Reduction::Mean,
11033         torch::Reduction::Sum}) {
11034     for (double beta : {0.25, 1.}) {
11035       auto testfn =
11036           [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11037         return torch::smooth_l1_loss(
11038             /*input=*/inputs[0],
11039             /*target=*/inputs[1],
11040             /*reduction=*/reduction,
11041             /*beta=*/beta);
11042       };
11043       ForEachDevice([&](const torch::Device& device) {
11044         TestBackward(
11045             {input, target},
11046             device,
11047             testfn,
11048             /*rtol=*/1e-5,
11049             /*atol=*/1e-8);
11050       });
11051     }
11052   }
11053 }
11054 
TEST_F(LazyOpsTest,TestViewBackward)11055 TEST_F(LazyOpsTest, TestViewBackward) {
11056   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11057     return inputs[0].view({-1, 320});
11058   };
11059   ForEachDevice([&](const torch::Device& device) {
11060     TestBackward(
11061         {torch::rand(
11062             {32, 20, 4, 4},
11063             torch::TensorOptions(torch::kFloat)
11064                 .device(DefaultDevice())
11065                 .requires_grad(true))},
11066         device,
11067         testfn);
11068   });
11069 }
11070 
TEST_F(LazyOpsTest,TestBatchNorm2DBackward)11071 TEST_F(LazyOpsTest, TestBatchNorm2DBackward) {
11072   double momentum = 0.1;
11073   double eps = 0.5;
11074   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11075     return torch::batch_norm(
11076         /*input=*/inputs[0],
11077         /*weight=*/inputs[1],
11078         /*bias=*/inputs[2],
11079         /*running_mean=*/inputs[3],
11080         /*running_var=*/inputs[4],
11081         /*training=*/true,
11082         /*momentum=*/momentum,
11083         /*eps=*/eps,
11084         /*cudnn_enabled=*/false);
11085   };
11086   int num_features = 3;
11087   torch::Tensor undef;
11088   for (bool undef_weight_bias : {false, true}) {
11089     ForEachDevice([&](const torch::Device& device) {
11090       torch::Tensor input = torch::rand(
11091           {2, num_features, 4, 4},
11092           torch::TensorOptions(torch::kFloat)
11093               .device(DefaultDevice())
11094               .requires_grad(true));
11095       torch::Tensor weight = undef_weight_bias
11096           ? undef
11097           : torch::rand(
11098                 {num_features},
11099                 torch::TensorOptions(torch::kFloat)
11100                     .device(DefaultDevice())
11101                     .requires_grad(true));
11102       torch::Tensor bias = undef_weight_bias
11103           ? undef
11104           : torch::rand(
11105                 {num_features},
11106                 torch::TensorOptions(torch::kFloat)
11107                     .device(DefaultDevice())
11108                     .requires_grad(true));
11109       torch::Tensor running_mean = torch::zeros(
11110           {num_features},
11111           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11112       torch::Tensor running_var = torch::ones(
11113           {num_features},
11114           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11115       TestBackward(
11116           {input, weight, bias, running_mean, running_var},
11117           device,
11118           testfn,
11119           /*rtol=*/1e-3,
11120           /*atol=*/1e-4);
11121     });
11122   }
11123 }
11124 
TEST_F(LazyOpsTest,TestBatchNorm3DBackward)11125 TEST_F(LazyOpsTest, TestBatchNorm3DBackward) {
11126   double momentum = 0.1;
11127   double eps = 0.5;
11128   auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11129     return torch::batch_norm(
11130         /*input=*/inputs[0],
11131         /*weight=*/inputs[1],
11132         /*bias=*/inputs[2],
11133         /*running_mean=*/inputs[3],
11134         /*running_var=*/inputs[4],
11135         /*training=*/true,
11136         /*momentum=*/momentum,
11137         /*eps=*/eps,
11138         /*cudnn_enabled=*/false);
11139   };
11140   int num_features = 3;
11141   torch::Tensor undef;
11142   for (bool undef_weight_bias : {false, true}) {
11143     ForEachDevice([&](const torch::Device& device) {
11144       torch::Tensor input = torch::rand(
11145           {2, num_features, 4, 4, 2},
11146           torch::TensorOptions(torch::kFloat)
11147               .device(DefaultDevice())
11148               .requires_grad(true));
11149       torch::Tensor weight = undef_weight_bias
11150           ? undef
11151           : torch::rand(
11152                 {num_features},
11153                 torch::TensorOptions(torch::kFloat)
11154                     .device(DefaultDevice())
11155                     .requires_grad(true));
11156       torch::Tensor bias = undef_weight_bias
11157           ? undef
11158           : torch::rand(
11159                 {num_features},
11160                 torch::TensorOptions(torch::kFloat)
11161                     .device(DefaultDevice())
11162                     .requires_grad(true));
11163       torch::Tensor running_mean = torch::zeros(
11164           {num_features},
11165           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11166       torch::Tensor running_var = torch::ones(
11167           {num_features},
11168           torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11169       TestBackward(
11170           {input, weight, bias, running_mean, running_var},
11171           device,
11172           testfn,
11173           /*rtol=*/1e-3,
11174           /*atol=*/1e-3);
11175     });
11176   }
11177 }
11178 
TEST_F(LazyOpsTest,TestBCEWithLogitsBackward)11179 TEST_F(LazyOpsTest, TestBCEWithLogitsBackward) {
11180   int batch = 10;
11181   int classes = 5;
11182   torch::Tensor undef;
11183   for (torch::Reduction::Reduction reduction :
11184        {torch::Reduction::None,
11185         torch::Reduction::Mean,
11186         torch::Reduction::Sum}) {
11187     auto testfn =
11188         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11189       return torch::binary_cross_entropy_with_logits(
11190           /*input=*/inputs[0],
11191           /*target=*/inputs[1],
11192           /*weight=*/inputs[2],
11193           /*pos_weight=*/inputs[3],
11194           /*reduction=*/reduction);
11195     };
11196     for (bool undef_weight : {false, true}) {
11197       for (bool undef_pos_weight : {false, true}) {
11198         torch::Tensor input = torch::rand(
11199             {batch, classes},
11200             torch::TensorOptions(torch::kFloat)
11201                 .device(DefaultDevice())
11202                 .requires_grad(true));
11203         torch::Tensor target = torch::rand(
11204             {batch, classes},
11205             torch::TensorOptions(torch::kFloat)
11206                 .device(DefaultDevice())
11207                 .requires_grad(true));
11208         torch::Tensor weight = undef_weight
11209             ? undef
11210             : torch::rand(
11211                   {classes},
11212                   torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11213         torch::Tensor pos_weight = undef_pos_weight
11214             ? undef
11215             : torch::rand(
11216                   {classes},
11217                   torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11218         ForEachDevice([&](const torch::Device& device) {
11219           TestBackward(
11220               {input, target, weight, pos_weight},
11221               device,
11222               testfn,
11223               /*rtol=*/1e-3,
11224               /*atol=*/1e-5);
11225         });
11226       }
11227     }
11228   }
11229 }
11230 
TEST_F(LazyOpsTest,TestKlDivBackward)11231 TEST_F(LazyOpsTest, TestKlDivBackward) {
11232   torch::Tensor input = torch::rand(
11233       {4, 3},
11234       torch::TensorOptions(torch::kFloat)
11235           .device(DefaultDevice())
11236           .requires_grad(true));
11237   torch::Tensor target = torch::rand(
11238       {4, 3},
11239       torch::TensorOptions(torch::kFloat)
11240           .device(DefaultDevice())
11241           .requires_grad(true));
11242   for (torch::Reduction::Reduction reduction :
11243        {torch::Reduction::Mean,
11244         torch::Reduction::Sum,
11245         torch::Reduction::None}) {
11246     auto testfn =
11247         [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11248       return torch::kl_div(/*self=*/inputs[0], /*target=*/inputs[1], reduction);
11249     };
11250     ForEachDevice([&](const torch::Device& device) {
11251       TestBackward(
11252           {input, target},
11253           device,
11254           testfn,
11255           /*rtol=*/1e-4,
11256           /*atol=*/1e-5);
11257     });
11258   }
11259 }
11260 
TEST_F(LazyOpsTest,TestEmbeddingBackward)11261 TEST_F(LazyOpsTest, TestEmbeddingBackward) {
11262   int num_weights = 32;
11263   for (int padding_idx = -1; padding_idx < num_weights; ++padding_idx) {
11264     for (bool scale_grad_by_freq : {false, true}) {
11265       auto testfn =
11266           [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11267         return torch::embedding(
11268             inputs[0],
11269             inputs[1],
11270             /*padding_idx=*/padding_idx,
11271             /*scale_grad_by_freq=*/scale_grad_by_freq,
11272             /*sparse=*/false);
11273       };
11274       ForEachDevice([&](const torch::Device& device) {
11275         torch::Tensor weight = torch::rand(
11276             {num_weights, 7},
11277             torch::TensorOptions(torch::kFloat)
11278                 .device(DefaultDevice())
11279                 .requires_grad(true));
11280         torch::Tensor indices = torch::randint(
11281             num_weights,
11282             {3, 9, 4},
11283             torch::TensorOptions(torch::kLong).device(DefaultDevice()));
11284         TestBackward(
11285             {weight, indices},
11286             device,
11287             testfn,
11288             /*rtol=*/1e-5,
11289             /*atol=*/1e-8);
11290       });
11291     }
11292   }
11293 }
11294 
TEST_F(LazyOpsTest,TestAmpForeachNonFiniteCheckAndUnscale)11295 TEST_F(LazyOpsTest, TestAmpForeachNonFiniteCheckAndUnscale) {
11296   if (IsCuda()) {
11297     // TODO(whc) debug failure on cuda
11298     GTEST_SKIP();
11299   }
11300 
11301   torch::Tensor grads0 = torch::tensor(
11302       {1, 2, 3, 4},
11303       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11304   torch::Tensor grads1 = torch::tensor(
11305       {1.0, 2.0, std::nan("1"), 4.0},
11306       torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11307   torch::Tensor inv_scale = torch::scalar_tensor(
11308       0.2, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11309   torch::Tensor found_inf = torch::scalar_tensor(
11310       0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11311   torch::Tensor grads_output0 = grads0 * inv_scale;
11312   torch::Tensor found_inf_output0 = torch::scalar_tensor(
11313       0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11314   torch::Tensor found_inf_output1 = torch::scalar_tensor(
11315       1, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11316   ForEachDevice([&](const torch::Device& device) {
11317     if (grads0.device() == at::kCPU) {
11318       GTEST_SKIP();
11319     }
11320     torch::Tensor lazy_grads0 = CopyToDevice(grads0, device);
11321     torch::Tensor lazy_inv_scale = CopyToDevice(inv_scale, device);
11322     torch::Tensor lazy_found_inf = CopyToDevice(found_inf, device);
11323     torch::_amp_foreach_non_finite_check_and_unscale_(
11324         lazy_grads0, lazy_found_inf, lazy_inv_scale);
11325     AllClose(grads_output0, lazy_grads0, /*rtol=*/1e-2, /*atol=*/1e-4);
11326     AllEqual(found_inf_output0, lazy_found_inf);
11327 
11328     torch::Tensor lazy_grads1 = CopyToDevice(grads1, device);
11329     torch::_amp_foreach_non_finite_check_and_unscale_(
11330         lazy_grads1, lazy_found_inf, lazy_inv_scale);
11331     AllEqual(found_inf_output1, lazy_found_inf);
11332   });
11333 }
11334 
TEST_F(LazyOpsTest,TestAmpUpdateScale)11335 TEST_F(LazyOpsTest, TestAmpUpdateScale) {
11336   torch::Tensor growth_tracker = torch::scalar_tensor(
11337       0, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11338   torch::Tensor current_scale = torch::scalar_tensor(
11339       4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11340   torch::Tensor found_inf = torch::scalar_tensor(
11341       1, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11342   torch::Tensor not_found_inf = torch::scalar_tensor(
11343       0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11344   float scale_growth_factor = 2.0;
11345   float scale_backoff_factor = 0.5;
11346   int growth_interval = 3;
11347 
11348   torch::Tensor growth_tracker_result0 = torch::scalar_tensor(
11349       1, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11350   torch::Tensor current_scale_result0 = torch::scalar_tensor(
11351       4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11352   torch::Tensor growth_tracker_result1 = torch::scalar_tensor(
11353       2, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11354   torch::Tensor current_scale_result1 = torch::scalar_tensor(
11355       4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11356   torch::Tensor growth_tracker_result2 = torch::scalar_tensor(
11357       0, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11358   torch::Tensor current_scale_result2 = torch::scalar_tensor(
11359       8, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11360   torch::Tensor growth_tracker_result3 = torch::scalar_tensor(
11361       0, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11362   torch::Tensor current_scale_result3 = torch::scalar_tensor(
11363       4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11364 
11365   ForEachDevice([&](const torch::Device& device) {
11366     if (growth_tracker.device() == at::kCPU) {
11367       GTEST_SKIP();
11368     }
11369     torch::Tensor lazy_growth_tracker = CopyToDevice(growth_tracker, device);
11370     torch::Tensor lazy_current_scale = CopyToDevice(current_scale, device);
11371     torch::Tensor lazy_found_inf = CopyToDevice(found_inf, device);
11372     torch::Tensor lazy_not_found_inf = CopyToDevice(not_found_inf, device);
11373 
11374     torch::_amp_update_scale_(
11375         lazy_current_scale,
11376         lazy_growth_tracker,
11377         lazy_not_found_inf,
11378         scale_growth_factor,
11379         scale_backoff_factor,
11380         growth_interval);
11381     AllClose(
11382         current_scale_result0,
11383         lazy_current_scale,
11384         /*rtol=*/1e-2,
11385         /*atol=*/1e-4);
11386     AllEqual(growth_tracker_result0, lazy_growth_tracker);
11387 
11388     torch::_amp_update_scale_(
11389         lazy_current_scale,
11390         lazy_growth_tracker,
11391         lazy_not_found_inf,
11392         scale_growth_factor,
11393         scale_backoff_factor,
11394         growth_interval);
11395     AllClose(
11396         current_scale_result1,
11397         lazy_current_scale,
11398         /*rtol=*/1e-2,
11399         /*atol=*/1e-4);
11400     AllEqual(growth_tracker_result1, lazy_growth_tracker);
11401 
11402     // torch::_amp_update_scale_ returns the reference of current_scale
11403     lazy_current_scale = torch::_amp_update_scale_(
11404         lazy_current_scale,
11405         lazy_growth_tracker,
11406         lazy_not_found_inf,
11407         scale_growth_factor,
11408         scale_backoff_factor,
11409         growth_interval);
11410     AllClose(
11411         current_scale_result2,
11412         lazy_current_scale,
11413         /*rtol=*/1e-2,
11414         /*atol=*/1e-4);
11415     AllEqual(growth_tracker_result2, lazy_growth_tracker);
11416 
11417     lazy_current_scale = torch::_amp_update_scale_(
11418         lazy_current_scale,
11419         lazy_growth_tracker,
11420         lazy_found_inf,
11421         scale_growth_factor,
11422         scale_backoff_factor,
11423         growth_interval);
11424     AllClose(
11425         current_scale_result3,
11426         lazy_current_scale,
11427         /*rtol=*/1e-2,
11428         /*atol=*/1e-4);
11429     AllEqual(growth_tracker_result3, lazy_growth_tracker);
11430   });
11431   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11432   ExpectCounterChanged("lazy::_amp_update_scale_", GetIgnoredCounters());
11433 }
11434 
TEST_F(LazyOpsTest,TestEarlySyncLiveTensors)11435 TEST_F(LazyOpsTest, TestEarlySyncLiveTensors) {
11436   torch::Tensor scalar_tensor = torch::scalar_tensor(
11437       1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11438   torch::Scalar scalar1 = scalar_tensor.item();
11439   ForEachDevice([&](const torch::Device& device) {
11440     torch::Tensor lazy_scalar_tensor = CopyToDevice(scalar_tensor, device);
11441     torch::Scalar scalar2 = lazy_scalar_tensor.item();
11442     ASSERT_EQ(scalar1.to<float>(), scalar2.to<float>());
11443   });
11444   if (DebugUtil::ExperimentEnabled("early_sync")) {
11445     ExpectCounterChanged("EarlySyncLiveTensorsCount", GetIgnoredCounters());
11446   } else {
11447     ExpectCounterNotChanged("EarlySyncLiveTensorsCount", GetIgnoredCounters());
11448   }
11449   ExpectCounterChanged("aten::_local_scalar_dense", GetIgnoredCounters());
11450 }
11451 
TEST_F(LazyOpsTest,TestLerp)11452 TEST_F(LazyOpsTest, TestLerp) {
11453   torch::Tensor start = torch::rand(
11454       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11455   torch::Tensor end = torch::rand(
11456       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11457   torch::Tensor weight = torch::rand(
11458       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11459   torch::Tensor res = torch::lerp(start, end, weight);
11460   ForEachDevice([&](const torch::Device& device) {
11461     torch::Tensor lazy_start = CopyToDevice(start, device);
11462     torch::Tensor lazy_end = CopyToDevice(end, device);
11463     torch::Tensor lazy_weight = CopyToDevice(weight, device);
11464     torch::Tensor lazy_res = torch::lerp(lazy_start, lazy_end, lazy_weight);
11465     AllClose(res, lazy_res);
11466   });
11467   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11468   ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11469 }
11470 
TEST_F(LazyOpsTest,TestLerpScalar)11471 TEST_F(LazyOpsTest, TestLerpScalar) {
11472   torch::Tensor start = torch::rand(
11473       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11474   torch::Tensor end = torch::rand(
11475       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11476   torch::Scalar weight = torch::Scalar(3.0);
11477   torch::Tensor res = torch::lerp(start, end, weight);
11478   ForEachDevice([&](const torch::Device& device) {
11479     torch::Tensor lazy_start = CopyToDevice(start, device);
11480     torch::Tensor lazy_end = CopyToDevice(end, device);
11481     torch::Tensor lazy_res = torch::lerp(lazy_start, lazy_end, weight);
11482     AllClose(res, lazy_res);
11483   });
11484   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11485   ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11486 }
11487 
TEST_F(LazyOpsTest,TestLerpInplace)11488 TEST_F(LazyOpsTest, TestLerpInplace) {
11489   torch::Tensor input = torch::rand(
11490       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11491   torch::Tensor end = torch::rand(
11492       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11493   torch::Tensor weight = torch::rand(
11494       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11495   torch::Tensor input_copy = input.clone();
11496   input.lerp_(end, weight);
11497   ForEachDevice([&](const torch::Device& device) {
11498     torch::Tensor lazy_input = CopyToDevice(input_copy, device);
11499     torch::Tensor lazy_end = CopyToDevice(end, device);
11500     torch::Tensor lazy_weight = CopyToDevice(weight, device);
11501     lazy_input.lerp_(lazy_end, lazy_weight);
11502     AllClose(lazy_input, input);
11503   });
11504   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11505   ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11506 }
11507 
TEST_F(LazyOpsTest,TestLerpScalarInplace)11508 TEST_F(LazyOpsTest, TestLerpScalarInplace) {
11509   torch::Tensor input = torch::rand(
11510       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11511   torch::Tensor end = torch::rand(
11512       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11513   torch::Scalar weight = torch::Scalar(3.0);
11514   torch::Tensor input_copy = input.clone();
11515   input.lerp_(end, weight);
11516   ForEachDevice([&](const torch::Device& device) {
11517     torch::Tensor lazy_input = CopyToDevice(input_copy, device);
11518     torch::Tensor lazy_end = CopyToDevice(end, device);
11519     lazy_input.lerp_(lazy_end, weight);
11520     AllClose(lazy_input, input);
11521   });
11522   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11523   ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11524 }
11525 
TEST_F(LazyOpsTest,TestLerpOut)11526 TEST_F(LazyOpsTest, TestLerpOut) {
11527   torch::Tensor start = torch::rand(
11528       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11529   torch::Tensor end = torch::rand(
11530       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11531   torch::Tensor weight = torch::rand(
11532       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11533   torch::Tensor res = torch::empty(
11534       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11535   ;
11536   torch::lerp_out(res, start, end, weight);
11537   ForEachDevice([&](const torch::Device& device) {
11538     torch::Tensor lazy_start = CopyToDevice(start, device);
11539     torch::Tensor lazy_end = CopyToDevice(end, device);
11540     torch::Tensor lazy_weight = CopyToDevice(weight, device);
11541     torch::Tensor lazy_res = torch::empty({3, 4}, lazy_start.options());
11542     torch::lerp_out(lazy_res, lazy_start, lazy_end, lazy_weight);
11543     AllClose(res, lazy_res);
11544   });
11545   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11546   ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11547 }
11548 
TEST_F(LazyOpsTest,TestLerpScalarOut)11549 TEST_F(LazyOpsTest, TestLerpScalarOut) {
11550   torch::Tensor start = torch::rand(
11551       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11552   torch::Tensor end = torch::rand(
11553       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11554   torch::Scalar weight = torch::Scalar(3.0);
11555   torch::Tensor res = torch::empty(
11556       {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11557   torch::lerp_out(res, start, end, weight);
11558   ForEachDevice([&](const torch::Device& device) {
11559     torch::Tensor lazy_start = CopyToDevice(start, device);
11560     torch::Tensor lazy_end = CopyToDevice(end, device);
11561     torch::Tensor lazy_res = torch::empty({3, 4}, lazy_start.options());
11562     torch::lerp_out(lazy_res, lazy_start, lazy_end, weight);
11563     AllClose(res, lazy_res);
11564   });
11565   ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11566   ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11567 }
11568 
TEST_F(LazyOpsTest,IsAliasOf)11569 TEST_F(LazyOpsTest, IsAliasOf) {
11570   auto a = torch::empty(
11571       4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11572   auto b = torch::empty(
11573       4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11574 
11575   ForEachDevice([&](const torch::Device& device) {
11576     auto lazy_a = CopyToDevice(a, device);
11577     auto lazy_b = CopyToDevice(b, device);
11578     EXPECT_EQ(!a.is_alias_of(b), !lazy_a.is_alias_of(lazy_b));
11579 
11580     auto c = a.view({2, 2});
11581     auto lazy_c = lazy_a.view({2, 2});
11582     EXPECT_EQ(a.is_alias_of(c), lazy_a.is_alias_of(lazy_c));
11583 
11584     auto d = c.view({1, 4});
11585     auto lazy_d = lazy_c.view({1, 4});
11586     EXPECT_EQ(d.is_alias_of(c), lazy_d.is_alias_of(lazy_c));
11587     EXPECT_EQ(d.is_alias_of(a), lazy_d.is_alias_of(lazy_a));
11588   });
11589 }
11590 
11591 #endif // FBCODE_CAFFE2
11592 
11593 } // namespace lazy
11594 } // namespace torch
11595