xref: /aosp_15_r20/external/pytorch/c10/test/util/LeftRight_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/LeftRight.h>
2 #include <gtest/gtest.h>
3 #include <vector>
4 
5 using c10::LeftRight;
6 using std::vector;
7 
TEST(LeftRightTest,givenInt_whenWritingAndReading_thenChangesArePresent)8 TEST(LeftRightTest, givenInt_whenWritingAndReading_thenChangesArePresent) {
9   LeftRight<int> obj;
10 
11   obj.write([](int& obj) { obj = 5; });
12   int read = obj.read([](const int& obj) { return obj; });
13   EXPECT_EQ(5, read);
14 
15   // check changes are also present in background copy
16   obj.write([](int&) {}); // this switches to the background copy
17   read = obj.read([](const int& obj) { return obj; });
18   EXPECT_EQ(5, read);
19 }
20 
TEST(LeftRightTest,givenVector_whenWritingAndReading_thenChangesArePresent)21 TEST(LeftRightTest, givenVector_whenWritingAndReading_thenChangesArePresent) {
22   LeftRight<vector<int>> obj;
23 
24   obj.write([](vector<int>& obj) { obj.push_back(5); });
25   vector<int> read = obj.read([](const vector<int>& obj) { return obj; });
26   EXPECT_EQ((vector<int>{5}), read);
27 
28   obj.write([](vector<int>& obj) { obj.push_back(6); });
29   read = obj.read([](const vector<int>& obj) { return obj; });
30   EXPECT_EQ((vector<int>{5, 6}), read);
31 }
32 
TEST(LeftRightTest,givenVector_whenWritingReturnsValue_thenValueIsReturned)33 TEST(LeftRightTest, givenVector_whenWritingReturnsValue_thenValueIsReturned) {
34   LeftRight<vector<int>> obj;
35 
36   auto a = obj.write([](vector<int>&) -> int { return 5; });
37   static_assert(std::is_same<int, decltype(a)>::value);
38   EXPECT_EQ(5, a);
39 }
40 
TEST(LeftRightTest,readsCanBeConcurrent)41 TEST(LeftRightTest, readsCanBeConcurrent) {
42   LeftRight<int> obj;
43   std::atomic<int> num_running_readers{0};
44 
45   std::thread reader1([&]() {
46     obj.read([&](const int&) {
47       ++num_running_readers;
48       while (num_running_readers.load() < 2) {
49       }
50     });
51   });
52 
53   std::thread reader2([&]() {
54     obj.read([&](const int&) {
55       ++num_running_readers;
56       while (num_running_readers.load() < 2) {
57       }
58     });
59   });
60 
61   // the threads only finish after both entered the read function.
62   // if LeftRight didn't allow concurrency, this would cause a deadlock.
63   reader1.join();
64   reader2.join();
65 }
66 
TEST(LeftRightTest,writesCanBeConcurrentWithReads_readThenWrite)67 TEST(LeftRightTest, writesCanBeConcurrentWithReads_readThenWrite) {
68   LeftRight<int> obj;
69   std::atomic<bool> reader_running{false};
70   std::atomic<bool> writer_running{false};
71 
72   std::thread reader([&]() {
73     obj.read([&](const int&) {
74       reader_running = true;
75       while (!writer_running.load()) {
76       }
77     });
78   });
79 
80   std::thread writer([&]() {
81     // run read first, write second
82     while (!reader_running.load()) {
83     }
84 
85     obj.write([&](int&) { writer_running = true; });
86   });
87 
88   // the threads only finish after both entered the read function.
89   // if LeftRight didn't allow concurrency, this would cause a deadlock.
90   reader.join();
91   writer.join();
92 }
93 
TEST(LeftRightTest,writesCanBeConcurrentWithReads_writeThenRead)94 TEST(LeftRightTest, writesCanBeConcurrentWithReads_writeThenRead) {
95   LeftRight<int> obj;
96   std::atomic<bool> writer_running{false};
97   std::atomic<bool> reader_running{false};
98 
99   std::thread writer([&]() {
100     obj.read([&](const int&) {
101       writer_running = true;
102       while (!reader_running.load()) {
103       }
104     });
105   });
106 
107   std::thread reader([&]() {
108     // run write first, read second
109     while (!writer_running.load()) {
110     }
111 
112     obj.read([&](const int&) { reader_running = true; });
113   });
114 
115   // the threads only finish after both entered the read function.
116   // if LeftRight didn't allow concurrency, this would cause a deadlock.
117   writer.join();
118   reader.join();
119 }
120 
TEST(LeftRightTest,writesCannotBeConcurrentWithWrites)121 TEST(LeftRightTest, writesCannotBeConcurrentWithWrites) {
122   LeftRight<int> obj;
123   std::atomic<bool> first_writer_started{false};
124   std::atomic<bool> first_writer_finished{false};
125 
126   std::thread writer1([&]() {
127     obj.write([&](int&) {
128       first_writer_started = true;
129       std::this_thread::sleep_for(std::chrono::milliseconds(50));
130       first_writer_finished = true;
131     });
132   });
133 
134   std::thread writer2([&]() {
135     // make sure the other writer runs first
136     while (!first_writer_started.load()) {
137     }
138 
139     obj.write([&](int&) {
140       // expect the other writer finished before this one starts
141       EXPECT_TRUE(first_writer_finished.load());
142     });
143   });
144 
145   writer1.join();
146   writer2.join();
147 }
148 
149 namespace {
150 class MyException : public std::exception {};
151 } // namespace
152 
TEST(LeftRightTest,whenReadThrowsException_thenThrowsThrough)153 TEST(LeftRightTest, whenReadThrowsException_thenThrowsThrough) {
154   LeftRight<int> obj;
155 
156   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
157   EXPECT_THROW(obj.read([](const int&) { throw MyException(); }), MyException);
158 }
159 
TEST(LeftRightTest,whenWriteThrowsException_thenThrowsThrough)160 TEST(LeftRightTest, whenWriteThrowsException_thenThrowsThrough) {
161   LeftRight<int> obj;
162 
163   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
164   EXPECT_THROW(obj.write([](int&) { throw MyException(); }), MyException);
165 }
166 
TEST(LeftRightTest,givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState)167 TEST(
168     LeftRightTest,
169     givenInt_whenWriteThrowsExceptionOnFirstCall_thenResetsToOldState) {
170   LeftRight<int> obj;
171 
172   obj.write([](int& obj) { obj = 5; });
173 
174   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
175   EXPECT_THROW(
176       obj.write([](int& obj) {
177         obj = 6;
178         throw MyException();
179       }),
180       MyException);
181 
182   // check reading it returns old value
183   int read = obj.read([](const int& obj) { return obj; });
184   EXPECT_EQ(5, read);
185 
186   // check changes are also present in background copy
187   obj.write([](int&) {}); // this switches to the background copy
188   read = obj.read([](const int& obj) { return obj; });
189   EXPECT_EQ(5, read);
190 }
191 
192 // note: each write is executed twice, on the foreground and background copy.
193 // We need to test a thrown exception in either call is handled correctly.
TEST(LeftRightTest,givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState)194 TEST(
195     LeftRightTest,
196     givenInt_whenWriteThrowsExceptionOnSecondCall_thenKeepsNewState) {
197   LeftRight<int> obj;
198 
199   obj.write([](int& obj) { obj = 5; });
200   bool write_called = false;
201 
202   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
203   EXPECT_THROW(
204       obj.write([&](int& obj) {
205         obj = 6;
206         if (write_called) {
207           // this is the second time the write callback is executed
208           throw MyException();
209         } else {
210           write_called = true;
211         }
212       }),
213       MyException);
214 
215   // check reading it returns new value
216   int read = obj.read([](const int& obj) { return obj; });
217   EXPECT_EQ(6, read);
218 
219   // check changes are also present in background copy
220   obj.write([](int&) {}); // this switches to the background copy
221   read = obj.read([](const int& obj) { return obj; });
222   EXPECT_EQ(6, read);
223 }
224 
TEST(LeftRightTest,givenVector_whenWriteThrowsException_thenResetsToOldState)225 TEST(LeftRightTest, givenVector_whenWriteThrowsException_thenResetsToOldState) {
226   LeftRight<vector<int>> obj;
227 
228   obj.write([](vector<int>& obj) { obj.push_back(5); });
229 
230   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
231   EXPECT_THROW(
232       obj.write([](vector<int>& obj) {
233         obj.push_back(6);
234         throw MyException();
235       }),
236       MyException);
237 
238   // check reading it returns old value
239   vector<int> read = obj.read([](const vector<int>& obj) { return obj; });
240   EXPECT_EQ((vector<int>{5}), read);
241 
242   // check changes are also present in background copy
243   obj.write([](vector<int>&) {}); // this switches to the background copy
244   read = obj.read([](const vector<int>& obj) { return obj; });
245   EXPECT_EQ((vector<int>{5}), read);
246 }
247