1 #include <ATen/ATen.h>
2 #include <ATen/core/Dict.h>
3 #include <c10/util/intrusive_ptr.h>
4 #include <c10/util/irange.h>
5 #include <gmock/gmock.h>
6 #include <gtest/gtest.h>
7 #include <torch/torch.h>
8
9 // Snippets for checking assembly.
inspectTupleConstruction()10 c10::IValue inspectTupleConstruction() {
11 std::tuple<std::string, std::string> s = std::make_tuple(
12 "abcdefghijklmnopqrstuvwxyz", "ABCDEFGHIJKLMNOPQRSTUVWXYZ");
13 return c10::IValue(s);
14 }
15
16 namespace c10 {
17
TEST(IValueTest,Basic)18 TEST(IValueTest, Basic) {
19 c10::List<int64_t> foo({3, 4, 5});
20 ASSERT_EQ(foo.use_count(), 1);
21 IValue bar{foo};
22 ASSERT_EQ(foo.use_count(), 2);
23 auto baz = bar;
24 ASSERT_EQ(foo.use_count(), 3);
25 auto foo2 = std::move(bar);
26 ASSERT_EQ(foo.use_count(), 3);
27 ASSERT_TRUE(foo2.isIntList());
28 // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
29 ASSERT_TRUE(bar.isNone());
30 foo2 = IValue(4.0);
31 ASSERT_TRUE(foo2.isDouble());
32 ASSERT_EQ(foo2.toDouble(), 4.0);
33 ASSERT_EQ(foo.use_count(), 2);
34 ASSERT_TRUE(baz.toIntVector() == std::vector<int64_t>({3, 4, 5}));
35 ASSERT_TRUE(baz.toDimVector() == at::DimVector({3, 4, 5}));
36
37 auto move_it = std::move(baz).toIntList();
38 ASSERT_EQ(foo.use_count(), 2);
39 // NOLINTNEXTLINE(bugprone-use-after-move)
40 ASSERT_TRUE(baz.isNone());
41 IValue i(4);
42 ASSERT_TRUE(i.isInt());
43 ASSERT_EQ(i.toInt(), 4);
44 IValue dlist(c10::List<double>({3.5}));
45 ASSERT_TRUE(dlist.isDoubleList());
46 ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.5}));
47 std::move(dlist).toDoubleList();
48 // NOLINTNEXTLINE(bugprone-use-after-move)
49 ASSERT_TRUE(dlist.isNone());
50 dlist = IValue(c10::List<double>({3.4}));
51 ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.4}));
52 dlist = IValue(std::vector<double>({3.3, 3.2}));
53 ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.3, 3.2}));
54 IValue blist(std::vector<bool>{true, false});
55 ASSERT_TRUE(blist.isList());
56 const auto blistRef = blist.toListRef();
57 ASSERT_EQ(blistRef.size(), 2);
58 ASSERT_TRUE(blistRef[0].toBool());
59 ASSERT_FALSE(blistRef[1].toBool());
60 IValue the_list(
61 at::ivalue::Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
62 ASSERT_EQ(foo.use_count(), 3);
63 ASSERT_TRUE(the_list.isTuple());
64 auto first = the_list.toTupleRef().elements()[1];
65 ASSERT_EQ(first.toInt(), 4);
66 // Make sure toTupleRef has test coverage too.
67 first = the_list.toTupleRef().elements()[1];
68 ASSERT_EQ(first.toInt(), 4);
69 at::Tensor tv = at::rand({3, 4});
70 IValue ten(tv);
71 ASSERT_EQ(tv.use_count(), 2);
72 auto ten2 = ten;
73 ASSERT_EQ(tv.use_count(), 3);
74 ASSERT_TRUE(ten2.toTensor().equal(ten.toTensor()));
75 std::move(ten2).toTensor();
76 ASSERT_EQ(tv.use_count(), 2);
77
78 auto elem1 = c10::complex<double>(3, 4);
79 auto elem2 = c10::complex<double>(3, -4);
80 auto elem3 = c10::complex<double>(5, 0);
81 c10::List<c10::complex<double>> foo1({elem1, elem2, elem3});
82 ASSERT_EQ(foo1.use_count(), 1);
83 IValue bar1{foo1};
84 ASSERT_EQ(foo1.use_count(), 2);
85 auto baz1 = bar1;
86 ASSERT_EQ(foo1.use_count(), 3);
87 auto foo12 = std::move(bar1);
88 ASSERT_EQ(foo1.use_count(), 3);
89 ASSERT_TRUE(foo12.isComplexDoubleList());
90 ASSERT_EQ(foo12.toComplexDoubleList(), foo1);
91
92 // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
93 ASSERT_TRUE(bar1.isNone());
94 auto foo3 = IValue(c10::complex<double>(3, 4));
95 ASSERT_TRUE(foo3.isComplexDouble());
96 ASSERT_EQ(foo3.toComplexDouble(), c10::complex<double>(3,4));
97
98 ASSERT_TRUE(baz1.toComplexDoubleVector() == std::vector<c10::complex<double>>({elem1, elem2, elem3}));
99 IValue complex_tuple(
100 at::ivalue::Tuple::create({IValue(c10::complex<double>(3.4, 4.7)), IValue(foo1)}));
101 ASSERT_TRUE(complex_tuple.isTuple());
102 ASSERT_EQ(complex_tuple.toTupleRef().elements()[0].toComplexDouble(), c10::complex<double>(3.4, 4.7));
103 ASSERT_EQ(complex_tuple.toTupleRef().elements()[1], foo1);
104 }
105
TEST(IValueTest,BasicStorage)106 TEST(IValueTest, BasicStorage) {
107 at::Storage emptyStorage;
108 at::Storage nonemptyStorage(at::rand({3, 4}).storage());
109 IValue ivEmpty(emptyStorage);
110 IValue ivNonempty(nonemptyStorage);
111
112 ASSERT_TRUE(ivEmpty.isStorage());
113 ASSERT_TRUE(ivNonempty.isStorage());
114 ASSERT_EQ(emptyStorage.unsafeGetStorageImpl(), ivEmpty.toStorage().unsafeGetStorageImpl());
115 ASSERT_EQ(nonemptyStorage.unsafeGetStorageImpl(), ivNonempty.toStorage().unsafeGetStorageImpl());
116 }
117
TEST(IValueTest,ComplexDict)118 TEST(IValueTest, ComplexDict) {
119 typedef c10::complex<double> c_type;
120 c10::Dict<c_type, c_type> m;
121 auto num1 = c_type(2.3, -3.5);
122 auto num2 = c_type(0, 5);
123 m.insert(num1, 2 * num1);
124 m.insert(num2, 2 * num2);
125 IValue dict(std::move(m));
126 auto m_ = dict.toGenericDict();
127 ASSERT_EQ(m_.at(num1), 2 * num1);
128 ASSERT_EQ(m_.at(num2), 2 * num2);
129 }
130
131 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
makeSampleIValues()132 static std::array<IValue, 16> makeSampleIValues() {
133 return {
134 IValue(),
135 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
136 at::rand({3, 4}),
137 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
138 at::rand({3, 4}).storage(),
139 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
140 1.5,
141 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
142 c10::complex<double>(2.5, -0.5),
143 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
144 42,
145 true,
146 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
147 std::make_tuple(23, "hello"),
148 "hello",
149 c10::make_intrusive<caffe2::Blob>(),
150 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
151 c10::List<int64_t>({1, 2, 3}),
152 c10::Dict<std::string, std::string>(),
153 c10::make_intrusive<ivalue::Future>(FloatType::get()),
154 c10::Device(c10::DeviceType::CPU, 0),
155 c10::Stream(c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CPU, 0)),
156 c10::make_intrusive<ivalue::Object>(c10::StrongTypePtr(nullptr, ClassType::create("class1", {})), 1),
157 };
158 }
159
160 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
makeMoreSampleIValues()161 static std::array<IValue, 16> makeMoreSampleIValues() {
162 return {
163 IValue(),
164 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
165 at::rand({3, 4}),
166 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
167 at::rand({3, 4}).storage(),
168 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
169 2.5,
170 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
171 c10::complex<double>(2.7, -0.3),
172 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
173 43,
174 false,
175 std::make_tuple(1, "goodbye"),
176 "goodbye",
177 c10::make_intrusive<caffe2::Blob>(),
178 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
179 c10::List<int64_t>({4, 5, 6}),
180 c10::Dict<std::string, std::string>(),
181 c10::make_intrusive<ivalue::Future>(IntType::get()),
182 c10::Device(c10::DeviceType::CUDA, 2),
183 c10::Stream(c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CUDA, 1)),
184 c10::make_intrusive<ivalue::Object>(c10::StrongTypePtr(nullptr, ClassType::create("class2", {})), 2),
185 };}
186
187 // IValue::operator== doesn't seem to work on Tensors.
188 #define EXPECT_IVALUE_EQ(a, b) \
189 EXPECT_EQ((a).isTensor(), (b).isTensor()); \
190 if ((a).isTensor()) { \
191 EXPECT_TRUE((a).toTensor().equal((b).toTensor())); \
192 } else { \
193 EXPECT_EQ((a), (b)); \
194 }
195
TEST(IValueTest,Swap)196 TEST(IValueTest, Swap) {
197 // swap() has the following 3 cases: tensor, intrusive_ptr, or
198 // neither. Exercise all pairs of the three.
199
200 auto sampleInputs = makeSampleIValues();
201 auto sampleTargets = makeMoreSampleIValues();
202 for (const auto& input: sampleInputs) {
203 for (const auto& target: sampleTargets) {
204 IValue a(input);
205 IValue b(target);
206 EXPECT_IVALUE_EQ(a, input);
207 EXPECT_IVALUE_EQ(b, target);
208 a.swap(b);
209 EXPECT_IVALUE_EQ(a, target);
210 EXPECT_IVALUE_EQ(b, input);
211 }
212 }
213 }
214
TEST(IValueTest,CopyConstruct)215 TEST(IValueTest, CopyConstruct) {
216 auto sampleInputs = makeSampleIValues();
217 for (const IValue& v: sampleInputs) {
218 IValue copy(v);
219 EXPECT_IVALUE_EQ(copy, v);
220 }
221 }
222
TEST(IValueTest,MoveConstruct)223 TEST(IValueTest, MoveConstruct) {
224 auto sampleInputs = makeSampleIValues();
225 for (const IValue& v: sampleInputs) {
226 IValue source(v);
227 IValue target(std::move(source));
228 EXPECT_IVALUE_EQ(target, v);
229 // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
230 EXPECT_TRUE(source.isNone());
231 }
232 }
233
TEST(IValueTest,CopyAssign)234 TEST(IValueTest, CopyAssign) {
235 auto sampleInputs = makeSampleIValues();
236 auto sampleTargets = makeMoreSampleIValues();
237
238 for (const IValue& input: sampleInputs) {
239 for (const IValue& target: sampleTargets) {
240 IValue copyTo(target);
241 IValue copyFrom(input);
242 copyTo = copyFrom;
243 EXPECT_IVALUE_EQ(copyTo, input);
244 EXPECT_IVALUE_EQ(copyFrom, input);
245 EXPECT_IVALUE_EQ(copyTo, copyFrom);
246 }
247 }
248 }
249
TEST(IValueTest,MoveAssign)250 TEST(IValueTest, MoveAssign) {
251 auto sampleInputs = makeSampleIValues();
252 auto sampleTargets = makeMoreSampleIValues();
253
254 for (const IValue& input: sampleInputs) {
255 for (const IValue& target: sampleTargets) {
256 IValue moveTo(target);
257 IValue moveFrom(input);
258 moveTo = std::move(moveFrom);
259 EXPECT_IVALUE_EQ(moveTo, input);
260 // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
261 EXPECT_TRUE(moveFrom.isNone());
262 }
263 }
264 }
265
TEST(IValueTest,Tuple)266 TEST(IValueTest, Tuple) {
267 std::tuple<int64_t, at::Tensor> t = std::make_tuple(123, at::randn({1}));
268 auto iv = IValue(t);
269 auto t_ = iv.to<std::tuple<int64_t, at::Tensor>>();
270 ASSERT_EQ(std::get<0>(t_), 123);
271 ASSERT_EQ(
272 std::get<1>(t_).item().to<float>(), std::get<1>(t).item().to<float>());
273 }
274
TEST(IValueTest,unsafeRemoveAttr)275 TEST(IValueTest, unsafeRemoveAttr) {
276 auto cu = std::make_shared<CompilationUnit>();
277 auto cls = ClassType::create("foo.bar", cu);
278 cls->addAttribute("attr1", TensorType::get());
279 cls->addAttribute("attr2", TensorType::get());
280 auto obj = c10::ivalue::Object::create(
281 c10::StrongTypePtr(cu, cls), cls->numAttributes());
282 obj->unsafeRemoveAttr("attr1");
283 // attr1 is not removed in the type
284 ASSERT_TRUE(cls->hasAttribute("attr1"));
285 ASSERT_TRUE(cls->hasAttribute("attr2"));
286 ASSERT_TRUE(obj->slots().size() == 1);
287 }
288
TEST(IValueTest,TuplePrint)289 TEST(IValueTest, TuplePrint) {
290 {
291 IValue tp = std::make_tuple(3);
292
293 std::stringstream ss;
294 ss << tp;
295 ASSERT_EQ(ss.str(), "(3,)");
296 }
297
298 {
299 IValue tp = std::make_tuple(3, 3);
300 std::stringstream ss;
301 ss << tp;
302 ASSERT_EQ(ss.str(), "(3, 3)");
303 }
304 }
305
TEST(IValueTest,ComplexIValuePrint)306 TEST(IValueTest, ComplexIValuePrint) {
307 {
308 IValue complex(c10::complex<double>(2, -3));
309 std::stringstream ss;
310 ss << complex;
311 ASSERT_EQ(ss.str(), "2.-3.j");
312 }
313
314 {
315 IValue complex(c10::complex<double>(2, 0));
316 std::stringstream ss;
317 ss << complex;
318 ASSERT_EQ(ss.str(), "2.+0.j");
319 }
320
321 {
322 IValue complex(c10::complex<double>(0, 3));
323 std::stringstream ss;
324 ss << complex;
325 ASSERT_EQ(ss.str(), "0.+3.j");
326 }
327 }
328
TEST(IValueTest,Complex)329 TEST(IValueTest, Complex) {
330 auto c = c10::complex<double>(2, 3);
331 auto c_ = c10::complex<double>(2, -3);
332 IValue c1(c), c2(c_), c3{at::Scalar(c)};
333
334 ASSERT_TRUE(c1.isComplexDouble());
335 ASSERT_TRUE(c3.isComplexDouble());
336
337 ASSERT_EQ(c, c1.toComplexDouble());
338 ASSERT_FALSE(c1 == c2);
339 ASSERT_TRUE(c1 == c3);
340
341 ASSERT_TRUE(c1.isScalar());
342 ASSERT_TRUE(c2.toScalar().equal(c_));
343 }
344
TEST(IValueTest,BasicFuture)345 TEST(IValueTest, BasicFuture) {
346 auto f1 = c10::make_intrusive<ivalue::Future>(IntType::get());
347 ASSERT_FALSE(f1->completed());
348
349 f1->markCompleted(IValue(42));
350 ASSERT_TRUE(f1->completed());
351 ASSERT_EQ(42, f1->value().toInt());
352 IValue iv(f1);
353 ASSERT_EQ(42, iv.toFuture()->value().toInt());
354 }
355
TEST(IValueTest,FutureCallbacks)356 TEST(IValueTest, FutureCallbacks) {
357 auto f2 = c10::make_intrusive<ivalue::Future>(IntType::get());
358 int calledTimesA = 0;
359 int calledTimesB = 0;
360 f2->addCallback([&calledTimesA](ivalue::Future& f2) {
361 ASSERT_TRUE(f2.completed());
362 ASSERT_EQ(f2.value().toInt(), 43);
363 ++calledTimesA;
364 });
365 f2->markCompleted(IValue(43));
366 ASSERT_EQ(calledTimesA, 1);
367 ASSERT_EQ(calledTimesB, 0);
368 // Post-markCompleted()
369 f2->addCallback([&calledTimesB](ivalue::Future& f2) {
370 ASSERT_TRUE(f2.completed());
371 ASSERT_EQ(f2.value().toInt(), 43);
372 ++calledTimesB;
373 });
374 ASSERT_EQ(calledTimesA, 1);
375 ASSERT_EQ(calledTimesB, 1);
376 ASSERT_FALSE(f2->hasError());
377 }
378
TEST(IValueTest,FutureExceptions)379 TEST(IValueTest, FutureExceptions) {
380 auto f3 = c10::make_intrusive<ivalue::Future>(IntType::get());
381 int calledTimes = 0;
382 f3->addCallback([&calledTimes](ivalue::Future& f3) {
383 ASSERT_TRUE(f3.completed());
384 try {
385 (void)f3.value();
386 } catch (const std::exception& e) {
387 if (std::string(e.what()) == "My Error") {
388 ++calledTimes;
389 }
390 }
391 });
392 ivalue::Future::FutureError err("My Error");
393 f3->setError(std::make_exception_ptr(err));
394 ASSERT_EQ(calledTimes, 1);
395 ASSERT_TRUE(f3->hasError());
396 ASSERT_EQ(f3->tryRetrieveErrorMessage(), std::string("My Error"));
397 }
398
TEST(IValueTest,FutureSetError)399 TEST(IValueTest, FutureSetError) {
400 auto f1 = c10::make_intrusive<ivalue::Future>(IntType::get());
401 f1->setError(std::make_exception_ptr(std::runtime_error("foo")));
402 try {
403 f1->setError(std::make_exception_ptr(std::runtime_error("bar")));
404 FAIL() << "Expected to throw";
405 } catch (std::exception& e) {
406 EXPECT_THAT(e.what(), ::testing::HasSubstr("Error already set"));
407 EXPECT_THAT(e.what(), ::testing::HasSubstr("foo"));
408 EXPECT_THAT(e.what(), ::testing::HasSubstr("bar"));
409 }
410 }
411
TEST(IValueTest,ValueEquality)412 TEST(IValueTest, ValueEquality) {
413 EXPECT_EQ(IValue("asdf"), IValue("asdf"));
414 EXPECT_NE(IValue("asdf"), IValue("ASDF"));
415 EXPECT_NE(IValue("2"), IValue(2));
416 EXPECT_EQ(IValue(1), IValue(1));
417
418 // Check the equals() variant that returns an IValue
419 auto res = IValue("asdf").equals("asdf");
420 EXPECT_TRUE(res.isBool());
421 EXPECT_TRUE(res.toBool());
422
423 res = IValue("asdf").equals(1);
424 EXPECT_TRUE(res.isBool());
425 EXPECT_FALSE(res.toBool());
426 }
427
TEST(IValueTest,TensorEquality)428 TEST(IValueTest, TensorEquality) {
429 auto rawTensor = torch::zeros({2, 3});
430 auto rawTensorCopy = rawTensor.clone();
431 auto t = IValue(rawTensor);
432 auto tCopy = IValue(rawTensorCopy);
433
434 // This should throw, because elementwise equality is ambiguous for
435 // multi-element Tensors.
436 auto testEquality = []() {
437 return IValue(torch::ones({2, 3})) == IValue(torch::rand({2, 3}));
438 };
439 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
440 EXPECT_ANY_THROW(testEquality());
441
442 // equals() should return a tensor of all `true`.
443 IValue eqTensor = t.equals(tCopy);
444 EXPECT_TRUE(eqTensor.isTensor());
445 auto booleanTrue = torch::ones({2, 3}).to(torch::kBool);
446 EXPECT_TRUE(eqTensor.toTensor().equal(booleanTrue));
447
448 // Test identity checking
449 EXPECT_TRUE(t.is(t));
450 EXPECT_FALSE(t.is(tCopy));
451 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
452 IValue tReference = t;
453 EXPECT_TRUE(t.is(tReference));
454 }
455
TEST(IValueTest,ListEquality)456 TEST(IValueTest, ListEquality) {
457 IValue c1 = std::vector<int64_t>{0, 1, 2, 3};
458 IValue c2 = std::vector<int64_t>{0, 1, 2, 3};
459 IValue c3 = std::vector<int64_t>{0, 1, 2, 3, 4};
460 EXPECT_EQ(c1, c1);
461 EXPECT_EQ(c1, c2);
462 EXPECT_FALSE(c1.is(c2));
463 EXPECT_NE(c1, c3);
464 EXPECT_NE(c2, c3);
465 }
466
TEST(IValueTest,DictEquality)467 TEST(IValueTest, DictEquality) {
468 auto innerDict = c10::Dict<std::string, std::string>();
469 innerDict.insert("foo", "bar");
470
471 auto d1 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
472 d1.insert("one", innerDict);
473 d1.insert("two", innerDict);
474 d1.insert("three", innerDict);
475 auto c1 = IValue(d1);
476
477 auto d2 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
478 d2.insert("one", innerDict.copy());
479 d2.insert("two", innerDict.copy());
480 d2.insert("three", innerDict.copy());
481 auto c2 = IValue(d2);
482
483 auto d3 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
484 d3.insert("one", innerDict.copy());
485 d3.insert("two", innerDict.copy());
486 d3.insert("three", innerDict.copy());
487 d3.insert("four", innerDict.copy());
488 auto c3 = IValue(d3);
489
490 auto d4 = c10::Dict<std::string, c10::Dict<std::string, std::string>>();
491 d4.insert("one", innerDict.copy());
492 d4.insert("two", innerDict.copy());
493 auto innerDictNotEqual = c10::Dict<std::string, std::string>();
494 innerDictNotEqual.insert("bar", "foo");
495 d4.insert("three", innerDictNotEqual);
496 auto c4 = IValue(d4);
497
498 EXPECT_EQ(c1, c1);
499 EXPECT_EQ(c1, c2);
500 EXPECT_FALSE(c1.is(c2));
501 EXPECT_NE(c1, c3);
502 EXPECT_NE(c2, c3);
503 EXPECT_NE(c1, c4);
504 EXPECT_NE(c2, c4);
505 }
506
TEST(IValueTest,DictEqualityDifferentOrder)507 TEST(IValueTest, DictEqualityDifferentOrder) {
508 auto d1 = c10::Dict<std::string, int64_t>();
509 d1.insert("one", 1);
510 d1.insert("two", 2);
511 auto d2 = c10::Dict<std::string, int64_t>();
512 d2.insert("two", 2);
513 d2.insert("one", 1);
514
515 EXPECT_EQ(d1, d2);
516 }
517
TEST(IValueTest,ListNestedEquality)518 TEST(IValueTest, ListNestedEquality) {
519 IValue c1 = std::vector<std::vector<int64_t>>({{0}, {0, 1}, {0, 1, 2}});
520 IValue c2 = std::vector<std::vector<int64_t>>({{0}, {0, 1}, {0, 1, 2}});
521 IValue c3 = std::vector<std::vector<int64_t>>({{1}, {0, 1}, {0, 1, 2}});
522 EXPECT_EQ(c1, c1);
523 EXPECT_EQ(c1, c2);
524 EXPECT_NE(c1, c3);
525 EXPECT_NE(c2, c3);
526 }
527
TEST(IValueTest,StreamEquality)528 TEST(IValueTest, StreamEquality) {
529 at::Device device1 = at::Device(kCUDA, 0);
530 at::Device device2 = at::Device(kCUDA, 1);
531 c10::Stream stream1 = c10::Stream(c10::Stream::Default::DEFAULT, device1);
532 c10::Stream stream2 = c10::Stream(c10::Stream::Default::DEFAULT, device2);
533 IValue lhs(stream1);
534 IValue rhs_different(stream2);
535 IValue rhs_same(stream1);
536 EXPECT_FALSE(lhs.equals(rhs_different).toBool());
537 EXPECT_TRUE(lhs.equals(rhs_same).toBool());
538 }
539
TEST(IValueTest,EnumEquality)540 TEST(IValueTest, EnumEquality) {
541 auto cu = std::make_shared<CompilationUnit>();
542 IValue int_ivalue_1(1);
543 IValue int_ivalue_2(2);
544 IValue str_ivalue_1("1");
545 auto int_enum_type1 = EnumType::create(
546 "enum_class_1",
547 IntType::get(),
548 {{"enum_name_1", int_ivalue_1}, {"enum_name_2", int_ivalue_2}},
549 cu);
550 auto int_enum_type2 = EnumType::create(
551 "enum_class_2",
552 IntType::get(),
553 {{"enum_name_1", int_ivalue_1}, {"enum_name_2", int_ivalue_2}},
554 cu);
555 auto string_enum_type = EnumType::create(
556 "enum_class_3", StringType::get(), {{"enum_name_1", str_ivalue_1}}, cu);
557
558 EXPECT_EQ(
559 IValue(c10::make_intrusive<ivalue::EnumHolder>(
560 int_enum_type1, "enum_name_1", int_ivalue_1)),
561 IValue(c10::make_intrusive<ivalue::EnumHolder>(
562 int_enum_type1, "enum_name_1", int_ivalue_1))
563 );
564
565 EXPECT_NE(
566 IValue(c10::make_intrusive<ivalue::EnumHolder>(
567 int_enum_type1, "enum_name_1", int_ivalue_1)),
568 IValue(c10::make_intrusive<ivalue::EnumHolder>(
569 int_enum_type2, "enum_name_1", int_ivalue_1))
570 );
571
572 EXPECT_NE(
573 IValue(c10::make_intrusive<ivalue::EnumHolder>(
574 int_enum_type1, "enum_name_1", int_ivalue_1)),
575 IValue(c10::make_intrusive<ivalue::EnumHolder>(
576 int_enum_type1, "enum_name_2", int_ivalue_2))
577 );
578
579 EXPECT_NE(
580 IValue(c10::make_intrusive<ivalue::EnumHolder>(
581 int_enum_type1, "enum_name_1", int_ivalue_1)),
582 IValue(c10::make_intrusive<ivalue::EnumHolder>(
583 string_enum_type, "enum_name_1", str_ivalue_1))
584 );
585 }
586
TEST(IValueTest,isPtrType)587 TEST(IValueTest, isPtrType) {
588 IValue tensor(at::rand({3, 4}));
589 IValue undefinedTensor((at::Tensor()));
590 IValue integer(42);
591 IValue str("hello");
592
593 EXPECT_TRUE(tensor.isPtrType());
594 EXPECT_FALSE(undefinedTensor.isPtrType());
595 EXPECT_FALSE(integer.isPtrType());
596 EXPECT_TRUE(str.isPtrType());
597 }
598
TEST(IValueTest,isAliasOf)599 TEST(IValueTest, isAliasOf) {
600 auto sampleIValues = makeSampleIValues();
601 for (auto& iv: sampleIValues) {
602 for (auto& iv2: sampleIValues) {
603 if (&iv == &iv2 && iv.isPtrType()) {
604 EXPECT_TRUE(iv.isAliasOf(iv2));
605 } else {
606 EXPECT_FALSE(iv.isAliasOf(iv2));
607 }
608 }
609 }
610 }
611
TEST(IValueTest,internalToPointer)612 TEST(IValueTest, internalToPointer) {
613 IValue tensor(at::rand({3, 4}));
614 IValue str("hello");
615
616 EXPECT_EQ(tensor.internalToPointer(), tensor.unsafeToTensorImpl());
617 EXPECT_NE(str.internalToPointer(), nullptr);
618
619 IValue nullStr((c10::intrusive_ptr<ivalue::ConstantString>()));
620 ASSERT_TRUE(nullStr.isString());
621 EXPECT_EQ(nullStr.internalToPointer(), nullptr);
622 }
623
TEST(IValueTest,IdentityComparisonAndHashing)624 TEST(IValueTest, IdentityComparisonAndHashing) {
625 at::Tensor t1 = at::rand({3, 4});
626 at::Tensor t2 = at::rand({3, 4});
627 IValue tv1(t1), tv2(t2);
628 IValue tv1b(t1);
629
630 EXPECT_EQ(tv1.hash(), tv1b.hash());
631 EXPECT_NE(tv1.hash(), tv2.hash());
632
633 EXPECT_TRUE(tv1.is(tv1));
634 EXPECT_TRUE(tv1.is(tv1b));
635 EXPECT_TRUE(tv1b.is(tv1));
636 EXPECT_TRUE(tv2.is(tv2));
637
638 EXPECT_FALSE(tv1.is(tv2));
639 EXPECT_FALSE(tv2.is(tv1));
640
641 IValue none;
642 IValue undefinedTensor((at::Tensor()));
643
644 EXPECT_TRUE(none.is(undefinedTensor));
645 EXPECT_TRUE(undefinedTensor.is(none));
646
647 // Is this a bug? We should probably have a is b => a.hash() == b.hash()
648 EXPECT_NE(none.hash(), undefinedTensor.hash());
649
650 auto sampleIValues = makeSampleIValues();
651 auto sampleIValues2 = makeSampleIValues();
652 auto moreSampleIValues = makeMoreSampleIValues();
653
654 ASSERT_EQ(sampleIValues.size(), moreSampleIValues.size());
655 for (const auto ii : c10::irange(sampleIValues.size())) {
656 if (sampleIValues[ii].isComplexDouble() ||
657 sampleIValues[ii].isBlob() ||
658 sampleIValues[ii].isList() ||
659 sampleIValues[ii].isFuture() ||
660 sampleIValues[ii].isStream() ||
661 sampleIValues[ii].isObject() ||
662 sampleIValues[ii].isGenericDict()) {
663 // Not hashable.
664 continue;
665 }
666 // Tuples may or may not have the same hash across instantiations.
667 if (!sampleIValues[ii].isTuple()) {
668 // Constant strings will have the same pointer value.
669 if (sampleIValues[ii].isPtrType() && !sampleIValues[ii].isString()) {
670 EXPECT_NE(sampleIValues[ii].hash(), sampleIValues2[ii].hash())
671 << " at index " << ii;
672 } else {
673 EXPECT_EQ(sampleIValues[ii].hash(), sampleIValues2[ii].hash())
674 << " at index " << ii;
675 }
676 }
677 if (!sampleIValues[ii].isNone() && !moreSampleIValues[ii].isNone()) {
678 EXPECT_NE(sampleIValues[ii].hash(), moreSampleIValues[ii].hash())
679 << " at index " << ii;
680 }
681 }
682 }
683
684 // Sparse tensors do not work with static CPU dispatch
685 #ifndef ATEN_CPU_STATIC_DISPATCH
TEST(IValueTest,IdentityAndHashing_SparseCOO)686 TEST(IValueTest, IdentityAndHashing_SparseCOO) {
687 using namespace torch::indexing;
688
689 at::Tensor t1 = at::rand({3, 4}).to_sparse();
690 at::Tensor t2 = at::rand({3, 4}).to_sparse();
691 at::Tensor t3 = at::rand({3, 4});
692
693 IValue tv1(t1), tv1b(t1), tv2(t2), tv3(t3);
694
695 EXPECT_EQ(tv1.hash(), tv1b.hash());
696 EXPECT_NE(tv1.hash(), tv2.hash());
697
698 EXPECT_TRUE(tv1.is(tv1b));
699 EXPECT_FALSE(tv1.is(tv2));
700
701 EXPECT_TRUE(tv1.isAliasOf(tv1b));
702 EXPECT_FALSE(tv1.isAliasOf(tv2));
703 EXPECT_FALSE(tv1.isAliasOf(tv3));
704
705 std::vector<int64_t> idx_array1 = {0, 1, 1, 0, 0, 1};
706 at::Tensor idx1 = torch::from_blob(
707 idx_array1.data(),
708 {2, 3},
709 torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU));
710 std::vector<int64_t> idx_array2 = {1, 1, 2, 0, 1, 2};
711 at::Tensor idx2 = torch::from_blob(
712 idx_array2.data(),
713 {2, 3},
714 torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU));
715 std::vector<int32_t> val_array = {3, -5, 7};
716 at::Tensor val = torch::from_blob(
717 val_array.data(),
718 {3},
719 torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU));
720 at::Tensor sparse1 = torch::sparse_coo_tensor(
721 idx1, val, {3, 3}, torch::TensorOptions().dtype(torch::kInt32));
722 at::Tensor sparse2 = torch::sparse_coo_tensor(
723 idx2, val, {3, 3}, torch::TensorOptions().dtype(torch::kInt32));
724
725 IValue idx1_v(idx1), idx2_v(idx2);
726 IValue val_v(val);
727 IValue sparse1_v(sparse1), sparse2_v(sparse2);
728
729 EXPECT_TRUE(sparse1_v.isAliasOf(sparse2_v));
730 EXPECT_TRUE(sparse1_v.isAliasOf(idx1_v));
731 EXPECT_TRUE(sparse1_v.isAliasOf(val_v));
732 EXPECT_TRUE(sparse2_v.isAliasOf(idx2_v));
733 EXPECT_TRUE(sparse2_v.isAliasOf(val_v));
734 EXPECT_FALSE(idx1_v.isAliasOf(idx2_v));
735 EXPECT_FALSE(idx1_v.isAliasOf(val_v));
736 EXPECT_FALSE(sparse1_v.isAliasOf(idx2_v));
737 }
738 #endif // ATEN_CPU_STATIC_DISPATCH
739
TEST(IValueTest,getSubValues)740 TEST(IValueTest, getSubValues) {
741 // Scalars have no subvalues.
742 IValue integer(42), float_(1.5), complex(c10::complex<double>(2, 3));
743
744 IValue::HashAliasedIValues subvalues;
745
746 integer.getSubValues(subvalues);
747 EXPECT_TRUE(subvalues.empty());
748
749 subvalues.clear();
750
751 float_.getSubValues(subvalues);
752 EXPECT_TRUE(subvalues.empty());
753
754 subvalues.clear();
755
756 complex.getSubValues(subvalues);
757 EXPECT_TRUE(subvalues.empty());
758
759 subvalues.clear();
760
761 at::Tensor t1(at::rand({3, 4})), t2(at::rand({3, 4}));
762 IValue tv1(t1), tv2(t2);
763 IValue list(std::vector<at::Tensor>{t1, t2});
764 IValue tuple(ivalue::Tuple::create({tv1, tv2}));
765
766 c10::Dict<int64_t, at::Tensor> m;
767 m.insert(1, t1);
768 m.insert(2, t2);
769
770 IValue dict(std::move(m));
771
772 auto objType = ClassType::create(std::nullopt, {});
773 objType->addAttribute("t1", tv1.type());
774 objType->addAttribute("t2", tv2.type());
775
776 auto o = ivalue::Object::create(StrongTypePtr(nullptr, objType), 2);
777 o->setSlot(0, tv1);
778 o->setSlot(1, tv2);
779
780 IValue object(o);
781 tv1.getSubValues(subvalues);
782 EXPECT_EQ(subvalues.size(), 1);
783 EXPECT_EQ(subvalues.count(tv1), 1);
784
785 subvalues.clear();
786
787 for (auto& container: {list, tuple, dict, object}) {
788 container.getSubValues(subvalues);
789 EXPECT_EQ(subvalues.size(), 3);
790 EXPECT_EQ(subvalues.count(container), 1);
791 EXPECT_EQ(subvalues.count(tv1), 1);
792 EXPECT_EQ(subvalues.count(tv2), 1);
793
794 subvalues.clear();
795 }
796 }
797
TEST(IValueTest,ScalarBool)798 TEST(IValueTest, ScalarBool) {
799 Scalar expected(true);
800 IValue v(expected);
801 Scalar actual = v.toScalar();
802 EXPECT_TRUE(actual.isBoolean());
803 EXPECT_TRUE(actual.toBool());
804 }
805
TEST(IValueTest,ToWeakAndBack)806 TEST(IValueTest, ToWeakAndBack) {
807 auto sampleInputs = makeSampleIValues();
808 for (const auto& sample: sampleInputs) {
809 WeakIValue weak(sample);
810 EXPECT_IVALUE_EQ(sample, weak.lock());
811 }
812 }
813
814 // Storage and Generator did not set is_intrusive_ptr if they were
815 // undefined, which led use_count to return 1 instead of 0 for these
816 // cases.
TEST(IValueTest,UseCountCornerCases)817 TEST(IValueTest, UseCountCornerCases) {
818 at::Storage undefinedStorage;
819 at::Generator undefinedGenerator;
820 at::Tensor undefinedTensor;
821
822 IValue ivEmptyStorage(undefinedStorage);
823 IValue ivEmptyGenerator(undefinedGenerator);
824 IValue ivEmptyTensor(undefinedTensor);
825
826 ASSERT_EQ(1, ivEmptyStorage.use_count());
827 ASSERT_EQ(1, ivEmptyGenerator.use_count());
828 ASSERT_EQ(0, ivEmptyTensor.use_count());
829 }
830
831 // TODO(gmagogsfm): Add type conversion test?
832
833 using ivalue::TupleElements;
834
835 namespace {
validateTupleElements(TupleElements & te,c10::ArrayRef<IValue> contents)836 void validateTupleElements(TupleElements& te, c10::ArrayRef<IValue> contents) {
837 EXPECT_EQ(te.empty(), contents.empty());
838 EXPECT_EQ(te.size(), contents.size());
839 for (const auto idx: c10::irange(contents.size())) {
840 EXPECT_IVALUE_EQ(te[idx], contents[idx]);
841 EXPECT_IVALUE_EQ(te.at(idx), contents[idx]);
842 EXPECT_IVALUE_EQ(*(te.begin() + idx), contents[idx]);
843 }
844 if (!contents.empty()) {
845 EXPECT_IVALUE_EQ(te.back(), contents.back());
846 }
847 auto v = std::move(te).vec();
848 EXPECT_EQ(v.size(), contents.size());
849 for (const auto idx: c10::irange(contents.size())) {
850 EXPECT_IVALUE_EQ(v[idx], contents[idx]);
851 }
852 }
853 } // namespace
854
TEST(TupleElementsTest,Basic)855 TEST(TupleElementsTest, Basic) {
856 TupleElements empty;
857 validateTupleElements(empty, {});
858 TupleElements size1(1);
859 validateTupleElements(size1, {1});
860 TupleElements size2(1, 2);
861 validateTupleElements(size2, {1, 2});
862 TupleElements size3(1, 2, 3);
863 validateTupleElements(size3, {1, 2, 3});
864
865 auto sampleIValuesArray = makeSampleIValues();
866 TupleElements large(std::vector<IValue>(sampleIValuesArray.begin(), sampleIValuesArray.end()));
867 validateTupleElements(large, sampleIValuesArray);
868 }
869
870 namespace {
871
872 std::array<TupleElements(*)(), 3> factories = {
__anon2940f6690702() 873 []() { return TupleElements();},
__anon2940f6690802() 874 []() { return TupleElements(1, 2, 3);},
__anon2940f6690902() 875 []() { return TupleElements(std::vector<IValue>({1, 2, 3, "hello"})); }
876 };
877
878 std::array<std::vector<IValue>, 3> expectedContents = {
879 std::vector<IValue>(),
880 std::vector<IValue>({1, 2, 3}),
881 std::vector<IValue>({1, 2, 3, "hello"}),
882 };
883
884 }
885
TEST(TupleElementsTest,Resize)886 TEST(TupleElementsTest, Resize) {
887 std::array<std::vector<IValue>, 3> newContents = {std::vector<IValue>(), std::vector<IValue>({4, 5, 6}), std::vector<IValue>({7, 8, 9, "hello"})};
888
889 for (auto factory : factories) {
890 for (const auto& contents : newContents) {
891 auto te = factory();
892 auto contentsCopy = contents;
893 te.setContents(std::move(contentsCopy));
894 validateTupleElements(te, contents);
895 }
896 }
897 }
898
TEST(TupleElementsTest,CopyAndMoveConstruct)899 TEST(TupleElementsTest, CopyAndMoveConstruct) {
900 int idx = 0;
901 for (auto fromFactory : factories) {
902 auto toMoveFrom = fromFactory();
903 TupleElements movedInto(std::move(toMoveFrom));
904 validateTupleElements(movedInto, expectedContents[idx]);
905 auto toCopyFrom = fromFactory();
906 TupleElements copiedInto(toCopyFrom);
907 validateTupleElements(copiedInto, expectedContents[idx]);
908 idx++;
909 }
910 }
911
TEST(TupleElementsTest,CopyAndMoveAssign)912 TEST(TupleElementsTest, CopyAndMoveAssign) {
913 int fromIdx = 0;
914 for (auto fromFactory : factories) {
915 for (auto toFactory : factories) {
916 auto from = fromFactory();
917 auto to = toFactory();
918 auto copyFrom = fromFactory();
919 auto toCopy = toFactory();
920 to = std::move(from);
921 validateTupleElements(to, expectedContents[fromIdx]);
922 toCopy = copyFrom;
923 validateTupleElements(toCopy, expectedContents[fromIdx]);
924 }
925 fromIdx++;
926 }
927 }
928
929 } // namespace c10
930