xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/run_handler_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/framework/run_handler.h"
19 
20 #include <memory>
21 #include <vector>
22 
23 #define EIGEN_USE_THREADS
24 #include "absl/memory/memory.h"
25 #include "absl/synchronization/barrier.h"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/testlib.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/public/session.h"
37 #include "tensorflow/core/public/session_options.h"
38 
39 namespace tensorflow {
40 namespace {
41 
TEST(RunHandlerUtilTest,TestBasicScheduling)42 TEST(RunHandlerUtilTest, TestBasicScheduling) {
43   int num_threads = 2;
44   int num_handlers = 10;
45 
46   std::unique_ptr<RunHandlerPool> pool(
47       new RunHandlerPool(num_threads, num_threads));
48 
49   // RunHandler should always be able to run num_threads inter closures
50   absl::Barrier barrier(num_threads);
51 
52   BlockingCounter counter(2 * num_handlers * num_threads);
53 
54   thread::ThreadPool test_pool(Env::Default(), "test", num_handlers);
55   for (int i = 0; i < num_handlers; ++i) {
56     test_pool.Schedule([&counter, &barrier, &pool, i, num_threads]() {
57       auto handler = pool->Get(i);
58       BlockingCounter local_counter(2 * num_threads);
59       auto intra_thread_pool = handler->AsIntraThreadPoolInterface();
60 
61       for (int j = 0; j < num_threads; ++j) {
62         handler->ScheduleInterOpClosure(
63             [&local_counter, &counter, &barrier, i]() {
64               if (i == 2) {
65                 barrier.Block();
66               }
67               counter.DecrementCount();
68               local_counter.DecrementCount();
69             });
70         intra_thread_pool->Schedule([&local_counter, &counter]() {
71           counter.DecrementCount();
72           local_counter.DecrementCount();
73         });
74       }
75       local_counter.Wait();
76     });
77   }
78   counter.Wait();
79 }
80 
TEST(RunHandlerUtilTest,PrioritySchedulingTest)81 TEST(RunHandlerUtilTest, PrioritySchedulingTest) {
82   int num_threads = 2;
83   std::unique_ptr<RunHandlerPool> pool(
84       new RunHandlerPool(num_threads, num_threads));
85 
86   RunOptions::Experimental::RunHandlerPoolOptions options =
87       RunOptions::Experimental::RunHandlerPoolOptions();
88   options.set_priority(2);
89   auto handler1 = pool->Get(/*step_id=*/1, /*timeout_in_ms=*/0, options);
90   options.set_priority(1);
91   auto handler2 = pool->Get(/*step_id=*/2, /*timeout_in_ms=*/0, options);
92   options.set_priority(3);
93   auto handler3 = pool->Get(/*step_id=*/3, /*timeout_in_ms=*/0, options);
94 
95   // The active requests should be ordered by priorites.
96   std::vector<int64_t> sorted_active_list =
97       pool->GetActiveHandlerPrioritiesForTesting();
98   EXPECT_EQ(sorted_active_list.size(), 3);
99   EXPECT_EQ(sorted_active_list[0], 3);
100   EXPECT_EQ(sorted_active_list[1], 2);
101   EXPECT_EQ(sorted_active_list[2], 1);
102 
103   handler1.reset();
104   options.set_priority(5);
105   auto handler4 = pool->Get(/*step_id=*/4, /*timeout_in_ms=*/0, options);
106   options.set_priority(4);
107   auto handler5 = pool->Get(/*step_id=*/5, /*timeout_in_ms=*/0, options);
108   sorted_active_list = pool->GetActiveHandlerPrioritiesForTesting();
109   EXPECT_EQ(sorted_active_list.size(), 4);
110   EXPECT_EQ(sorted_active_list[0], 5);
111   EXPECT_EQ(sorted_active_list[1], 4);
112   EXPECT_EQ(sorted_active_list[2], 3);
113   EXPECT_EQ(sorted_active_list[3], 1);
114 }
115 
TEST(RunHandlerThreadPool,EnqueueTask)116 TEST(RunHandlerThreadPool, EnqueueTask) {
117   Eigen::MaxSizeVector<mutex> waiters_mu(2);
118   waiters_mu.resize(2);
119   Eigen::MaxSizeVector<internal::Waiter> waiters(2);
120   waiters.resize(2);
121   internal::RunHandlerThreadPool run_handler_thread_pool(
122       /*num_blocking_threads=*/0, /*num_non_blocking_threads=*/0,
123       Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
124       &waiters);
125   internal::ThreadWorkSource tws;
126 
127   int result = 0;
128   std::function<void()> fn = [&result] { result = 1; };
129   std::function<void()> fn2 = [&result] { result = 2; };
130   run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn);
131   EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 1);
132   run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/true, fn2);
133   EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/true), 2);
134   tws.PopBlockingTask().f->f();
135   EXPECT_EQ(result, 1);
136   tws.PopBlockingTask().f->f();
137   EXPECT_EQ(result, 2);
138 
139   run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn);
140   EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 1);
141   run_handler_thread_pool.AddWorkToQueue(&tws, /*is_blocking=*/false, fn2);
142   EXPECT_EQ(tws.TaskQueueSize(/*is_blocking=*/false), 2);
143   tws.PopNonBlockingTask(0, true).f->f();
144   EXPECT_EQ(result, 1);
145   tws.PopNonBlockingTask(0, true).f->f();
146   EXPECT_EQ(result, 2);
147 }
148 
TEST(RunHandlerThreadPool,FindTask)149 TEST(RunHandlerThreadPool, FindTask) {
150   Eigen::MaxSizeVector<mutex> waiters_mu(2);
151   waiters_mu.resize(2);
152   Eigen::MaxSizeVector<internal::Waiter> waiters(2);
153   waiters.resize(2);
154   internal::RunHandlerThreadPool run_handler_thread_pool(
155       /*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
156       Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
157       &waiters);
158 
159   Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(5);
160   thread_work_sources.resize(5);
161   for (int i = 0; i < 5; ++i) {
162     thread_work_sources[i] = new internal::ThreadWorkSource();
163   }
164 
165   {
166     // The thread should search the task following round robin fashion.
167     int result = -1;
168     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
169                                            /*is_blocking=*/true,
170                                            [&result] { result = 2; });
171     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
172                                            /*is_blocking=*/true,
173                                            [&result] { result = 2; });
174     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
175                                            /*is_blocking=*/true,
176                                            [&result] { result = 3; });
177     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
178                                            /*is_blocking=*/true,
179                                            [&result] { result = 3; });
180 
181     const auto find_blocking_task_from_all_handlers =
182         [&](bool* task_from_blocking_queue, internal::Task* t) {
183           internal::ThreadWorkSource* tws;
184           *t = run_handler_thread_pool.FindTask(
185               /*searching_range_start=*/0, /*searching_range_end=*/5,
186               /*thread_id=*/0,
187               /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
188               /*may_steal_blocking_work=*/true, thread_work_sources,
189               task_from_blocking_queue, &tws);
190         };
191     bool task_from_blocking_queue;
192     internal::Task t;
193     find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
194     EXPECT_EQ(task_from_blocking_queue, true);
195     t.f->f();
196     EXPECT_EQ(result, 2);
197 
198     find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
199     EXPECT_EQ(task_from_blocking_queue, true);
200     t.f->f();
201     EXPECT_EQ(result, 3);
202 
203     find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
204     EXPECT_EQ(task_from_blocking_queue, true);
205     t.f->f();
206     EXPECT_EQ(result, 2);
207 
208     find_blocking_task_from_all_handlers(&task_from_blocking_queue, &t);
209     EXPECT_EQ(task_from_blocking_queue, true);
210     t.f->f();
211     EXPECT_EQ(result, 3);
212   }
213 
214   {
215     // Task out of searching range cannot be found.
216     int result = -1;
217     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
218                                            /*is_blocking=*/true,
219                                            [&result] { result = 3; });
220 
221     const auto find_blocking_task_from_range =
222         [&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
223             int range_end) {
224           internal::ThreadWorkSource* tws;
225           *t = run_handler_thread_pool.FindTask(
226               range_start, range_end,
227               /*thread_id=*/0,
228               /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
229               /*may_steal_blocking_work=*/true, thread_work_sources,
230               task_from_blocking_queue, &tws);
231         };
232 
233     bool task_from_blocking_queue;
234     internal::Task t;
235     find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
236     EXPECT_EQ(t.f, nullptr);
237 
238     // Clean up the queue.
239     find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
240   }
241 
242   {
243     // The thread should search from start range if the current index is
244     // smaller.
245     int result = -1;
246     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
247                                            /*is_blocking=*/true,
248                                            [&result] { result = 2; });
249     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
250                                            /*is_blocking=*/true,
251                                            [&result] { result = 3; });
252 
253     const auto find_blocking_task_from_range =
254         [&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
255             int range_end) {
256           internal::ThreadWorkSource* tws;
257           *t = run_handler_thread_pool.FindTask(
258               range_start, range_end,
259               /*thread_id=*/0,
260               /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
261               /*may_steal_blocking_work=*/true, thread_work_sources,
262               task_from_blocking_queue, &tws);
263         };
264     bool task_from_blocking_queue;
265     internal::Task t;
266     find_blocking_task_from_range(&task_from_blocking_queue, &t, 3, 5);
267     EXPECT_EQ(task_from_blocking_queue, true);
268     t.f->f();
269     EXPECT_EQ(result, 3);
270 
271     find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
272     EXPECT_EQ(task_from_blocking_queue, true);
273     t.f->f();
274     EXPECT_EQ(result, 2);
275   }
276 
277   {
278     // The thread should search within the range even if the current index
279     // is larger than searching_range_end;
280     int result = -1;
281     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
282                                            /*is_blocking=*/true,
283                                            [&result] { result = 2; });
284 
285     const auto find_blocking_task_from_range =
286         [&](bool* task_from_blocking_queue, internal::Task* t, int range_start,
287             int range_end) {
288           internal::ThreadWorkSource* tws;
289           *t = run_handler_thread_pool.FindTask(
290               range_start, range_end,
291               /*thread_id=*/0,
292               /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
293               /*may_steal_blocking_work=*/true, thread_work_sources,
294               task_from_blocking_queue, &tws);
295         };
296     bool task_from_blocking_queue;
297     // Make the current index to be 3.
298     internal::Task t;
299     find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
300     EXPECT_EQ(task_from_blocking_queue, true);
301     t.f->f();
302     EXPECT_EQ(result, 2);
303 
304     // Search in a smaller range.
305     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
306                                            /*is_blocking=*/true,
307                                            [&result] { result = 2; });
308     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[3],
309                                            /*is_blocking=*/true,
310                                            [&result] { result = 3; });
311     find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 3);
312     EXPECT_EQ(task_from_blocking_queue, true);
313     t.f->f();
314     EXPECT_EQ(result, 2);
315 
316     // Clean up the queue.
317     find_blocking_task_from_range(&task_from_blocking_queue, &t, 0, 5);
318     EXPECT_EQ(task_from_blocking_queue, true);
319     t.f->f();
320     EXPECT_EQ(result, 3);
321   }
322 
323   {
324     // We prefer blocking task for blocking threads.
325     int result = -1;
326     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
327                                            /*is_blocking=*/false,
328                                            [&result] { result = 2; });
329     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
330                                            /*is_blocking=*/true,
331                                            [&result] { result = 2; });
332     const auto blocking_thread_find_task_from_all_handler =
333         [&](bool* task_from_blocking_queue, internal::Task* t) {
334           internal::ThreadWorkSource* tws;
335           *t = run_handler_thread_pool.FindTask(
336               /*searching_range_start=*/0, /*searching_range_end=*/5,
337               /*thread_id=*/0,
338               /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
339               /*may_steal_blocking_work=*/true, thread_work_sources,
340               task_from_blocking_queue, &tws);
341         };
342     bool task_from_blocking_queue;
343     internal::Task t;
344     blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
345     EXPECT_EQ(task_from_blocking_queue, true);
346     t.f->f();
347     EXPECT_EQ(result, 2);
348 
349     blocking_thread_find_task_from_all_handler(&task_from_blocking_queue, &t);
350     EXPECT_EQ(task_from_blocking_queue, false);
351     t.f->f();
352     EXPECT_EQ(result, 2);
353   }
354 
355   {
356     // Nonblocking threads can only pick up non-blocking task.
357     int result = -1;
358     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
359                                            /*is_blocking=*/false,
360                                            [&result] { result = 2; });
361     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
362                                            /*is_blocking=*/true,
363                                            [&result] { result = 2; });
364 
365     const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
366                                                 internal::Task* t,
367                                                 bool is_blocking_thread) {
368       internal::ThreadWorkSource* tws;
369       *t = run_handler_thread_pool.FindTask(
370           /*searching_range_start=*/0, /*searching_range_end=*/5,
371           /*thread_id=*/0,
372           /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
373           is_blocking_thread, thread_work_sources, task_from_blocking_queue,
374           &tws);
375     };
376     bool task_from_blocking_queue;
377     internal::Task t;
378     find_task_from_all_handler(&task_from_blocking_queue, &t,
379                                /*is_blocking_thread=*/false);
380     EXPECT_EQ(task_from_blocking_queue, false);
381     t.f->f();
382     EXPECT_EQ(result, 2);
383 
384     find_task_from_all_handler(&task_from_blocking_queue, &t,
385                                /*is_blocking_thread=*/false);
386     EXPECT_EQ(t.f, nullptr);
387 
388     // Clean up the queue.
389     find_task_from_all_handler(&task_from_blocking_queue, &t,
390                                /*is_blocking_thread=*/true);
391   }
392 
393   {
394     // There is a limit for max_blocking_inflight requests.
395     int result = -1;
396     run_handler_thread_pool.AddWorkToQueue(thread_work_sources[2],
397                                            /*is_blocking=*/true,
398                                            [&result] { result = 2; });
399 
400     const auto find_task_from_all_handler = [&](bool* task_from_blocking_queue,
401                                                 internal::Task* t,
402                                                 bool is_blocking_thread) {
403       internal::ThreadWorkSource* tws;
404       *t = run_handler_thread_pool.FindTask(
405           /*searching_range_start=*/0, /*searching_range_end=*/5,
406           /*thread_id=*/0,
407           /*sub_thread_pool_id=*/0, /*max_blocking_inflight=*/10,
408           is_blocking_thread, thread_work_sources, task_from_blocking_queue,
409           &tws);
410     };
411 
412     bool task_from_blocking_queue;
413     internal::Task t;
414     find_task_from_all_handler(&task_from_blocking_queue, &t,
415                                /*is_blocking_thread=*/false);
416     EXPECT_EQ(task_from_blocking_queue, false);
417     EXPECT_EQ(t.f, nullptr);
418 
419     // Clean up the queue.
420     find_task_from_all_handler(&task_from_blocking_queue, &t,
421                                /*is_blocking_thread=*/true);
422   }
423 
424   for (int i = 0; i < 5; ++i) {
425     delete thread_work_sources[i];
426   }
427 }
428 
TEST(RunHandlerThreadPool,RoundRobinExecution)429 TEST(RunHandlerThreadPool, RoundRobinExecution) {
430   // Set up environment for 1 sub thread pool.
431   setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
432   setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "1", true);
433   setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0", true);
434   setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "1", true);
435 
436   Eigen::MaxSizeVector<mutex> waiters_mu(1);
437   waiters_mu.resize(1);
438   Eigen::MaxSizeVector<internal::Waiter> waiters(1);
439   waiters.resize(1);
440   internal::RunHandlerThreadPool* run_handler_thread_pool =
441       new internal::RunHandlerThreadPool(
442           /*num_blocking_threads=*/1, /*num_non_blocking_threads=*/0,
443           Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
444           &waiters);
445   Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(3);
446   thread_work_sources.resize(3);
447   internal::ThreadWorkSource tws[3];
448   for (int i = 0; i < 3; ++i) {
449     tws[i].SetWaiter(1, &waiters[0], &waiters_mu[0]);
450     thread_work_sources[i] = &tws[i];
451   }
452 
453   int result = 0;
454   mutex mu;
455   bool ok_to_execute = false;
456   bool ok_to_validate = false;
457   condition_variable function_start;
458   condition_variable function_end;
459   std::vector<std::function<void()>> fns;
460   for (int i = 0; i < 3; ++i) {
461     fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
462                    &ok_to_validate, i] {
463       mutex_lock l(mu);
464       while (!ok_to_execute) {
465         function_start.wait(l);
466       }
467       result = i;
468       ok_to_execute = false;
469       ok_to_validate = true;
470       function_end.notify_one();
471     });
472     run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
473                                             fns[i]);
474     run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
475                                             fns[i]);
476   }
477   run_handler_thread_pool->Start();
478   run_handler_thread_pool->SetThreadWorkSources(
479       /*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
480 
481   // Validate the execution should be roundrobin.
482   mutex_lock l(mu);
483   for (int round = 0; round < 2; ++round) {
484     for (int i = 0; i < 3; ++i) {
485       ok_to_execute = true;
486       function_start.notify_one();
487       while (!ok_to_validate) {
488         function_end.wait(l);
489       }
490       ok_to_validate = false;
491       EXPECT_EQ(result, i);
492     }
493   }
494 
495   delete run_handler_thread_pool;
496 }
497 
TEST(RunHandlerThreadPool,MultipleSubThreadPool)498 TEST(RunHandlerThreadPool, MultipleSubThreadPool) {
499   // Set up environment for 2 sub thread pools.
500   setenv("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", "true", true);
501   setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "2", true);
502   setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE", "0,0.5",
503          true);
504   setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE", "0.5,1",
505          true);
506 
507   Eigen::MaxSizeVector<mutex> waiters_mu(2);
508   waiters_mu.resize(2);
509   Eigen::MaxSizeVector<internal::Waiter> waiters(2);
510   waiters.resize(2);
511   internal::RunHandlerThreadPool* run_handler_thread_pool =
512       new internal::RunHandlerThreadPool(
513           /*num_blocking_threads=*/2, /*num_non_blocking_threads=*/0,
514           Env::Default(), ThreadOptions(), "tf_run_handler_pool", &waiters_mu,
515           &waiters);
516   Eigen::MaxSizeVector<internal::ThreadWorkSource*> thread_work_sources(4);
517   thread_work_sources.resize(4);
518   internal::ThreadWorkSource tws[4];
519   for (int i = 0; i < 4; ++i) {
520     tws[i].SetWaiter(1, &waiters[i / 2], &waiters_mu[i / 2]);
521     thread_work_sources[i] = &tws[i];
522   }
523 
524   int result = 0;
525   mutex mu;
526   bool ok_to_execute = false;
527   bool ok_to_validate = false;
528   condition_variable function_start;
529   condition_variable function_end;
530 
531   std::vector<std::function<void()>> fns;
532   for (int i = 0; i < 4; ++i) {
533     fns.push_back([&result, &mu, &function_start, &function_end, &ok_to_execute,
534                    &ok_to_validate, i] {
535       mutex_lock l(mu);
536       while (!ok_to_execute) {
537         function_start.wait(l);
538       }
539       result = i;
540       ok_to_execute = false;
541       ok_to_validate = true;
542       function_end.notify_one();
543     });
544     run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
545                                             fns[i]);
546     run_handler_thread_pool->AddWorkToQueue(&tws[i], /*is_blocking=*/true,
547                                             fns[i]);
548   }
549   run_handler_thread_pool->StartOneThreadForTesting();
550   run_handler_thread_pool->SetThreadWorkSources(
551       /*tid=*/0, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
552   run_handler_thread_pool->SetThreadWorkSources(
553       /*tid=*/1, /*start_request_idx=*/0, /*version=*/1, thread_work_sources);
554 
555   // Pick task from the given sub thread pool requests in a round robin fashion.
556   mutex_lock l(mu);
557   for (int round = 0; round < 2; ++round) {
558     for (int i = 0; i < 2; ++i) {
559       ok_to_execute = true;
560       function_start.notify_one();
561       while (!ok_to_validate) {
562         function_end.wait(l);
563       }
564       ok_to_validate = false;
565       EXPECT_EQ(result, i);
566     }
567   }
568 
569   // Pick task from any task if there is no tasks from the requests in the sub
570   // thread pool.
571   for (int i = 0; i < 2; ++i) {
572     for (int round = 0; round < 2; ++round) {
573       ok_to_execute = true;
574       function_start.notify_one();
575       while (!ok_to_validate) {
576         function_end.wait(l);
577       }
578       ok_to_validate = false;
579       EXPECT_EQ(result, i + 2);
580     }
581   }
582 
583   delete run_handler_thread_pool;
584 }
585 
DefaultSessionOptions()586 SessionOptions DefaultSessionOptions() {
587   SessionOptions options;
588   (*options.config.mutable_device_count())["CPU"] = 2;
589   return options;
590 }
591 
CreateSession()592 std::unique_ptr<Session> CreateSession() {
593   return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
594 }
595 
596 class RunHandlerTest : public ::testing::Test {
597  public:
Initialize(std::initializer_list<float> a_values)598   void Initialize(std::initializer_list<float> a_values) {
599     Graph graph(OpRegistry::Global());
600 
601     Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
602     test::FillValues<float>(&a_tensor, a_values);
603     Node* a = test::graph::Constant(&graph, a_tensor);
604     a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
605     a_ = a->name();
606 
607     Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
608     test::FillValues<float>(&x_tensor, {1, 1});
609     Node* x = test::graph::Constant(&graph, x_tensor);
610     x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
611     x_ = x->name();
612 
613     // y = A * x
614     Node* y = test::graph::Matmul(&graph, a, x, false, false);
615     y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
616     y_ = y->name();
617 
618     Node* y_neg = test::graph::Unary(&graph, "Neg", y);
619     y_neg_ = y_neg->name();
620     y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
621 
622     Node* z = test::graph::Unary(&graph, "Identity", y_neg);
623     z_ = z->name();
624     z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
625 
626     graph.ToGraphDef(&def_);
627 
628     ASSERT_EQ(setenv("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", "2", true), 0);
629     ASSERT_EQ(
630         setenv("TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL", "8,8", true),
631         0);
632     ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
633                      "0,0.4", true),
634               0);
635     ASSERT_EQ(setenv("TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
636                      "0.4,1", true),
637               0);
638     ASSERT_EQ(setenv("TF_NUM_INTEROP_THREADS", "16", true), 0);
639   }
640 
641   string a_;
642   string x_;
643   string y_;
644   string y_neg_;
645   string z_;
646   GraphDef def_;
647 };
648 
TEST_F(RunHandlerTest,UseRunHandlerPoolEnableSubPool)649 TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPool) {
650   Initialize({3, 2, -1, 0});
651   auto session = CreateSession();
652   ASSERT_TRUE(session != nullptr);
653   EXPECT_EQ(OkStatus(), session->Create(def_));
654   std::vector<std::pair<string, Tensor>> inputs;
655 
656   // Request two targets: one fetch output and one non-fetched output.
657   std::vector<string> output_names = {y_ + ":0"};
658   std::vector<string> target_nodes = {y_neg_};
659   std::vector<Tensor> outputs;
660 
661   // Prepares RunOptions and RunMetadata
662   RunOptions run_options;
663   run_options.mutable_experimental()->set_use_run_handler_pool(true);
664 
665   Status s = session->Run(run_options, inputs, output_names, target_nodes,
666                           &outputs, nullptr);
667   EXPECT_EQ(OkStatus(), s);
668 
669   ASSERT_EQ(1, outputs.size());
670   // The first output should be initialized and have the correct
671   // output.
672   auto mat = outputs[0].matrix<float>();
673   ASSERT_TRUE(outputs[0].IsInitialized());
674   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
675 }
676 
TEST_F(RunHandlerTest,TestConcurrencyUseRunHandlerPool)677 TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) {
678   Initialize({1, 2, 3, 4});
679   auto session = CreateSession();
680   ASSERT_TRUE(session != nullptr);
681   EXPECT_EQ(OkStatus(), session->Create(def_));
682 
683   RunOptions run_options;
684   run_options.mutable_experimental()->set_use_run_handler_pool(true);
685 
686   // Fill in the input and ask for the output
687   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
688 
689   // Run the graph 1000 times in 4 different threads concurrently.
690   std::vector<string> output_names = {y_ + ":0"};
691   auto fn = [&session, output_names, run_options]() {
692     for (int i = 0; i < 1000; ++i) {
693       std::vector<std::pair<string, Tensor>> inputs;
694       std::vector<Tensor> outputs;
695       // Run the graph
696       Status s = session->Run(run_options, inputs, output_names, {}, &outputs,
697                               nullptr);
698       EXPECT_EQ(OkStatus(), s);
699       ASSERT_EQ(1, outputs.size());
700       auto mat = outputs[0].matrix<float>();
701       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
702     }
703   };
704 
705   for (int i = 0; i < 4; ++i) {
706     tp->Schedule(fn);
707   }
708 
709   // Wait for the functions to finish.
710   delete tp;
711 }
712 
TEST_F(RunHandlerTest,UseRunHandlerPoolEnableSubPoolWithPriority)713 TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPoolWithPriority) {
714   Initialize({3, 2, -1, 0});
715   auto session = CreateSession();
716   ASSERT_TRUE(session != nullptr);
717   EXPECT_EQ(OkStatus(), session->Create(def_));
718   std::vector<std::pair<string, Tensor>> inputs;
719 
720   // Request two targets: one fetch output and one non-fetched output.
721   std::vector<string> output_names = {y_ + ":0"};
722   std::vector<string> target_nodes = {y_neg_};
723   std::vector<Tensor> outputs;
724 
725   // Prepares RunOptions and RunMetadata
726   RunOptions run_options;
727   run_options.mutable_experimental()->set_use_run_handler_pool(true);
728   run_options.mutable_experimental()
729       ->mutable_run_handler_pool_options()
730       ->set_priority(1);
731 
732   Status s = session->Run(run_options, inputs, output_names, target_nodes,
733                           &outputs, nullptr);
734   EXPECT_EQ(OkStatus(), s);
735 
736   ASSERT_EQ(1, outputs.size());
737   // The first output should be initialized and have the correct
738   // output.
739   auto mat = outputs[0].matrix<float>();
740   ASSERT_TRUE(outputs[0].IsInitialized());
741   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
742 }
743 
TEST_F(RunHandlerTest,TestConcurrencyUseRunHandlerPoolWithPriority)744 TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPoolWithPriority) {
745   Initialize({1, 2, 3, 4});
746   auto session = CreateSession();
747   ASSERT_TRUE(session != nullptr);
748   EXPECT_EQ(OkStatus(), session->Create(def_));
749 
750   // Fill in the input and ask for the output
751   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
752 
753   // Run the graph 1000 times in 4 different threads concurrently.
754   std::vector<string> output_names = {y_ + ":0"};
755   auto fn = [&session, output_names]() {
756     for (int i = 0; i < 1000; ++i) {
757       RunOptions run_options;
758       run_options.mutable_experimental()->set_use_run_handler_pool(true);
759       run_options.mutable_experimental()
760           ->mutable_run_handler_pool_options()
761           ->set_priority(i % 4);
762       std::vector<std::pair<string, Tensor>> inputs;
763       std::vector<Tensor> outputs;
764       // Run the graph
765       Status s = session->Run(run_options, inputs, output_names, {}, &outputs,
766                               nullptr);
767       EXPECT_EQ(OkStatus(), s);
768       ASSERT_EQ(1, outputs.size());
769       auto mat = outputs[0].matrix<float>();
770       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
771     }
772   };
773 
774   for (int i = 0; i < 4; ++i) {
775     tp->Schedule(fn);
776   }
777 
778   // Wait for the functions to finish.
779   delete tp;
780 }
781 
TEST_F(RunHandlerTest,TestWaitTimeout)782 TEST_F(RunHandlerTest, TestWaitTimeout) {
783   std::unique_ptr<RunHandlerPool> pool(new RunHandlerPool(1, 1));
784 
785   // Get the single handler in the pool.
786   std::vector<std::unique_ptr<RunHandler>> blocking_handles;
787   const int32_t kMaxConcurrentHandlers = 128;  // Copied from run_handler.cc.
788   blocking_handles.reserve(kMaxConcurrentHandlers);
789   for (int i = 0; i < kMaxConcurrentHandlers; ++i) {
790     blocking_handles.push_back(pool->Get(i));
791   }
792 
793   // A subsequent request with a non-zero timeout will fail by returning
794   // nullptr.
795   auto null_handle = pool->Get(128, 1);
796   EXPECT_EQ(null_handle.get(), nullptr);
797 
798   // A subsequent request with no timeout will succeed once the blocking handle
799   // is returned.
800   auto tp = std::make_unique<thread::ThreadPool>(Env::Default(), "test", 4);
801   std::atomic<int64_t> release_time;
802 
803   tp->Schedule([&blocking_handles, &release_time]() {
804     Env::Default()->SleepForMicroseconds(5000);
805     release_time = EnvTime::NowNanos();
806     blocking_handles[0].reset();
807   });
808 
809   auto next_handle = pool->Get(129, 0);
810   EXPECT_GT(EnvTime::NowNanos(), release_time);
811   EXPECT_NE(next_handle.get(), nullptr);
812 }
813 
814 }  // namespace
815 }  // namespace tensorflow
816