1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/core/ivalue.h>
5
6 #include <iostream>
7 #include <chrono>
8 #include <sstream>
9
10 using at::Tensor;
11 using c10::WeakIValue;
12 using c10::IValue;
13
14 // Weak pointer tests
15 // gets invalidated
TEST(TestWeakPointer,WeakPointerGetsInvalidated)16 TEST(TestWeakPointer, WeakPointerGetsInvalidated) {
17 IValue a = at::ones({2, 2});
18 WeakIValue b = a;
19 a = IValue();
20 ASSERT_TRUE(b.lock().isNone());
21 }
22
23 // can successfully lock
TEST(TestWeakPointer,WeakPointerLock)24 TEST(TestWeakPointer, WeakPointerLock) {
25 IValue a = at::ones({2, 2});
26 WeakIValue b = a;
27 auto c = b.lock();
28 ASSERT_TRUE(c.isTensor());
29
30 a = IValue();
31 ASSERT_TRUE(!b.lock().isNone());
32 c = IValue();
33 ASSERT_TRUE(b.lock().isNone());
34 }
35
36 // updates refcounts correctly
TEST(TestWeakPointer,WeakUpdatesRefcountsTest)37 TEST(TestWeakPointer, WeakUpdatesRefcountsTest) {
38 at::Tensor a = at::ones({2, 2});
39 ASSERT_EQ(a.use_count(), 1);
40 ASSERT_EQ(a.weak_use_count(), 1);
41 {
42 WeakIValue b = IValue(a);
43 ASSERT_EQ(a.use_count(), 1);
44 ASSERT_EQ(a.weak_use_count(), 2);
45 }
46 ASSERT_EQ(a.use_count(), 1);
47 ASSERT_EQ(a.weak_use_count(), 1);
48 {
49 WeakIValue b = IValue(a);
50 ASSERT_EQ(a.use_count(), 1);
51 auto locked = b.lock();
52 ASSERT_FALSE(locked.isNone());
53 ASSERT_EQ(a.use_count(), 2);
54 }
55 ASSERT_EQ(a.use_count(), 1);
56 ASSERT_EQ(a.weak_use_count(), 1);
57 {
58 WeakIValue b = IValue(a);
59 ASSERT_EQ(a.use_count(), 1);
60 ASSERT_EQ(a.weak_use_count(), 2);
61 a.reset();
62 ASSERT_EQ(b.use_count(), 0);
63 ASSERT_EQ(b.weak_use_count(), 1);
64 }
65 }
66