xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/cancellation_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/cancellation.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <random>
22 #include <vector>
23 
24 #include "tensorflow/core/lib/core/notification.h"
25 #include "tensorflow/core/lib/core/threadpool.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 
TEST(Cancellation,SimpleNoCancel)31 TEST(Cancellation, SimpleNoCancel) {
32   bool is_cancelled = false;
33   CancellationManager* manager = new CancellationManager();
34   auto token = manager->get_cancellation_token();
35   bool registered = manager->RegisterCallback(
36       token, [&is_cancelled]() { is_cancelled = true; });
37   EXPECT_TRUE(registered);
38   bool deregistered = manager->DeregisterCallback(token);
39   EXPECT_TRUE(deregistered);
40   delete manager;
41   EXPECT_FALSE(is_cancelled);
42 }
43 
TEST(Cancellation,SimpleCancel)44 TEST(Cancellation, SimpleCancel) {
45   bool is_cancelled = false;
46   CancellationManager* manager = new CancellationManager();
47   auto token = manager->get_cancellation_token();
48   bool registered = manager->RegisterCallback(
49       token, [&is_cancelled]() { is_cancelled = true; });
50   EXPECT_TRUE(registered);
51   manager->StartCancel();
52   EXPECT_TRUE(is_cancelled);
53   delete manager;
54 }
55 
TEST(Cancellation,StartCancelTriggersAllCallbacks)56 TEST(Cancellation, StartCancelTriggersAllCallbacks) {
57   bool is_cancelled_1 = false;
58   bool is_cancelled_2 = false;
59   auto manager = std::make_unique<CancellationManager>();
60   auto token_1 = manager->get_cancellation_token();
61   EXPECT_TRUE(manager->RegisterCallbackWithErrorLogging(
62       token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }, "TestCallback"));
63   auto token_2 = manager->get_cancellation_token();
64   EXPECT_TRUE(manager->RegisterCallback(
65       token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }));
66   manager->StartCancel();
67   EXPECT_TRUE(is_cancelled_1);
68   EXPECT_TRUE(is_cancelled_2);
69 }
70 
TEST(Cancellation,StartCancelWithStatusTriggersAllCallbacks)71 TEST(Cancellation, StartCancelWithStatusTriggersAllCallbacks) {
72   bool is_cancelled_1 = false;
73   bool is_cancelled_2 = false;
74   auto manager = std::make_unique<CancellationManager>();
75   auto token_1 = manager->get_cancellation_token();
76   EXPECT_TRUE(manager->RegisterCallbackWithErrorLogging(
77       token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }, "TestCallback"));
78   auto token_2 = manager->get_cancellation_token();
79   EXPECT_TRUE(manager->RegisterCallback(
80       token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }));
81   manager->StartCancelWithStatus(OkStatus());
82   EXPECT_TRUE(is_cancelled_1);
83   EXPECT_TRUE(is_cancelled_2);
84 }
85 
TEST(Cancellation,CancelBeforeRegister)86 TEST(Cancellation, CancelBeforeRegister) {
87   auto manager = std::make_unique<CancellationManager>();
88   auto token = manager->get_cancellation_token();
89   manager->StartCancel();
90   bool registered = manager->RegisterCallback(token, nullptr);
91   EXPECT_FALSE(registered);
92 }
93 
TEST(Cancellation,DeregisterAfterCancel)94 TEST(Cancellation, DeregisterAfterCancel) {
95   bool is_cancelled = false;
96   auto manager = std::make_unique<CancellationManager>();
97   auto token = manager->get_cancellation_token();
98   bool registered = manager->RegisterCallback(
99       token, [&is_cancelled]() { is_cancelled = true; });
100   EXPECT_TRUE(registered);
101   manager->StartCancel();
102   EXPECT_TRUE(is_cancelled);
103   bool deregistered = manager->DeregisterCallback(token);
104   EXPECT_FALSE(deregistered);
105 }
106 
TEST(Cancellation,CancelMultiple)107 TEST(Cancellation, CancelMultiple) {
108   bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
109   auto manager = std::make_unique<CancellationManager>();
110   auto token_1 = manager->get_cancellation_token();
111   bool registered_1 = manager->RegisterCallback(
112       token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
113   EXPECT_TRUE(registered_1);
114   auto token_2 = manager->get_cancellation_token();
115   bool registered_2 = manager->RegisterCallback(
116       token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
117   EXPECT_TRUE(registered_2);
118   EXPECT_FALSE(is_cancelled_1);
119   EXPECT_FALSE(is_cancelled_2);
120   manager->StartCancel();
121   EXPECT_TRUE(is_cancelled_1);
122   EXPECT_TRUE(is_cancelled_2);
123   EXPECT_FALSE(is_cancelled_3);
124   auto token_3 = manager->get_cancellation_token();
125   bool registered_3 = manager->RegisterCallback(
126       token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
127   EXPECT_FALSE(registered_3);
128   EXPECT_FALSE(is_cancelled_3);
129 }
130 
TEST(Cancellation,IsCancelled)131 TEST(Cancellation, IsCancelled) {
132   auto cm = std::make_unique<CancellationManager>();
133   thread::ThreadPool w(Env::Default(), "test", 4);
134   std::vector<Notification> done(8);
135   for (size_t i = 0; i < done.size(); ++i) {
136     Notification* n = &done[i];
137     w.Schedule([n, &cm]() {
138       while (!cm->IsCancelled()) {
139       }
140       ASSERT_FALSE(cm->IsCancelling());
141       n->Notify();
142     });
143   }
144   Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
145   cm->StartCancel();
146   for (size_t i = 0; i < done.size(); ++i) {
147     done[i].WaitForNotification();
148   }
149 }
150 
TEST(Cancellation,IsCancelling)151 TEST(Cancellation, IsCancelling) {
152   CancellationManager cm;
153   Notification started_cancelling;
154   Notification can_finish_cancel;
155   Notification cancel_done;
156   thread::ThreadPool w(Env::Default(), "test", 1);
157   auto token = cm.get_cancellation_token();
158   ASSERT_TRUE(
159       cm.RegisterCallback(token, [&started_cancelling, &can_finish_cancel]() {
160         started_cancelling.Notify();
161         can_finish_cancel.WaitForNotification();
162       }));
163   w.Schedule([&cm, &cancel_done]() {
164     cm.StartCancel();
165     cancel_done.Notify();
166   });
167   started_cancelling.WaitForNotification();
168   ASSERT_TRUE(cm.IsCancelling());
169   can_finish_cancel.Notify();
170   cancel_done.WaitForNotification();
171   ASSERT_FALSE(cm.IsCancelling());
172   ASSERT_TRUE(cm.IsCancelled());
173 }
174 
TEST(Cancellation,TryDeregisterWithoutCancel)175 TEST(Cancellation, TryDeregisterWithoutCancel) {
176   bool is_cancelled = false;
177   auto manager = std::make_unique<CancellationManager>();
178   auto token = manager->get_cancellation_token();
179   bool registered = manager->RegisterCallback(
180       token, [&is_cancelled]() { is_cancelled = true; });
181   EXPECT_TRUE(registered);
182   bool deregistered = manager->TryDeregisterCallback(token);
183   EXPECT_TRUE(deregistered);
184   EXPECT_FALSE(is_cancelled);
185 }
186 
TEST(Cancellation,TryDeregisterAfterCancel)187 TEST(Cancellation, TryDeregisterAfterCancel) {
188   bool is_cancelled = false;
189   auto manager = std::make_unique<CancellationManager>();
190   auto token = manager->get_cancellation_token();
191   bool registered = manager->RegisterCallback(
192       token, [&is_cancelled]() { is_cancelled = true; });
193   EXPECT_TRUE(registered);
194   manager->StartCancel();
195   EXPECT_TRUE(is_cancelled);
196   bool deregistered = manager->TryDeregisterCallback(token);
197   EXPECT_FALSE(deregistered);
198 }
199 
TEST(Cancellation,TryDeregisterDuringCancel)200 TEST(Cancellation, TryDeregisterDuringCancel) {
201   Notification cancel_started, finish_callback, cancel_complete;
202   auto manager = std::make_unique<CancellationManager>();
203   auto token = manager->get_cancellation_token();
204   bool registered = manager->RegisterCallback(token, [&]() {
205     cancel_started.Notify();
206     finish_callback.WaitForNotification();
207   });
208   EXPECT_TRUE(registered);
209 
210   thread::ThreadPool w(Env::Default(), "test", 1);
211   w.Schedule([&]() {
212     manager->StartCancel();
213     cancel_complete.Notify();
214   });
215   cancel_started.WaitForNotification();
216 
217   bool deregistered = manager->TryDeregisterCallback(token);
218   EXPECT_FALSE(deregistered);
219 
220   finish_callback.Notify();
221   cancel_complete.WaitForNotification();
222 }
223 
TEST(Cancellation,Parent_CancelManyChildren)224 TEST(Cancellation, Parent_CancelManyChildren) {
225   CancellationManager parent;
226   std::vector<std::unique_ptr<CancellationManager>> children;
227   for (size_t i = 0; i < 5; ++i) {
228     children.push_back(absl::make_unique<CancellationManager>(&parent));
229     EXPECT_FALSE(children.back()->IsCancelled());
230   }
231   parent.StartCancel();
232   for (auto& child : children) {
233     EXPECT_TRUE(child->IsCancelled());
234   }
235 }
236 
TEST(Cancellation,Parent_NotCancelled)237 TEST(Cancellation, Parent_NotCancelled) {
238   CancellationManager parent;
239   {
240     CancellationManager child(&parent);
241     child.StartCancel();
242     EXPECT_TRUE(child.IsCancelled());
243   }
244   EXPECT_FALSE(parent.IsCancelled());
245 }
246 
TEST(Cancellation,Parent_AlreadyCancelled)247 TEST(Cancellation, Parent_AlreadyCancelled) {
248   CancellationManager parent;
249   parent.StartCancel();
250   EXPECT_TRUE(parent.IsCancelled());
251 
252   CancellationManager child(&parent);
253   EXPECT_TRUE(child.IsCancelled());
254 }
255 
TEST(Cancellation,Parent_RandomDestructionOrder)256 TEST(Cancellation, Parent_RandomDestructionOrder) {
257   CancellationManager parent;
258   std::random_device rd;
259   std::mt19937 g(rd());
260 
261   // To cover the linked-list codepaths, perform multiple randomized rounds of
262   // registering and deregistering children with `parent`.
263   for (int rounds = 0; rounds < 100; ++rounds) {
264     std::vector<std::unique_ptr<CancellationManager>> children;
265 
266     // 1. Register a random number of children with the parent.
267     std::uniform_int_distribution<int> dist(1, 9);
268     const size_t round_size = dist(rd);
269     for (size_t i = 0; i < round_size; ++i) {
270       children.push_back(absl::make_unique<CancellationManager>(&parent));
271       EXPECT_FALSE(children.back()->IsCancelled());
272     }
273 
274     // 2. Deregister the children in a random order.
275     std::vector<size_t> destruction_order(round_size);
276     std::iota(destruction_order.begin(), destruction_order.end(), 0);
277     std::shuffle(destruction_order.begin(), destruction_order.end(), g);
278     for (size_t index : destruction_order) {
279       children[index].reset();
280     }
281   }
282 }
283 
284 }  // namespace tensorflow
285