xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/ivalue_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/Dict.h>
3 #include <c10/util/intrusive_ptr.h>
4 #include <c10/util/irange.h>
5 #include <gmock/gmock.h>
6 #include <gtest/gtest.h>
7 #include <torch/torch.h>
8 
9 // Snippets for checking assembly.
inspectTupleConstruction()10 c10::IValue inspectTupleConstruction() {
11   std::tuple<std::string, std::string> s = std::make_tuple(
12       "abcdefghijklmnopqrstuvwxyz", "ABCDEFGHIJKLMNOPQRSTUVWXYZ");
13   return c10::IValue(s);
14 }
15 
16 namespace c10 {
17 
TEST(IValueTest,Basic)18 TEST(IValueTest, Basic) {
19   c10::List<int64_t> foo({3, 4, 5});
20   ASSERT_EQ(foo.use_count(), 1);
21   IValue bar{foo};
22   ASSERT_EQ(foo.use_count(), 2);
23   auto baz = bar;
24   ASSERT_EQ(foo.use_count(), 3);
25   auto foo2 = std::move(bar);
26   ASSERT_EQ(foo.use_count(), 3);
27   ASSERT_TRUE(foo2.isIntList());
28   // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
29   ASSERT_TRUE(bar.isNone());
30   foo2 = IValue(4.0);
31   ASSERT_TRUE(foo2.isDouble());
32   ASSERT_EQ(foo2.toDouble(), 4.0);
33   ASSERT_EQ(foo.use_count(), 2);
34   ASSERT_TRUE(baz.toIntVector() == std::vector<int64_t>({3, 4, 5}));
35   ASSERT_TRUE(baz.toDimVector() == at::DimVector({3, 4, 5}));
36 
37   auto move_it = std::move(baz).toIntList();
38   ASSERT_EQ(foo.use_count(), 2);
39   // NOLINTNEXTLINE(bugprone-use-after-move)
40   ASSERT_TRUE(baz.isNone());
41   IValue i(4);
42   ASSERT_TRUE(i.isInt());
43   ASSERT_EQ(i.toInt(), 4);
44   IValue dlist(c10::List<double>({3.5}));
45   ASSERT_TRUE(dlist.isDoubleList());
46   ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.5}));
47   std::move(dlist).toDoubleList();
48   // NOLINTNEXTLINE(bugprone-use-after-move)
49   ASSERT_TRUE(dlist.isNone());
50   dlist = IValue(c10::List<double>({3.4}));
51   ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.4}));
52   dlist = IValue(std::vector<double>({3.3, 3.2}));
53   ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.3, 3.2}));
54   IValue blist(std::vector<bool>{true, false});
55   ASSERT_TRUE(blist.isList());
56   const auto blistRef = blist.toListRef();
57   ASSERT_EQ(blistRef.size(), 2);
58   ASSERT_TRUE(blistRef[0].toBool());
59   ASSERT_FALSE(blistRef[1].toBool());
60   IValue the_list(
61       at::ivalue::Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
62   ASSERT_EQ(foo.use_count(), 3);
63   ASSERT_TRUE(the_list.isTuple());
64   auto first = the_list.toTupleRef().elements()[1];
65   ASSERT_EQ(first.toInt(), 4);
66   // Make sure toTupleRef has test coverage too.
67   first = the_list.toTupleRef().elements()[1];
68   ASSERT_EQ(first.toInt(), 4);
69   at::Tensor tv = at::rand({3, 4});
70   IValue ten(tv);
71   ASSERT_EQ(tv.use_count(), 2);
72   auto ten2 = ten;
73   ASSERT_EQ(tv.use_count(), 3);
74   ASSERT_TRUE(ten2.toTensor().equal(ten.toTensor()));
75   std::move(ten2).toTensor();
76   ASSERT_EQ(tv.use_count(), 2);
77 
78   auto elem1 = c10::complex<double>(3, 4);
79   auto elem2 = c10::complex<double>(3, -4);
80   auto elem3 = c10::complex<double>(5, 0);
81   c10::List<c10::complex<double>> foo1({elem1, elem2, elem3});
82   ASSERT_EQ(foo1.use_count(), 1);
83   IValue bar1{foo1};
84   ASSERT_EQ(foo1.use_count(), 2);
85   auto baz1 = bar1;
86   ASSERT_EQ(foo1.use_count(), 3);
87   auto foo12 = std::move(bar1);
88   ASSERT_EQ(foo1.use_count(), 3);
89   ASSERT_TRUE(foo12.isComplexDoubleList());
90   ASSERT_EQ(foo12.toComplexDoubleList(), foo1);
91 
92   // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
93   ASSERT_TRUE(bar1.isNone());
94   auto foo3 = IValue(c10::complex<double>(3, 4));
95   ASSERT_TRUE(foo3.isComplexDouble());
96   ASSERT_EQ(foo3.toComplexDouble(), c10::complex<double>(3,4));
97 
98   ASSERT_TRUE(baz1.toComplexDoubleVector() == std::vector<c10::complex<double>>({elem1, elem2, elem3}));
99   IValue complex_tuple(
100       at::ivalue::Tuple::create({IValue(c10::complex<double>(3.4, 4.7)), IValue(foo1)}));
101   ASSERT_TRUE(complex_tuple.isTuple());
102   ASSERT_EQ(complex_tuple.toTupleRef().elements()[0].toComplexDouble(), c10::complex<double>(3.4, 4.7));
103   ASSERT_EQ(complex_tuple.toTupleRef().elements()[1], foo1);
104 }
105 
TEST(IValueTest,BasicStorage)106 TEST(IValueTest, BasicStorage) {
107   at::Storage emptyStorage;
108   at::Storage nonemptyStorage(at::rand({3, 4}).storage());
109   IValue ivEmpty(emptyStorage);
110   IValue ivNonempty(nonemptyStorage);
111 
112   ASSERT_TRUE(ivEmpty.isStorage());
113   ASSERT_TRUE(ivNonempty.isStorage());
114   ASSERT_EQ(emptyStorage.unsafeGetStorageImpl(), ivEmpty.toStorage().unsafeGetStorageImpl());
115   ASSERT_EQ(nonemptyStorage.unsafeGetStorageImpl(), ivNonempty.toStorage().unsafeGetStorageImpl());
116 }
117 
TEST(IValueTest,ComplexDict)118 TEST(IValueTest, ComplexDict) {
119   typedef c10::complex<double> c_type;
120   c10::Dict<c_type, c_type> m;
121   auto num1 = c_type(2.3, -3.5);
122   auto num2 = c_type(0, 5);
123   m.insert(num1, 2 * num1);
124   m.insert(num2, 2 * num2);
125   IValue dict(std::move(m));
126   auto m_ = dict.toGenericDict();
127   ASSERT_EQ(m_.at(num1), 2 * num1);
128   ASSERT_EQ(m_.at(num2), 2 * num2);
129 }
130 
131 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
makeSampleIValues()132 static std::array<IValue, 16> makeSampleIValues() {
133   return {
134     IValue(),
135     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
136     at::rand({3, 4}),
137     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
138     at::rand({3, 4}).storage(),
139     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
140     1.5,
141     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
142     c10::complex<double>(2.5, -0.5),
143     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
144     42,
145     true,
146     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
147     std::make_tuple(23, "hello"),
148     "hello",
149     c10::make_intrusive<caffe2::Blob>(),
150     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
151     c10::List<int64_t>({1, 2, 3}),
152     c10::Dict<std::string, std::string>(),
153     c10::make_intrusive<ivalue::Future>(FloatType::get()),
154     c10::Device(c10::DeviceType::CPU, 0),
155     c10::Stream(c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CPU, 0)),
156     c10::make_intrusive<ivalue::Object>(c10::StrongTypePtr(nullptr, ClassType::create("class1", {})), 1),
157   };
158 }
159 
160 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
makeMoreSampleIValues()161 static std::array<IValue, 16> makeMoreSampleIValues() {
162   return {
163     IValue(),
164     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
165     at::rand({3, 4}),
166     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
167     at::rand({3, 4}).storage(),
168     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
169     2.5,
170     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
171     c10::complex<double>(2.7, -0.3),
172     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
173     43,
174     false,
175     std::make_tuple(1, "goodbye"),
176     "goodbye",
177     c10::make_intrusive<caffe2::Blob>(),
178     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
179     c10::List<int64_t>({4, 5, 6}),
180     c10::Dict<std::string, std::string>(),
181     c10::make_intrusive<ivalue::Future>(IntType::get()),
182     c10::Device(c10::DeviceType::CUDA, 2),
183     c10::Stream(c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CUDA, 1)),
184     c10::make_intrusive<ivalue::Object>(c10::StrongTypePtr(nullptr, ClassType::create("class2", {})), 2),
185   };}
186 
187 // IValue::operator== doesn't seem to work on Tensors.
188 #define EXPECT_IVALUE_EQ(a, b)                          \
189   EXPECT_EQ((a).isTensor(), (b).isTensor());            \
190   if ((a).isTensor()) {                                 \
191     EXPECT_TRUE((a).toTensor().equal((b).toTensor()));  \
192   } else {                                              \
193     EXPECT_EQ((a), (b));                                \
194   }
195 
TEST(IValueTest,Swap)196 TEST(IValueTest, Swap) {
197   // swap() has the following 3 cases: tensor, intrusive_ptr, or
198   // neither. Exercise all pairs of the three.
199 
200   auto sampleInputs = makeSampleIValues();
201   auto sampleTargets = makeMoreSampleIValues();
202   for (const auto& input: sampleInputs) {
203     for (const auto& target: sampleTargets) {
204       IValue a(input);
205       IValue b(target);
206       EXPECT_IVALUE_EQ(a, input);
207       EXPECT_IVALUE_EQ(b, target);
208       a.swap(b);
209       EXPECT_IVALUE_EQ(a, target);
210       EXPECT_IVALUE_EQ(b, input);
211     }
212   }
213 }
214 
TEST(IValueTest,CopyConstruct)215 TEST(IValueTest, CopyConstruct) {
216   auto sampleInputs = makeSampleIValues();
217   for (const IValue& v: sampleInputs) {
218     IValue copy(v);
219     EXPECT_IVALUE_EQ(copy, v);
220   }
221 }
222 
TEST(IValueTest,MoveConstruct)223 TEST(IValueTest, MoveConstruct) {
224   auto sampleInputs = makeSampleIValues();
225   for (const IValue& v: sampleInputs) {
226     IValue source(v);
227     IValue target(std::move(source));
228     EXPECT_IVALUE_EQ(target, v);
229     // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
230     EXPECT_TRUE(source.isNone());
231   }
232 }
233 
TEST(IValueTest,CopyAssign)234 TEST(IValueTest, CopyAssign) {
235   auto sampleInputs = makeSampleIValues();
236   auto sampleTargets = makeMoreSampleIValues();
237 
238   for (const IValue& input: sampleInputs) {
239     for (const IValue& target: sampleTargets) {
240       IValue copyTo(target);
241       IValue copyFrom(input);
242       copyTo = copyFrom;
243       EXPECT_IVALUE_EQ(copyTo, input);
244       EXPECT_IVALUE_EQ(copyFrom, input);
245       EXPECT_IVALUE_EQ(copyTo, copyFrom);
246     }
247   }
248 }
249 
TEST(IValueTest,MoveAssign)250 TEST(IValueTest, MoveAssign) {
251   auto sampleInputs = makeSampleIValues();
252   auto sampleTargets = makeMoreSampleIValues();
253 
254   for (const IValue& input: sampleInputs) {
255     for (const IValue& target: sampleTargets) {
256       IValue moveTo(target);
257       IValue moveFrom(input);
258       moveTo = std::move(moveFrom);
259       EXPECT_IVALUE_EQ(moveTo, input);
260       // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
261       EXPECT_TRUE(moveFrom.isNone());
262     }
263   }
264 }
265 
TEST(IValueTest,Tuple)266 TEST(IValueTest, Tuple) {
267   std::tuple<int64_t, at::Tensor> t = std::make_tuple(123, at::randn({1}));
268   auto iv = IValue(t);
269   auto t_ = iv.to<std::tuple<int64_t, at::Tensor>>();
270   ASSERT_EQ(std::get<0>(t_), 123);
271   ASSERT_EQ(
272       std::get<1>(t_).item().to<float>(), std::get<1>(t).item().to<float>());
273 }
274 
TEST(IValueTest,unsafeRemoveAttr)275 TEST(IValueTest, unsafeRemoveAttr) {
276   auto cu = std::make_shared<CompilationUnit>();
277   auto cls = ClassType::create("foo.bar", cu);
278   cls->addAttribute("attr1", TensorType::get());
279   cls->addAttribute("attr2", TensorType::get());
280   auto obj = c10::ivalue::Object::create(
281       c10::StrongTypePtr(cu, cls), cls->numAttributes());
282   obj->unsafeRemoveAttr("attr1");
283   // attr1 is not removed in the type
284   ASSERT_TRUE(cls->hasAttribute("attr1"));
285   ASSERT_TRUE(cls->hasAttribute("attr2"));
286   ASSERT_TRUE(obj->slots().size() == 1);
287 }
288 
TEST(IValueTest,TuplePrint)289 TEST(IValueTest, TuplePrint) {
290   {
291     IValue tp = std::make_tuple(3);
292 
293     std::stringstream ss;
294     ss << tp;
295     ASSERT_EQ(ss.str(), "(3,)");
296   }
297 
298   {
299     IValue tp = std::make_tuple(3, 3);
300     std::stringstream ss;
301     ss << tp;
302     ASSERT_EQ(ss.str(), "(3, 3)");
303   }
304 }
305 
TEST(IValueTest,ComplexIValuePrint)306 TEST(IValueTest, ComplexIValuePrint) {
307   {
308     IValue complex(c10::complex<double>(2, -3));
309     std::stringstream ss;
310     ss << complex;
311     ASSERT_EQ(ss.str(), "2.-3.j");
312   }
313 
314   {
315     IValue complex(c10::complex<double>(2, 0));
316     std::stringstream ss;
317     ss << complex;
318     ASSERT_EQ(ss.str(), "2.+0.j");
319   }
320 
321   {
322     IValue complex(c10::complex<double>(0, 3));
323     std::stringstream ss;
324     ss << complex;
325     ASSERT_EQ(ss.str(), "0.+3.j");
326   }
327 }
328 
TEST(IValueTest,Complex)329 TEST(IValueTest, Complex) {
330   auto c = c10::complex<double>(2, 3);
331   auto c_ = c10::complex<double>(2, -3);
332   IValue c1(c), c2(c_), c3{at::Scalar(c)};
333 
334   ASSERT_TRUE(c1.isComplexDouble());
335   ASSERT_TRUE(c3.isComplexDouble());
336 
337   ASSERT_EQ(c, c1.toComplexDouble());
338   ASSERT_FALSE(c1 == c2);
339   ASSERT_TRUE(c1 == c3);
340 
341   ASSERT_TRUE(c1.isScalar());
342   ASSERT_TRUE(c2.toScalar().equal(c_));
343 }
344 
TEST(IValueTest,BasicFuture)345 TEST(IValueTest, BasicFuture) {
346   auto f1 = c10::make_intrusive<ivalue::Future>(IntType::get());
347   ASSERT_FALSE(f1->completed());
348 
349   f1->markCompleted(IValue(42));
350   ASSERT_TRUE(f1->completed());
351   ASSERT_EQ(42, f1->value().toInt());
352   IValue iv(f1);
353   ASSERT_EQ(42, iv.toFuture()->value().toInt());
354 }
355 
TEST(IValueTest,FutureCallbacks)356 TEST(IValueTest, FutureCallbacks) {
357   auto f2 = c10::make_intrusive<ivalue::Future>(IntType::get());
358   int calledTimesA = 0;
359   int calledTimesB = 0;
360   f2->addCallback([&calledTimesA](ivalue::Future& f2) {
361     ASSERT_TRUE(f2.completed());
362     ASSERT_EQ(f2.value().toInt(), 43);
363     ++calledTimesA;
364   });
365   f2->markCompleted(IValue(43));
366   ASSERT_EQ(calledTimesA, 1);
367   ASSERT_EQ(calledTimesB, 0);
368   // Post-markCompleted()
369   f2->addCallback([&calledTimesB](ivalue::Future& f2) {
370     ASSERT_TRUE(f2.completed());
371     ASSERT_EQ(f2.value().toInt(), 43);
372     ++calledTimesB;
373   });
374   ASSERT_EQ(calledTimesA, 1);
375   ASSERT_EQ(calledTimesB, 1);
376   ASSERT_FALSE(f2->hasError());
377 }
378 
TEST(IValueTest,FutureExceptions)379 TEST(IValueTest, FutureExceptions) {
380   auto f3 = c10::make_intrusive<ivalue::Future>(IntType::get());
381   int calledTimes = 0;
382   f3->addCallback([&calledTimes](ivalue::Future& f3) {
383     ASSERT_TRUE(f3.completed());
384     try {
385       (void)f3.value();
386     } catch (const std::exception& e) {
387       if (std::string(e.what()) == "My Error") {
388         ++calledTimes;
389       }
390     }
391   });
392   ivalue::Future::FutureError err("My Error");
393   f3->setError(std::make_exception_ptr(err));
394   ASSERT_EQ(calledTimes, 1);
395   ASSERT_TRUE(f3->hasError());
396   ASSERT_EQ(f3->tryRetrieveErrorMessage(), std::string("My Error"));
397 }
398 
TEST(IValueTest,FutureSetError)399 TEST(IValueTest, FutureSetError) {
400   auto f1 = c10::make_intrusive<ivalue::Future>(IntType::get());
401   f1->setError(std::make_exception_ptr(std::runtime_error("foo")));
402   try {
403     f1->setError(std::make_exception_ptr(std::runtime_error("bar")));
404     FAIL() << "Expected to throw";
405   } catch (std::exception& e) {
406     EXPECT_THAT(e.what(), ::testing::HasSubstr("Error already set"));
407     EXPECT_THAT(e.what(), ::testing::HasSubstr("foo"));
408     EXPECT_THAT(e.what(), ::testing::HasSubstr("bar"));
409   }
410 }
411 
TEST(IValueTest,ValueEquality)412 TEST(IValueTest, ValueEquality) {
413   EXPECT_EQ(IValue("asdf"), IValue("asdf"));
414   EXPECT_NE(IValue("asdf"), IValue("ASDF"));
415   EXPECT_NE(IValue("2"), IValue(2));
416   EXPECT_EQ(IValue(1), IValue(1));
417 
418   // Check the equals() variant that returns an IValue
419   auto res = IValue("asdf").equals("asdf");
420   EXPECT_TRUE(res.isBool());
421   EXPECT_TRUE(res.toBool());
422 
423   res = IValue("asdf").equals(1);
424   EXPECT_TRUE(res.isBool());
425   EXPECT_FALSE(res.toBool());
426 }
427 
TEST(IValueTest,TensorEquality)428 TEST(IValueTest, TensorEquality) {
429   auto rawTensor = torch::zeros({2, 3});
430   auto rawTensorCopy = rawTensor.clone();
431   auto t = IValue(rawTensor);
432   auto tCopy = IValue(rawTensorCopy);
433 
434   // This should throw, because elementwise equality is ambiguous for
435   // multi-element Tensors.
436   auto testEquality = []() {
437     return IValue(torch::ones({2, 3})) == IValue(torch::rand({2, 3}));
438   };
439   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
440   EXPECT_ANY_THROW(testEquality());
441 
442   // equals() should return a tensor of all `true`.
443   IValue eqTensor = t.equals(tCopy);
444   EXPECT_TRUE(eqTensor.isTensor());
445   auto booleanTrue = torch::ones({2, 3}).to(torch::kBool);
446   EXPECT_TRUE(eqTensor.toTensor().equal(booleanTrue));
447 
448   // Test identity checking
449   EXPECT_TRUE(t.is(t));
450   EXPECT_FALSE(t.is(tCopy));
451   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
452   IValue tReference = t;
453   EXPECT_TRUE(t.is(tReference));
454 }
455 
TEST(IValueTest,ListEquality)456 TEST(IValueTest, ListEquality) {
457   IValue c1 = std::vector<int64_t>{0, 1, 2, 3};
458   IValue c2 = std::vector<int64_t>{0, 1, 2, 3};
459   IValue c3 = std::vector<int64_t>{0, 1, 2, 3, 4};
460   EXPECT_EQ(c1, c1);
461   EXPECT_EQ(c1, c2);
462   EXPECT_FALSE(c1.is(c2));
463   EXPECT_NE(c1, c3);
464   EXPECT_NE(c2, c3);
465 }
466 
TEST(IValueTest,DictEquality)467 TEST(IValueTest, DictEquality) {
468   auto innerDict = c10::Dict<std::string, std::string>();
469   innerDict.insert("foo", "bar");
470 
471   auto d1 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
472   d1.insert("one", innerDict);
473   d1.insert("two", innerDict);
474   d1.insert("three", innerDict);
475   auto c1 = IValue(d1);
476 
477   auto d2 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
478   d2.insert("one", innerDict.copy());
479   d2.insert("two", innerDict.copy());
480   d2.insert("three", innerDict.copy());
481   auto c2 = IValue(d2);
482 
483   auto d3 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
484   d3.insert("one", innerDict.copy());
485   d3.insert("two", innerDict.copy());
486   d3.insert("three", innerDict.copy());
487   d3.insert("four", innerDict.copy());
488   auto c3 = IValue(d3);
489 
490   auto d4 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
491   d4.insert("one", innerDict.copy());
492   d4.insert("two", innerDict.copy());
493   auto innerDictNotEqual = c10::Dict<std::string, std::string>();
494   innerDictNotEqual.insert("bar", "foo");
495   d4.insert("three", innerDictNotEqual);
496   auto c4 = IValue(d4);
497 
498   EXPECT_EQ(c1, c1);
499   EXPECT_EQ(c1, c2);
500   EXPECT_FALSE(c1.is(c2));
501   EXPECT_NE(c1, c3);
502   EXPECT_NE(c2, c3);
503   EXPECT_NE(c1, c4);
504   EXPECT_NE(c2, c4);
505 }
506 
TEST(IValueTest,DictEqualityDifferentOrder)507 TEST(IValueTest, DictEqualityDifferentOrder) {
508   auto d1 = c10::Dict<std::string, int64_t>();
509   d1.insert("one", 1);
510   d1.insert("two", 2);
511   auto d2 = c10::Dict<std::string, int64_t>();
512   d2.insert("two", 2);
513   d2.insert("one", 1);
514 
515   EXPECT_EQ(d1, d2);
516 }
517 
TEST(IValueTest,ListNestedEquality)518 TEST(IValueTest, ListNestedEquality) {
519   IValue c1 = std::vector<std::vector<int64_t>>({{0}, {0, 1}, {0, 1, 2}});
520   IValue c2 = std::vector<std::vector<int64_t>>({{0}, {0, 1}, {0, 1, 2}});
521   IValue c3 = std::vector<std::vector<int64_t>>({{1}, {0, 1}, {0, 1, 2}});
522   EXPECT_EQ(c1, c1);
523   EXPECT_EQ(c1, c2);
524   EXPECT_NE(c1, c3);
525   EXPECT_NE(c2, c3);
526 }
527 
TEST(IValueTest,StreamEquality)528 TEST(IValueTest, StreamEquality) {
529   at::Device device1 =  at::Device(kCUDA, 0);
530   at::Device device2 = at::Device(kCUDA, 1);
531   c10::Stream stream1 = c10::Stream(c10::Stream::Default::DEFAULT, device1);
532   c10::Stream stream2 = c10::Stream(c10::Stream::Default::DEFAULT, device2);
533   IValue lhs(stream1);
534   IValue rhs_different(stream2);
535   IValue rhs_same(stream1);
536   EXPECT_FALSE(lhs.equals(rhs_different).toBool());
537   EXPECT_TRUE(lhs.equals(rhs_same).toBool());
538 }
539 
TEST(IValueTest,EnumEquality)540 TEST(IValueTest, EnumEquality) {
541   auto cu = std::make_shared<CompilationUnit>();
542   IValue int_ivalue_1(1);
543   IValue int_ivalue_2(2);
544   IValue str_ivalue_1("1");
545   auto int_enum_type1 = EnumType::create(
546       "enum_class_1",
547       IntType::get(),
548       {{"enum_name_1", int_ivalue_1}, {"enum_name_2", int_ivalue_2}},
549       cu);
550   auto int_enum_type2 = EnumType::create(
551       "enum_class_2",
552       IntType::get(),
553       {{"enum_name_1", int_ivalue_1}, {"enum_name_2", int_ivalue_2}},
554       cu);
555   auto string_enum_type = EnumType::create(
556       "enum_class_3", StringType::get(), {{"enum_name_1", str_ivalue_1}}, cu);
557 
558   EXPECT_EQ(
559       IValue(c10::make_intrusive<ivalue::EnumHolder>(
560           int_enum_type1, "enum_name_1", int_ivalue_1)),
561       IValue(c10::make_intrusive<ivalue::EnumHolder>(
562           int_enum_type1, "enum_name_1", int_ivalue_1))
563   );
564 
565   EXPECT_NE(
566       IValue(c10::make_intrusive<ivalue::EnumHolder>(
567           int_enum_type1, "enum_name_1", int_ivalue_1)),
568       IValue(c10::make_intrusive<ivalue::EnumHolder>(
569           int_enum_type2, "enum_name_1", int_ivalue_1))
570   );
571 
572   EXPECT_NE(
573       IValue(c10::make_intrusive<ivalue::EnumHolder>(
574           int_enum_type1, "enum_name_1", int_ivalue_1)),
575       IValue(c10::make_intrusive<ivalue::EnumHolder>(
576           int_enum_type1, "enum_name_2", int_ivalue_2))
577   );
578 
579   EXPECT_NE(
580       IValue(c10::make_intrusive<ivalue::EnumHolder>(
581           int_enum_type1, "enum_name_1", int_ivalue_1)),
582       IValue(c10::make_intrusive<ivalue::EnumHolder>(
583           string_enum_type, "enum_name_1", str_ivalue_1))
584   );
585 }
586 
TEST(IValueTest,isPtrType)587 TEST(IValueTest, isPtrType) {
588   IValue tensor(at::rand({3, 4}));
589   IValue undefinedTensor((at::Tensor()));
590   IValue integer(42);
591   IValue str("hello");
592 
593   EXPECT_TRUE(tensor.isPtrType());
594   EXPECT_FALSE(undefinedTensor.isPtrType());
595   EXPECT_FALSE(integer.isPtrType());
596   EXPECT_TRUE(str.isPtrType());
597 }
598 
TEST(IValueTest,isAliasOf)599 TEST(IValueTest, isAliasOf) {
600   auto sampleIValues = makeSampleIValues();
601   for (auto& iv: sampleIValues) {
602     for (auto& iv2: sampleIValues) {
603       if (&iv == &iv2 && iv.isPtrType()) {
604         EXPECT_TRUE(iv.isAliasOf(iv2));
605       } else {
606         EXPECT_FALSE(iv.isAliasOf(iv2));
607       }
608     }
609   }
610 }
611 
TEST(IValueTest,internalToPointer)612 TEST(IValueTest, internalToPointer) {
613   IValue tensor(at::rand({3, 4}));
614   IValue str("hello");
615 
616   EXPECT_EQ(tensor.internalToPointer(), tensor.unsafeToTensorImpl());
617   EXPECT_NE(str.internalToPointer(), nullptr);
618 
619   IValue nullStr((c10::intrusive_ptr<ivalue::ConstantString>()));
620   ASSERT_TRUE(nullStr.isString());
621   EXPECT_EQ(nullStr.internalToPointer(), nullptr);
622 }
623 
TEST(IValueTest,IdentityComparisonAndHashing)624 TEST(IValueTest, IdentityComparisonAndHashing) {
625   at::Tensor t1 = at::rand({3, 4});
626   at::Tensor t2 = at::rand({3, 4});
627   IValue tv1(t1), tv2(t2);
628   IValue tv1b(t1);
629 
630   EXPECT_EQ(tv1.hash(), tv1b.hash());
631   EXPECT_NE(tv1.hash(), tv2.hash());
632 
633   EXPECT_TRUE(tv1.is(tv1));
634   EXPECT_TRUE(tv1.is(tv1b));
635   EXPECT_TRUE(tv1b.is(tv1));
636   EXPECT_TRUE(tv2.is(tv2));
637 
638   EXPECT_FALSE(tv1.is(tv2));
639   EXPECT_FALSE(tv2.is(tv1));
640 
641   IValue none;
642   IValue undefinedTensor((at::Tensor()));
643 
644   EXPECT_TRUE(none.is(undefinedTensor));
645   EXPECT_TRUE(undefinedTensor.is(none));
646 
647   // Is this a bug? We should probably have a is b => a.hash() == b.hash()
648   EXPECT_NE(none.hash(), undefinedTensor.hash());
649 
650   auto sampleIValues = makeSampleIValues();
651   auto sampleIValues2 = makeSampleIValues();
652   auto moreSampleIValues = makeMoreSampleIValues();
653 
654   ASSERT_EQ(sampleIValues.size(), moreSampleIValues.size());
655   for (const auto ii : c10::irange(sampleIValues.size())) {
656     if (sampleIValues[ii].isComplexDouble() ||
657         sampleIValues[ii].isBlob() ||
658         sampleIValues[ii].isList() ||
659         sampleIValues[ii].isFuture() ||
660         sampleIValues[ii].isStream() ||
661         sampleIValues[ii].isObject() ||
662         sampleIValues[ii].isGenericDict()) {
663       // Not hashable.
664       continue;
665     }
666     // Tuples may or may not have the same hash across instantiations.
667     if (!sampleIValues[ii].isTuple()) {
668       // Constant strings will have the same pointer value.
669       if (sampleIValues[ii].isPtrType() && !sampleIValues[ii].isString()) {
670         EXPECT_NE(sampleIValues[ii].hash(), sampleIValues2[ii].hash())
671           << " at index " << ii;
672       } else {
673         EXPECT_EQ(sampleIValues[ii].hash(), sampleIValues2[ii].hash())
674           << " at index " << ii;
675       }
676     }
677     if (!sampleIValues[ii].isNone() && !moreSampleIValues[ii].isNone()) {
678       EXPECT_NE(sampleIValues[ii].hash(), moreSampleIValues[ii].hash())
679         << " at index " << ii;
680     }
681   }
682 }
683 
684 // Sparse tensors do not work with static CPU dispatch
685 #ifndef ATEN_CPU_STATIC_DISPATCH
TEST(IValueTest,IdentityAndHashing_SparseCOO)686 TEST(IValueTest, IdentityAndHashing_SparseCOO) {
687   using namespace torch::indexing;
688 
689   at::Tensor t1 = at::rand({3, 4}).to_sparse();
690   at::Tensor t2 = at::rand({3, 4}).to_sparse();
691   at::Tensor t3 = at::rand({3, 4});
692 
693   IValue tv1(t1), tv1b(t1), tv2(t2), tv3(t3);
694 
695   EXPECT_EQ(tv1.hash(), tv1b.hash());
696   EXPECT_NE(tv1.hash(), tv2.hash());
697 
698   EXPECT_TRUE(tv1.is(tv1b));
699   EXPECT_FALSE(tv1.is(tv2));
700 
701   EXPECT_TRUE(tv1.isAliasOf(tv1b));
702   EXPECT_FALSE(tv1.isAliasOf(tv2));
703   EXPECT_FALSE(tv1.isAliasOf(tv3));
704 
705   std::vector<int64_t> idx_array1 = {0, 1, 1, 0, 0, 1};
706   at::Tensor idx1 = torch::from_blob(
707       idx_array1.data(),
708       {2, 3},
709       torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU));
710   std::vector<int64_t> idx_array2 = {1, 1, 2, 0, 1, 2};
711   at::Tensor idx2 = torch::from_blob(
712       idx_array2.data(),
713       {2, 3},
714       torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU));
715   std::vector<int32_t> val_array = {3, -5, 7};
716   at::Tensor val = torch::from_blob(
717       val_array.data(),
718       {3},
719       torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU));
720   at::Tensor sparse1 = torch::sparse_coo_tensor(
721       idx1, val, {3, 3}, torch::TensorOptions().dtype(torch::kInt32));
722   at::Tensor sparse2 = torch::sparse_coo_tensor(
723       idx2, val, {3, 3}, torch::TensorOptions().dtype(torch::kInt32));
724 
725   IValue idx1_v(idx1), idx2_v(idx2);
726   IValue val_v(val);
727   IValue sparse1_v(sparse1), sparse2_v(sparse2);
728 
729   EXPECT_TRUE(sparse1_v.isAliasOf(sparse2_v));
730   EXPECT_TRUE(sparse1_v.isAliasOf(idx1_v));
731   EXPECT_TRUE(sparse1_v.isAliasOf(val_v));
732   EXPECT_TRUE(sparse2_v.isAliasOf(idx2_v));
733   EXPECT_TRUE(sparse2_v.isAliasOf(val_v));
734   EXPECT_FALSE(idx1_v.isAliasOf(idx2_v));
735   EXPECT_FALSE(idx1_v.isAliasOf(val_v));
736   EXPECT_FALSE(sparse1_v.isAliasOf(idx2_v));
737 }
738 #endif // ATEN_CPU_STATIC_DISPATCH
739 
TEST(IValueTest,getSubValues)740 TEST(IValueTest, getSubValues) {
741   // Scalars have no subvalues.
742   IValue integer(42), float_(1.5), complex(c10::complex<double>(2, 3));
743 
744   IValue::HashAliasedIValues subvalues;
745 
746   integer.getSubValues(subvalues);
747   EXPECT_TRUE(subvalues.empty());
748 
749   subvalues.clear();
750 
751   float_.getSubValues(subvalues);
752   EXPECT_TRUE(subvalues.empty());
753 
754   subvalues.clear();
755 
756   complex.getSubValues(subvalues);
757   EXPECT_TRUE(subvalues.empty());
758 
759   subvalues.clear();
760 
761   at::Tensor t1(at::rand({3, 4})), t2(at::rand({3, 4}));
762   IValue tv1(t1), tv2(t2);
763   IValue list(std::vector<at::Tensor>{t1, t2});
764   IValue tuple(ivalue::Tuple::create({tv1, tv2}));
765 
766   c10::Dict<int64_t, at::Tensor> m;
767   m.insert(1, t1);
768   m.insert(2, t2);
769 
770   IValue dict(std::move(m));
771 
772   auto objType = ClassType::create(std::nullopt, {});
773   objType->addAttribute("t1", tv1.type());
774   objType->addAttribute("t2", tv2.type());
775 
776   auto o = ivalue::Object::create(StrongTypePtr(nullptr, objType), 2);
777   o->setSlot(0, tv1);
778   o->setSlot(1, tv2);
779 
780   IValue object(o);
781   tv1.getSubValues(subvalues);
782   EXPECT_EQ(subvalues.size(), 1);
783   EXPECT_EQ(subvalues.count(tv1), 1);
784 
785   subvalues.clear();
786 
787   for (auto& container: {list, tuple, dict, object}) {
788     container.getSubValues(subvalues);
789     EXPECT_EQ(subvalues.size(), 3);
790     EXPECT_EQ(subvalues.count(container), 1);
791     EXPECT_EQ(subvalues.count(tv1), 1);
792     EXPECT_EQ(subvalues.count(tv2), 1);
793 
794     subvalues.clear();
795   }
796 }
797 
TEST(IValueTest,ScalarBool)798 TEST(IValueTest, ScalarBool) {
799   Scalar expected(true);
800   IValue v(expected);
801   Scalar actual = v.toScalar();
802   EXPECT_TRUE(actual.isBoolean());
803   EXPECT_TRUE(actual.toBool());
804 }
805 
TEST(IValueTest,ToWeakAndBack)806 TEST(IValueTest, ToWeakAndBack) {
807   auto sampleInputs = makeSampleIValues();
808   for (const auto& sample: sampleInputs) {
809     WeakIValue weak(sample);
810     EXPECT_IVALUE_EQ(sample, weak.lock());
811   }
812 }
813 
814 // Storage and Generator did not set is_intrusive_ptr if they were
815 // undefined, which led use_count to return 1 instead of 0 for these
816 // cases.
TEST(IValueTest,UseCountCornerCases)817 TEST(IValueTest, UseCountCornerCases) {
818   at::Storage undefinedStorage;
819   at::Generator undefinedGenerator;
820   at::Tensor undefinedTensor;
821 
822   IValue ivEmptyStorage(undefinedStorage);
823   IValue ivEmptyGenerator(undefinedGenerator);
824   IValue ivEmptyTensor(undefinedTensor);
825 
826   ASSERT_EQ(1, ivEmptyStorage.use_count());
827   ASSERT_EQ(1, ivEmptyGenerator.use_count());
828   ASSERT_EQ(0, ivEmptyTensor.use_count());
829 }
830 
831 // TODO(gmagogsfm): Add type conversion test?
832 
833 using ivalue::TupleElements;
834 
835 namespace {
validateTupleElements(TupleElements & te,c10::ArrayRef<IValue> contents)836 void validateTupleElements(TupleElements& te, c10::ArrayRef<IValue> contents) {
837   EXPECT_EQ(te.empty(), contents.empty());
838   EXPECT_EQ(te.size(), contents.size());
839   for (const auto idx: c10::irange(contents.size())) {
840     EXPECT_IVALUE_EQ(te[idx], contents[idx]);
841     EXPECT_IVALUE_EQ(te.at(idx), contents[idx]);
842     EXPECT_IVALUE_EQ(*(te.begin() + idx), contents[idx]);
843   }
844   if (!contents.empty()) {
845     EXPECT_IVALUE_EQ(te.back(), contents.back());
846   }
847   auto v = std::move(te).vec();
848   EXPECT_EQ(v.size(), contents.size());
849   for (const auto idx: c10::irange(contents.size())) {
850     EXPECT_IVALUE_EQ(v[idx], contents[idx]);
851   }
852 }
853 } // namespace
854 
TEST(TupleElementsTest,Basic)855 TEST(TupleElementsTest, Basic) {
856   TupleElements empty;
857   validateTupleElements(empty, {});
858   TupleElements size1(1);
859   validateTupleElements(size1, {1});
860   TupleElements size2(1, 2);
861   validateTupleElements(size2, {1, 2});
862   TupleElements size3(1, 2, 3);
863   validateTupleElements(size3, {1, 2, 3});
864 
865   auto sampleIValuesArray = makeSampleIValues();
866   TupleElements large(std::vector<IValue>(sampleIValuesArray.begin(), sampleIValuesArray.end()));
867   validateTupleElements(large, sampleIValuesArray);
868 }
869 
870 namespace {
871 
872 std::array<TupleElements(*)(), 3> factories = {
__anon2940f6690702() 873   []() { return TupleElements();},
__anon2940f6690802() 874   []() { return  TupleElements(1, 2, 3);},
__anon2940f6690902() 875   []() { return TupleElements(std::vector<IValue>({1, 2, 3, "hello"})); }
876 };
877 
878 std::array<std::vector<IValue>, 3> expectedContents = {
879   std::vector<IValue>(),
880   std::vector<IValue>({1, 2, 3}),
881   std::vector<IValue>({1, 2, 3, "hello"}),
882 };
883 
884 }
885 
TEST(TupleElementsTest,Resize)886 TEST(TupleElementsTest, Resize) {
887   std::array<std::vector<IValue>, 3> newContents = {std::vector<IValue>(), std::vector<IValue>({4, 5, 6}), std::vector<IValue>({7, 8, 9, "hello"})};
888 
889   for (auto factory : factories) {
890     for (const auto& contents : newContents) {
891       auto te = factory();
892       auto contentsCopy = contents;
893       te.setContents(std::move(contentsCopy));
894       validateTupleElements(te, contents);
895     }
896   }
897 }
898 
TEST(TupleElementsTest,CopyAndMoveConstruct)899 TEST(TupleElementsTest, CopyAndMoveConstruct) {
900   int idx = 0;
901   for (auto fromFactory : factories) {
902     auto toMoveFrom = fromFactory();
903     TupleElements movedInto(std::move(toMoveFrom));
904     validateTupleElements(movedInto, expectedContents[idx]);
905     auto toCopyFrom = fromFactory();
906     TupleElements copiedInto(toCopyFrom);
907     validateTupleElements(copiedInto, expectedContents[idx]);
908     idx++;
909   }
910 }
911 
TEST(TupleElementsTest,CopyAndMoveAssign)912 TEST(TupleElementsTest, CopyAndMoveAssign) {
913   int fromIdx = 0;
914   for (auto fromFactory : factories) {
915     for (auto toFactory : factories) {
916       auto from = fromFactory();
917       auto to = toFactory();
918       auto copyFrom = fromFactory();
919       auto toCopy = toFactory();
920       to = std::move(from);
921       validateTupleElements(to, expectedContents[fromIdx]);
922       toCopy = copyFrom;
923       validateTupleElements(toCopy, expectedContents[fromIdx]);
924     }
925     fromIdx++;
926   }
927 }
928 
929 } // namespace c10
930