xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/weakref_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/ivalue.h>
5 
6 #include <iostream>
7 #include <chrono>
8 #include <sstream>
9 
10 using at::Tensor;
11 using c10::WeakIValue;
12 using c10::IValue;
13 
14 // Weak pointer tests
15 // gets invalidated
TEST(TestWeakPointer,WeakPointerGetsInvalidated)16 TEST(TestWeakPointer, WeakPointerGetsInvalidated) {
17   IValue a = at::ones({2, 2});
18   WeakIValue b = a;
19   a = IValue();
20   ASSERT_TRUE(b.lock().isNone());
21 }
22 
23 // can successfully lock
TEST(TestWeakPointer,WeakPointerLock)24 TEST(TestWeakPointer, WeakPointerLock) {
25   IValue a = at::ones({2, 2});
26   WeakIValue b = a;
27   auto c = b.lock();
28   ASSERT_TRUE(c.isTensor());
29 
30   a = IValue();
31   ASSERT_TRUE(!b.lock().isNone());
32   c = IValue();
33   ASSERT_TRUE(b.lock().isNone());
34 }
35 
36 // updates refcounts correctly
TEST(TestWeakPointer,WeakUpdatesRefcountsTest)37 TEST(TestWeakPointer, WeakUpdatesRefcountsTest) {
38   at::Tensor a = at::ones({2, 2});
39   ASSERT_EQ(a.use_count(), 1);
40   ASSERT_EQ(a.weak_use_count(), 1);
41   {
42     WeakIValue b = IValue(a);
43     ASSERT_EQ(a.use_count(), 1);
44     ASSERT_EQ(a.weak_use_count(), 2);
45   }
46   ASSERT_EQ(a.use_count(), 1);
47   ASSERT_EQ(a.weak_use_count(), 1);
48   {
49     WeakIValue b = IValue(a);
50     ASSERT_EQ(a.use_count(), 1);
51     auto locked = b.lock();
52     ASSERT_FALSE(locked.isNone());
53     ASSERT_EQ(a.use_count(), 2);
54   }
55   ASSERT_EQ(a.use_count(), 1);
56   ASSERT_EQ(a.weak_use_count(), 1);
57   {
58     WeakIValue b = IValue(a);
59     ASSERT_EQ(a.use_count(), 1);
60     ASSERT_EQ(a.weak_use_count(), 2);
61     a.reset();
62     ASSERT_EQ(b.use_count(), 0);
63     ASSERT_EQ(b.weak_use_count(), 1);
64   }
65 }
66