1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/cancellation.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <random>
22 #include <vector>
23
24 #include "tensorflow/core/lib/core/notification.h"
25 #include "tensorflow/core/lib/core/threadpool.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30
TEST(Cancellation,SimpleNoCancel)31 TEST(Cancellation, SimpleNoCancel) {
32 bool is_cancelled = false;
33 CancellationManager* manager = new CancellationManager();
34 auto token = manager->get_cancellation_token();
35 bool registered = manager->RegisterCallback(
36 token, [&is_cancelled]() { is_cancelled = true; });
37 EXPECT_TRUE(registered);
38 bool deregistered = manager->DeregisterCallback(token);
39 EXPECT_TRUE(deregistered);
40 delete manager;
41 EXPECT_FALSE(is_cancelled);
42 }
43
TEST(Cancellation,SimpleCancel)44 TEST(Cancellation, SimpleCancel) {
45 bool is_cancelled = false;
46 CancellationManager* manager = new CancellationManager();
47 auto token = manager->get_cancellation_token();
48 bool registered = manager->RegisterCallback(
49 token, [&is_cancelled]() { is_cancelled = true; });
50 EXPECT_TRUE(registered);
51 manager->StartCancel();
52 EXPECT_TRUE(is_cancelled);
53 delete manager;
54 }
55
TEST(Cancellation,StartCancelTriggersAllCallbacks)56 TEST(Cancellation, StartCancelTriggersAllCallbacks) {
57 bool is_cancelled_1 = false;
58 bool is_cancelled_2 = false;
59 auto manager = std::make_unique<CancellationManager>();
60 auto token_1 = manager->get_cancellation_token();
61 EXPECT_TRUE(manager->RegisterCallbackWithErrorLogging(
62 token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }, "TestCallback"));
63 auto token_2 = manager->get_cancellation_token();
64 EXPECT_TRUE(manager->RegisterCallback(
65 token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }));
66 manager->StartCancel();
67 EXPECT_TRUE(is_cancelled_1);
68 EXPECT_TRUE(is_cancelled_2);
69 }
70
TEST(Cancellation,StartCancelWithStatusTriggersAllCallbacks)71 TEST(Cancellation, StartCancelWithStatusTriggersAllCallbacks) {
72 bool is_cancelled_1 = false;
73 bool is_cancelled_2 = false;
74 auto manager = std::make_unique<CancellationManager>();
75 auto token_1 = manager->get_cancellation_token();
76 EXPECT_TRUE(manager->RegisterCallbackWithErrorLogging(
77 token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }, "TestCallback"));
78 auto token_2 = manager->get_cancellation_token();
79 EXPECT_TRUE(manager->RegisterCallback(
80 token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }));
81 manager->StartCancelWithStatus(OkStatus());
82 EXPECT_TRUE(is_cancelled_1);
83 EXPECT_TRUE(is_cancelled_2);
84 }
85
TEST(Cancellation,CancelBeforeRegister)86 TEST(Cancellation, CancelBeforeRegister) {
87 auto manager = std::make_unique<CancellationManager>();
88 auto token = manager->get_cancellation_token();
89 manager->StartCancel();
90 bool registered = manager->RegisterCallback(token, nullptr);
91 EXPECT_FALSE(registered);
92 }
93
TEST(Cancellation,DeregisterAfterCancel)94 TEST(Cancellation, DeregisterAfterCancel) {
95 bool is_cancelled = false;
96 auto manager = std::make_unique<CancellationManager>();
97 auto token = manager->get_cancellation_token();
98 bool registered = manager->RegisterCallback(
99 token, [&is_cancelled]() { is_cancelled = true; });
100 EXPECT_TRUE(registered);
101 manager->StartCancel();
102 EXPECT_TRUE(is_cancelled);
103 bool deregistered = manager->DeregisterCallback(token);
104 EXPECT_FALSE(deregistered);
105 }
106
TEST(Cancellation,CancelMultiple)107 TEST(Cancellation, CancelMultiple) {
108 bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
109 auto manager = std::make_unique<CancellationManager>();
110 auto token_1 = manager->get_cancellation_token();
111 bool registered_1 = manager->RegisterCallback(
112 token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
113 EXPECT_TRUE(registered_1);
114 auto token_2 = manager->get_cancellation_token();
115 bool registered_2 = manager->RegisterCallback(
116 token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
117 EXPECT_TRUE(registered_2);
118 EXPECT_FALSE(is_cancelled_1);
119 EXPECT_FALSE(is_cancelled_2);
120 manager->StartCancel();
121 EXPECT_TRUE(is_cancelled_1);
122 EXPECT_TRUE(is_cancelled_2);
123 EXPECT_FALSE(is_cancelled_3);
124 auto token_3 = manager->get_cancellation_token();
125 bool registered_3 = manager->RegisterCallback(
126 token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
127 EXPECT_FALSE(registered_3);
128 EXPECT_FALSE(is_cancelled_3);
129 }
130
TEST(Cancellation,IsCancelled)131 TEST(Cancellation, IsCancelled) {
132 auto cm = std::make_unique<CancellationManager>();
133 thread::ThreadPool w(Env::Default(), "test", 4);
134 std::vector<Notification> done(8);
135 for (size_t i = 0; i < done.size(); ++i) {
136 Notification* n = &done[i];
137 w.Schedule([n, &cm]() {
138 while (!cm->IsCancelled()) {
139 }
140 ASSERT_FALSE(cm->IsCancelling());
141 n->Notify();
142 });
143 }
144 Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
145 cm->StartCancel();
146 for (size_t i = 0; i < done.size(); ++i) {
147 done[i].WaitForNotification();
148 }
149 }
150
TEST(Cancellation,IsCancelling)151 TEST(Cancellation, IsCancelling) {
152 CancellationManager cm;
153 Notification started_cancelling;
154 Notification can_finish_cancel;
155 Notification cancel_done;
156 thread::ThreadPool w(Env::Default(), "test", 1);
157 auto token = cm.get_cancellation_token();
158 ASSERT_TRUE(
159 cm.RegisterCallback(token, [&started_cancelling, &can_finish_cancel]() {
160 started_cancelling.Notify();
161 can_finish_cancel.WaitForNotification();
162 }));
163 w.Schedule([&cm, &cancel_done]() {
164 cm.StartCancel();
165 cancel_done.Notify();
166 });
167 started_cancelling.WaitForNotification();
168 ASSERT_TRUE(cm.IsCancelling());
169 can_finish_cancel.Notify();
170 cancel_done.WaitForNotification();
171 ASSERT_FALSE(cm.IsCancelling());
172 ASSERT_TRUE(cm.IsCancelled());
173 }
174
TEST(Cancellation,TryDeregisterWithoutCancel)175 TEST(Cancellation, TryDeregisterWithoutCancel) {
176 bool is_cancelled = false;
177 auto manager = std::make_unique<CancellationManager>();
178 auto token = manager->get_cancellation_token();
179 bool registered = manager->RegisterCallback(
180 token, [&is_cancelled]() { is_cancelled = true; });
181 EXPECT_TRUE(registered);
182 bool deregistered = manager->TryDeregisterCallback(token);
183 EXPECT_TRUE(deregistered);
184 EXPECT_FALSE(is_cancelled);
185 }
186
TEST(Cancellation,TryDeregisterAfterCancel)187 TEST(Cancellation, TryDeregisterAfterCancel) {
188 bool is_cancelled = false;
189 auto manager = std::make_unique<CancellationManager>();
190 auto token = manager->get_cancellation_token();
191 bool registered = manager->RegisterCallback(
192 token, [&is_cancelled]() { is_cancelled = true; });
193 EXPECT_TRUE(registered);
194 manager->StartCancel();
195 EXPECT_TRUE(is_cancelled);
196 bool deregistered = manager->TryDeregisterCallback(token);
197 EXPECT_FALSE(deregistered);
198 }
199
TEST(Cancellation,TryDeregisterDuringCancel)200 TEST(Cancellation, TryDeregisterDuringCancel) {
201 Notification cancel_started, finish_callback, cancel_complete;
202 auto manager = std::make_unique<CancellationManager>();
203 auto token = manager->get_cancellation_token();
204 bool registered = manager->RegisterCallback(token, [&]() {
205 cancel_started.Notify();
206 finish_callback.WaitForNotification();
207 });
208 EXPECT_TRUE(registered);
209
210 thread::ThreadPool w(Env::Default(), "test", 1);
211 w.Schedule([&]() {
212 manager->StartCancel();
213 cancel_complete.Notify();
214 });
215 cancel_started.WaitForNotification();
216
217 bool deregistered = manager->TryDeregisterCallback(token);
218 EXPECT_FALSE(deregistered);
219
220 finish_callback.Notify();
221 cancel_complete.WaitForNotification();
222 }
223
TEST(Cancellation,Parent_CancelManyChildren)224 TEST(Cancellation, Parent_CancelManyChildren) {
225 CancellationManager parent;
226 std::vector<std::unique_ptr<CancellationManager>> children;
227 for (size_t i = 0; i < 5; ++i) {
228 children.push_back(absl::make_unique<CancellationManager>(&parent));
229 EXPECT_FALSE(children.back()->IsCancelled());
230 }
231 parent.StartCancel();
232 for (auto& child : children) {
233 EXPECT_TRUE(child->IsCancelled());
234 }
235 }
236
TEST(Cancellation,Parent_NotCancelled)237 TEST(Cancellation, Parent_NotCancelled) {
238 CancellationManager parent;
239 {
240 CancellationManager child(&parent);
241 child.StartCancel();
242 EXPECT_TRUE(child.IsCancelled());
243 }
244 EXPECT_FALSE(parent.IsCancelled());
245 }
246
TEST(Cancellation,Parent_AlreadyCancelled)247 TEST(Cancellation, Parent_AlreadyCancelled) {
248 CancellationManager parent;
249 parent.StartCancel();
250 EXPECT_TRUE(parent.IsCancelled());
251
252 CancellationManager child(&parent);
253 EXPECT_TRUE(child.IsCancelled());
254 }
255
TEST(Cancellation,Parent_RandomDestructionOrder)256 TEST(Cancellation, Parent_RandomDestructionOrder) {
257 CancellationManager parent;
258 std::random_device rd;
259 std::mt19937 g(rd());
260
261 // To cover the linked-list codepaths, perform multiple randomized rounds of
262 // registering and deregistering children with `parent`.
263 for (int rounds = 0; rounds < 100; ++rounds) {
264 std::vector<std::unique_ptr<CancellationManager>> children;
265
266 // 1. Register a random number of children with the parent.
267 std::uniform_int_distribution<int> dist(1, 9);
268 const size_t round_size = dist(rd);
269 for (size_t i = 0; i < round_size; ++i) {
270 children.push_back(absl::make_unique<CancellationManager>(&parent));
271 EXPECT_FALSE(children.back()->IsCancelled());
272 }
273
274 // 2. Deregister the children in a random order.
275 std::vector<size_t> destruction_order(round_size);
276 std::iota(destruction_order.begin(), destruction_order.end(), 0);
277 std::shuffle(destruction_order.begin(), destruction_order.end(), g);
278 for (size_t index : destruction_order) {
279 children[index].reset();
280 }
281 }
282 }
283
284 } // namespace tensorflow
285