1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/framework/run_handler.h"
19
20 #include <memory>
21 #include <vector>
22
23 #define EIGEN_USE_THREADS
24 #include "absl/memory/memory.h"
25 #include "absl/synchronization/barrier.h"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/testlib.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/public/session.h"
37 #include "tensorflow/core/public/session_options.h"
38
39 namespace tensorflow {
40 namespace {
41
TEST(RunHandlerUtilTest,TestBasicScheduling)42 TEST(RunHandlerUtilTest, TestBasicScheduling) {
43 int num_threads = 2;
44 int num_handlers = 10;
45
46 std::unique_ptr<RunHandlerPool> pool(
47 new RunHandlerPool(num_threads, num_threads));
48
49 // RunHandler should always be able to run num_threads inter closures
50 absl::Barrier barrier(num_threads);
51
52 BlockingCounter counter(2 * num_handlers * num_threads);
53
54 thread::ThreadPool test_pool(Env::Default(), "test", num_handlers);
55 for (int i = 0; i < num_handlers; ++i) {
56 test_pool.Schedule([&counter, &barrier, &pool, i, num_threads]() {
57 auto handler = pool->Get(i);
58 BlockingCounter local_counter(2 * num_threads);
59 auto intra_thread_pool = handler->AsIntraThreadPoolInterface();
60
61 for (int j = 0; j < num_threads; ++j) {
62 handler->ScheduleInterOpClosure(
63 [&local_counter, &counter, &barrier, i]() {
64 if (i == 2) {
65 barrier.Block();
66 }
67 counter.DecrementCount();
68 local_counter.DecrementCount();
69 });
70 intra_thread_pool->Schedule([&local_counter, &counter]() {
71 counter.DecrementCount();
72 local_counter.DecrementCount();
73 });
74 }
75 local_counter.Wait();
76 });
77 }
78 counter.Wait();
79 }
80
TEST(RunHandlerUtilTest,PrioritySchedulingTest)81 TEST(RunHandlerUtilTest, PrioritySchedulingTest) {
82 int num_threads = 2;
83 std::unique_ptr<RunHandlerPool> pool(
84 new RunHandlerPool(num_threads, num_threads));
85
86 RunOptions::Experimental::RunHandlerPoolOptions options =
87 RunOptions::Experimental::RunHandlerPoolOptions();
88 options.set_priority(2);
89 auto handler1 = pool->Get(/*step_id=*/1, /*timeout_in_ms=*/0, options);
90 options.set_priority(1);
91 auto handler2 = pool->Get(/*step_id=*/2, /*timeout_in_ms=*/0, options);
92 options.set_priority(3);
93 auto handler3 = pool->Get(/*step_id=*/3, /*timeout_in_ms=*/0, options);
94
95 // The active requests should be ordered by priorites.
96 std::vector<int64_t> sorted_active_list =
97 pool->GetActiveHandlerPrioritiesForTesting();
98 EXPECT_EQ(sorted_active_list.size(), 3);
99 EXPECT_EQ(sorted_active_list[0], 3);
100 EXPECT_EQ(sorted_active_list[1], 2);
101 EXPECT_EQ(sorted_active_list[2], 1);
102
103 handler1.reset();
104 options.set_priority(5);
105 auto handler4 = pool->Get(/*step_id=*/4, /*timeout_in_ms=*/0, options);
106 options.set_priority(4);
107 auto handler5 = pool->Get(/*step_id=*/5, /*timeout_in_ms=*/0, options);
108 sorted_active_list = pool->GetActiveHandlerPrioritiesForTesting();
109 EXPECT_EQ(sorted_active_list.size(), 4);
110 EXPECT_EQ(sorted_active_list[0], 5);
111 EXPECT_EQ(sorted_active_list[1], 4);
112 EXPECT_EQ(sorted_active_list[2], 3);
113 EXPECT_EQ(sorted_active_list[3], 1);
114 }
115
TEST(RunHandlerThreadPool,EnqueueTask)116 TEST(RunHandlerThreadPool, EnqueueTask) {
117 Eigen::MaxSizeVector<mutex> waiters_mu(2);
118 waiters_mu.resize(2);
119 Eigen::MaxSizeVector<internal::Waiter> waiters(2);
120 waiters.resize(2);
121 internal::RunHandlerThreadPool run_handler_thread_pool(
122 /*num_blocking_threads=*/0, /*num_non_blocking_threads=*/0,
123 Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
124 &waiters);
125 internal::ThreadWorkSource tws;
126
127 int result = 0;
128 std::function<void()> fn = [&result] { result = 1; };
129 std::function<void()> fn2 = [&result] { result = 2; };
130 run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn);
131 EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 1);
132 run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn2);
133 EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 2);
134 tws.PopBlockingTask().f->f();
135 EXPECT_EQ(result, 1);
136 tws.PopBlockingTask().f->f();
137 EXPECT_EQ(result, 2);
138
139 run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn);
140 EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 1);
141 run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn2);
142 EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 2);
143 tws.PopNonBlockingTask(0, true).f->f();
144 EXPECT_EQ(result, 1);
145 tws.PopNonBlockingTask(0, true).f->f();
146 EXPECT_EQ(result, 2);
147 }
148
TEST(RunHandlerThreadPool,FindTask)149 TEST(RunHandlerThreadPool, FindTask) {
150 Eigen::MaxSizeVector<mutex> waiters_mu(2);
151 waiters_mu.resize(2);
152 Eigen::MaxSizeVector<internal::Waiter> waiters(2);
153 waiters.resize(2);
154 internal::RunHandlerThreadPool run_handler_thread_pool(
155 /*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
156 Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
157 &waiters);
158
159 Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(5);
160 thread_work_sources.resize(5);
161 for (int i = 0; i < 5; ++i) {
162 thread_work_sources[i] = new internal::ThreadWorkSource();
163 }
164
165 {
166 // The thread should search the task following round robin fashion.
167 int result = -1;
168 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
169 /*is_blocking=*/true,
170 [&result] { result = 2; });
171 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
172 /*is_blocking=*/true,
173 [&result] { result = 2; });
174 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
175 /*is_blocking=*/true,
176 [&result] { result = 3; });
177 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
178 /*is_blocking=*/true,
179 [&result] { result = 3; });
180
181 const auto find_blocking_task_from_all_handlers =
182 [&](bool* task_from_blocking_queue, internal::Task* t) {
183 internal::ThreadWorkSource* tws;
184 *t = run_handler_thread_pool.FindTask(
185 /*searching_range_start=*/0, /*searching_range_end=*/5,
186 /*thread_id=*/0,
187 /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
188 /*may_steal_blocking_work=*/true, thread_work_sources,
189 task_from_blocking_queue, &tws);
190 };
191 bool task_from_blocking_queue;
192 internal::Task t;
193 find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
194 EXPECT_EQ(task_from_blocking_queue, true);
195 t.f->f();
196 EXPECT_EQ(result, 2);
197
198 find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
199 EXPECT_EQ(task_from_blocking_queue, true);
200 t.f->f();
201 EXPECT_EQ(result, 3);
202
203 find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
204 EXPECT_EQ(task_from_blocking_queue, true);
205 t.f->f();
206 EXPECT_EQ(result, 2);
207
208 find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
209 EXPECT_EQ(task_from_blocking_queue, true);
210 t.f->f();
211 EXPECT_EQ(result, 3);
212 }
213
214 {
215 // Task out of searching range cannot be found.
216 int result = -1;
217 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
218 /*is_blocking=*/true,
219 [&result] { result = 3; });
220
221 const auto find_blocking_task_from_range =
222 [&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
223 int range_end) {
224 internal::ThreadWorkSource* tws;
225 *t = run_handler_thread_pool.FindTask(
226 range_start, range_end,
227 /*thread_id=*/0,
228 /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
229 /*may_steal_blocking_work=*/true, thread_work_sources,
230 task_from_blocking_queue, &tws);
231 };
232
233 bool task_from_blocking_queue;
234 internal::Task t;
235 find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
236 EXPECT_EQ(t.f, nullptr);
237
238 // Clean up the queue.
239 find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
240 }
241
242 {
243 // The thread should search from start range if the current index is
244 // smaller.
245 int result = -1;
246 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
247 /*is_blocking=*/true,
248 [&result] { result = 2; });
249 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
250 /*is_blocking=*/true,
251 [&result] { result = 3; });
252
253 const auto find_blocking_task_from_range =
254 [&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
255 int range_end) {
256 internal::ThreadWorkSource* tws;
257 *t = run_handler_thread_pool.FindTask(
258 range_start, range_end,
259 /*thread_id=*/0,
260 /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
261 /*may_steal_blocking_work=*/true, thread_work_sources,
262 task_from_blocking_queue, &tws);
263 };
264 bool task_from_blocking_queue;
265 internal::Task t;
266 find_blocking_task_from_range(&task_from_blocking_queue, &t, 3, 5);
267 EXPECT_EQ(task_from_blocking_queue, true);
268 t.f->f();
269 EXPECT_EQ(result, 3);
270
271 find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
272 EXPECT_EQ(task_from_blocking_queue, true);
273 t.f->f();
274 EXPECT_EQ(result, 2);
275 }
276
277 {
278 // The thread should search within the range even if the current index
279 // is larger than searching_range_end;
280 int result = -1;
281 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
282 /*is_blocking=*/true,
283 [&result] { result = 2; });
284
285 const auto find_blocking_task_from_range =
286 [&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
287 int range_end) {
288 internal::ThreadWorkSource* tws;
289 *t = run_handler_thread_pool.FindTask(
290 range_start, range_end,
291 /*thread_id=*/0,
292 /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
293 /*may_steal_blocking_work=*/true, thread_work_sources,
294 task_from_blocking_queue, &tws);
295 };
296 bool task_from_blocking_queue;
297 // Make the current index to be 3.
298 internal::Task t;
299 find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
300 EXPECT_EQ(task_from_blocking_queue, true);
301 t.f->f();
302 EXPECT_EQ(result, 2);
303
304 // Search in a smaller range.
305 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
306 /*is_blocking=*/true,
307 [&result] { result = 2; });
308 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
309 /*is_blocking=*/true,
310 [&result] { result = 3; });
311 find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
312 EXPECT_EQ(task_from_blocking_queue, true);
313 t.f->f();
314 EXPECT_EQ(result, 2);
315
316 // Clean up the queue.
317 find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
318 EXPECT_EQ(task_from_blocking_queue, true);
319 t.f->f();
320 EXPECT_EQ(result, 3);
321 }
322
323 {
324 // We prefer blocking task for blocking threads.
325 int result = -1;
326 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
327 /*is_blocking=*/false,
328 [&result] { result = 2; });
329 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
330 /*is_blocking=*/true,
331 [&result] { result = 2; });
332 const auto blocking_thread_find_task_from_all_handler =
333 [&](bool* task_from_blocking_queue, internal::Task* t) {
334 internal::ThreadWorkSource* tws;
335 *t = run_handler_thread_pool.FindTask(
336 /*searching_range_start=*/0, /*searching_range_end=*/5,
337 /*thread_id=*/0,
338 /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
339 /*may_steal_blocking_work=*/true, thread_work_sources,
340 task_from_blocking_queue, &tws);
341 };
342 bool task_from_blocking_queue;
343 internal::Task t;
344 blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
345 EXPECT_EQ(task_from_blocking_queue, true);
346 t.f->f();
347 EXPECT_EQ(result, 2);
348
349 blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
350 EXPECT_EQ(task_from_blocking_queue, false);
351 t.f->f();
352 EXPECT_EQ(result, 2);
353 }
354
355 {
356 // Nonblocking threads can only pick up non-blocking task.
357 int result = -1;
358 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
359 /*is_blocking=*/false,
360 [&result] { result = 2; });
361 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
362 /*is_blocking=*/true,
363 [&result] { result = 2; });
364
365 const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
366 internal::Task* t,
367 bool is_blocking_thread) {
368 internal::ThreadWorkSource* tws;
369 *t = run_handler_thread_pool.FindTask(
370 /*searching_range_start=*/0, /*searching_range_end=*/5,
371 /*thread_id=*/0,
372 /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
373 is_blocking_thread, thread_work_sources, task_from_blocking_queue,
374 &tws);
375 };
376 bool task_from_blocking_queue;
377 internal::Task t;
378 find_task_from_all_handler(&task_from_blocking_queue, &t,
379 /*is_blocking_thread=*/false);
380 EXPECT_EQ(task_from_blocking_queue, false);
381 t.f->f();
382 EXPECT_EQ(result, 2);
383
384 find_task_from_all_handler(&task_from_blocking_queue, &t,
385 /*is_blocking_thread=*/false);
386 EXPECT_EQ(t.f, nullptr);
387
388 // Clean up the queue.
389 find_task_from_all_handler(&task_from_blocking_queue, &t,
390 /*is_blocking_thread=*/true);
391 }
392
393 {
394 // There is a limit for max_blocking_inflight requests.
395 int result = -1;
396 run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
397 /*is_blocking=*/true,
398 [&result] { result = 2; });
399
400 const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
401 internal::Task* t,
402 bool is_blocking_thread) {
403 internal::ThreadWorkSource* tws;
404 *t = run_handler_thread_pool.FindTask(
405 /*searching_range_start=*/0, /*searching_range_end=*/5,
406 /*thread_id=*/0,
407 /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
408 is_blocking_thread, thread_work_sources, task_from_blocking_queue,
409 &tws);
410 };
411
412 bool task_from_blocking_queue;
413 internal::Task t;
414 find_task_from_all_handler(&task_from_blocking_queue, &t,
415 /*is_blocking_thread=*/false);
416 EXPECT_EQ(task_from_blocking_queue, false);
417 EXPECT_EQ(t.f, nullptr);
418
419 // Clean up the queue.
420 find_task_from_all_handler(&task_from_blocking_queue, &t,
421 /*is_blocking_thread=*/true);
422 }
423
424 for (int i = 0; i < 5; ++i) {
425 delete thread_work_sources[i];
426 }
427 }
428
TEST(RunHandlerThreadPool,RoundRobinExecution)429 TEST(RunHandlerThreadPool, RoundRobinExecution) {
430 // Set up environment for 1 sub thread pool.
431 setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
432 setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "1", true);
433 setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0", true);
434 setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "1", true);
435
436 Eigen::MaxSizeVector<mutex> waiters_mu(1);
437 waiters_mu.resize(1);
438 Eigen::MaxSizeVector<internal::Waiter> waiters(1);
439 waiters.resize(1);
440 internal::RunHandlerThreadPool* run_handler_thread_pool =
441 new internal::RunHandlerThreadPool(
442 /*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
443 Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
444 &waiters);
445 Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(3);
446 thread_work_sources.resize(3);
447 internal::ThreadWorkSource tws[3];
448 for (int i = 0; i < 3; ++i) {
449 tws[i].SetWaiter(1, &waiters[0], &waiters_mu[0]);
450 thread_work_sources[i] = &tws[i];
451 }
452
453 int result = 0;
454 mutex mu;
455 bool ok_to_execute = false;
456 bool ok_to_validate = false;
457 condition_variable function_start;
458 condition_variable function_end;
459 std::vector<std::function<void()>> fns;
460 for (int i = 0; i < 3; ++i) {
461 fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
462 &ok_to_validate, i] {
463 mutex_lock l(mu);
464 while (!ok_to_execute) {
465 function_start.wait(l);
466 }
467 result = i;
468 ok_to_execute = false;
469 ok_to_validate = true;
470 function_end.notify_one();
471 });
472 run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
473 fns[i]);
474 run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
475 fns[i]);
476 }
477 run_handler_thread_pool->Start();
478 run_handler_thread_pool->SetThreadWorkSources(
479 /*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
480
481 // Validate the execution should be roundrobin.
482 mutex_lock l(mu);
483 for (int round = 0; round < 2; ++round) {
484 for (int i = 0; i < 3; ++i) {
485 ok_to_execute = true;
486 function_start.notify_one();
487 while (!ok_to_validate) {
488 function_end.wait(l);
489 }
490 ok_to_validate = false;
491 EXPECT_EQ(result, i);
492 }
493 }
494
495 delete run_handler_thread_pool;
496 }
497
TEST(RunHandlerThreadPool,MultipleSubThreadPool)498 TEST(RunHandlerThreadPool, MultipleSubThreadPool) {
499 // Set up environment for 2 sub thread pools.
500 setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
501 setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "2", true);
502 setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0,0.5",
503 true);
504 setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "0.5,1",
505 true);
506
507 Eigen::MaxSizeVector<mutex> waiters_mu(2);
508 waiters_mu.resize(2);
509 Eigen::MaxSizeVector<internal::Waiter> waiters(2);
510 waiters.resize(2);
511 internal::RunHandlerThreadPool* run_handler_thread_pool =
512 new internal::RunHandlerThreadPool(
513 /*num_blocking_threads=*/2, /*num_non_blocking_threads=*/0,
514 Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
515 &waiters);
516 Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(4);
517 thread_work_sources.resize(4);
518 internal::ThreadWorkSource tws[4];
519 for (int i = 0; i < 4; ++i) {
520 tws[i].SetWaiter(1, &waiters[i / 2], &waiters_mu[i / 2]);
521 thread_work_sources[i] = &tws[i];
522 }
523
524 int result = 0;
525 mutex mu;
526 bool ok_to_execute = false;
527 bool ok_to_validate = false;
528 condition_variable function_start;
529 condition_variable function_end;
530
531 std::vector<std::function<void()>> fns;
532 for (int i = 0; i < 4; ++i) {
533 fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
534 &ok_to_validate, i] {
535 mutex_lock l(mu);
536 while (!ok_to_execute) {
537 function_start.wait(l);
538 }
539 result = i;
540 ok_to_execute = false;
541 ok_to_validate = true;
542 function_end.notify_one();
543 });
544 run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
545 fns[i]);
546 run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
547 fns[i]);
548 }
549 run_handler_thread_pool->StartOneThreadForTesting();
550 run_handler_thread_pool->SetThreadWorkSources(
551 /*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
552 run_handler_thread_pool->SetThreadWorkSources(
553 /*tid=*/1, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
554
555 // Pick task from the given sub thread pool requests in a round robin fashion.
556 mutex_lock l(mu);
557 for (int round = 0; round < 2; ++round) {
558 for (int i = 0; i < 2; ++i) {
559 ok_to_execute = true;
560 function_start.notify_one();
561 while (!ok_to_validate) {
562 function_end.wait(l);
563 }
564 ok_to_validate = false;
565 EXPECT_EQ(result, i);
566 }
567 }
568
569 // Pick task from any task if there is no tasks from the requests in the sub
570 // thread pool.
571 for (int i = 0; i < 2; ++i) {
572 for (int round = 0; round < 2; ++round) {
573 ok_to_execute = true;
574 function_start.notify_one();
575 while (!ok_to_validate) {
576 function_end.wait(l);
577 }
578 ok_to_validate = false;
579 EXPECT_EQ(result, i + 2);
580 }
581 }
582
583 delete run_handler_thread_pool;
584 }
585
DefaultSessionOptions()586 SessionOptions DefaultSessionOptions() {
587 SessionOptions options;
588 (*options.config.mutable_device_count())["CPU"] = 2;
589 return options;
590 }
591
CreateSession()592 std::unique_ptr<Session> CreateSession() {
593 return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
594 }
595
596 class RunHandlerTest : public ::testing::Test {
597 public:
Initialize(std::initializer_list<float> a_values)598 void Initialize(std::initializer_list<float> a_values) {
599 Graph graph(OpRegistry::Global());
600
601 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
602 test::FillValues<float>(&a_tensor, a_values);
603 Node* a = test::graph::Constant(&graph, a_tensor);
604 a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
605 a_ = a->name();
606
607 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
608 test::FillValues<float>(&x_tensor, {1, 1});
609 Node* x = test::graph::Constant(&graph, x_tensor);
610 x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
611 x_ = x->name();
612
613 // y = A * x
614 Node* y = test::graph::Matmul(&graph, a, x, false, false);
615 y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
616 y_ = y->name();
617
618 Node* y_neg = test::graph::Unary(&graph, "Neg", y);
619 y_neg_ = y_neg->name();
620 y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
621
622 Node* z = test::graph::Unary(&graph, "Identity", y_neg);
623 z_ = z->name();
624 z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
625
626 graph.ToGraphDef(&def_);
627
628 ASSERT_EQ(setenv("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", "2", true), 0);
629 ASSERT_EQ(
630 setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "8,8", true),
631 0);
632 ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
633 "0,0.4", true),
634 0);
635 ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
636 "0.4,1", true),
637 0);
638 ASSERT_EQ(setenv("TF_NUM_INTEROP_THREADS", "16", true), 0);
639 }
640
641 string a_;
642 string x_;
643 string y_;
644 string y_neg_;
645 string z_;
646 GraphDef def_;
647 };
648
TEST_F(RunHandlerTest,UseRunHandlerPoolEnableSubPool)649 TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPool) {
650 Initialize({3, 2, -1, 0});
651 auto session = CreateSession();
652 ASSERT_TRUE(session != nullptr);
653 EXPECT_EQ(OkStatus(), session->Create(def_));
654 std::vector<std::pair<string, Tensor>> inputs;
655
656 // Request two targets: one fetch output and one non-fetched output.
657 std::vector<string> output_names = {y_ + ":0"};
658 std::vector<string> target_nodes = {y_neg_};
659 std::vector<Tensor> outputs;
660
661 // Prepares RunOptions and RunMetadata
662 RunOptions run_options;
663 run_options.mutable_experimental()->set_use_run_handler_pool(true);
664
665 Status s = session->Run(run_options, inputs, output_names, target_nodes,
666 &outputs, nullptr);
667 EXPECT_EQ(OkStatus(), s);
668
669 ASSERT_EQ(1, outputs.size());
670 // The first output should be initialized and have the correct
671 // output.
672 auto mat = outputs[0].matrix<float>();
673 ASSERT_TRUE(outputs[0].IsInitialized());
674 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
675 }
676
TEST_F(RunHandlerTest,TestConcurrencyUseRunHandlerPool)677 TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) {
678 Initialize({1, 2, 3, 4});
679 auto session = CreateSession();
680 ASSERT_TRUE(session != nullptr);
681 EXPECT_EQ(OkStatus(), session->Create(def_));
682
683 RunOptions run_options;
684 run_options.mutable_experimental()->set_use_run_handler_pool(true);
685
686 // Fill in the input and ask for the output
687 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
688
689 // Run the graph 1000 times in 4 different threads concurrently.
690 std::vector<string> output_names = {y_ + ":0"};
691 auto fn = [&session, output_names, run_options]() {
692 for (int i = 0; i < 1000; ++i) {
693 std::vector<std::pair<string, Tensor>> inputs;
694 std::vector<Tensor> outputs;
695 // Run the graph
696 Status s = session->Run(run_options, inputs, output_names, {}, &outputs,
697 nullptr);
698 EXPECT_EQ(OkStatus(), s);
699 ASSERT_EQ(1, outputs.size());
700 auto mat = outputs[0].matrix<float>();
701 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
702 }
703 };
704
705 for (int i = 0; i < 4; ++i) {
706 tp->Schedule(fn);
707 }
708
709 // Wait for the functions to finish.
710 delete tp;
711 }
712
TEST_F(RunHandlerTest,UseRunHandlerPoolEnableSubPoolWithPriority)713 TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPoolWithPriority) {
714 Initialize({3, 2, -1, 0});
715 auto session = CreateSession();
716 ASSERT_TRUE(session != nullptr);
717 EXPECT_EQ(OkStatus(), session->Create(def_));
718 std::vector<std::pair<string, Tensor>> inputs;
719
720 // Request two targets: one fetch output and one non-fetched output.
721 std::vector<string> output_names = {y_ + ":0"};
722 std::vector<string> target_nodes = {y_neg_};
723 std::vector<Tensor> outputs;
724
725 // Prepares RunOptions and RunMetadata
726 RunOptions run_options;
727 run_options.mutable_experimental()->set_use_run_handler_pool(true);
728 run_options.mutable_experimental()
729 ->mutable_run_handler_pool_options()
730 ->set_priority(1);
731
732 Status s = session->Run(run_options, inputs, output_names, target_nodes,
733 &outputs, nullptr);
734 EXPECT_EQ(OkStatus(), s);
735
736 ASSERT_EQ(1, outputs.size());
737 // The first output should be initialized and have the correct
738 // output.
739 auto mat = outputs[0].matrix<float>();
740 ASSERT_TRUE(outputs[0].IsInitialized());
741 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
742 }
743
TEST_F(RunHandlerTest,TestConcurrencyUseRunHandlerPoolWithPriority)744 TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPoolWithPriority) {
745 Initialize({1, 2, 3, 4});
746 auto session = CreateSession();
747 ASSERT_TRUE(session != nullptr);
748 EXPECT_EQ(OkStatus(), session->Create(def_));
749
750 // Fill in the input and ask for the output
751 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
752
753 // Run the graph 1000 times in 4 different threads concurrently.
754 std::vector<string> output_names = {y_ + ":0"};
755 auto fn = [&session, output_names]() {
756 for (int i = 0; i < 1000; ++i) {
757 RunOptions run_options;
758 run_options.mutable_experimental()->set_use_run_handler_pool(true);
759 run_options.mutable_experimental()
760 ->mutable_run_handler_pool_options()
761 ->set_priority(i % 4);
762 std::vector<std::pair<string, Tensor>> inputs;
763 std::vector<Tensor> outputs;
764 // Run the graph
765 Status s = session->Run(run_options, inputs, output_names, {}, &outputs,
766 nullptr);
767 EXPECT_EQ(OkStatus(), s);
768 ASSERT_EQ(1, outputs.size());
769 auto mat = outputs[0].matrix<float>();
770 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
771 }
772 };
773
774 for (int i = 0; i < 4; ++i) {
775 tp->Schedule(fn);
776 }
777
778 // Wait for the functions to finish.
779 delete tp;
780 }
781
TEST_F(RunHandlerTest,TestWaitTimeout)782 TEST_F(RunHandlerTest, TestWaitTimeout) {
783 std::unique_ptr<RunHandlerPool> pool(new RunHandlerPool(1, 1));
784
785 // Get the single handler in the pool.
786 std::vector<std::unique_ptr<RunHandler>> blocking_handles;
787 const int32_t kMaxConcurrentHandlers = 128; // Copied from run_handler.cc.
788 blocking_handles.reserve(kMaxConcurrentHandlers);
789 for (int i = 0; i < kMaxConcurrentHandlers; ++i) {
790 blocking_handles.push_back(pool->Get(i));
791 }
792
793 // A subsequent request with a non-zero timeout will fail by returning
794 // nullptr.
795 auto null_handle = pool->Get(128, 1);
796 EXPECT_EQ(null_handle.get(), nullptr);
797
798 // A subsequent request with no timeout will succeed once the blocking handle
799 // is returned.
800 auto tp = std::make_unique<thread::ThreadPool>(Env::Default(), "test", 4);
801 std::atomic<int64_t> release_time;
802
803 tp->Schedule([&blocking_handles, &release_time]() {
804 Env::Default()->SleepForMicroseconds(5000);
805 release_time = EnvTime::NowNanos();
806 blocking_handles[0].reset();
807 });
808
809 auto next_handle = pool->Get(129, 0);
810 EXPECT_GT(EnvTime::NowNanos(), release_time);
811 EXPECT_NE(next_handle.get(), nullptr);
812 }
813
814 } // namespace
815 } // namespace tensorflow
816