xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/run_handler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/framework/run_handler.h"
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <list>
23 #include <memory>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/run_handler_util.h"
27 #include "tensorflow/core/lib/core/threadpool_interface.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/context.h"
30 #include "tensorflow/core/platform/denormal.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/numa.h"
33 #include "tensorflow/core/platform/setround.h"
34 #include "tensorflow/core/platform/tracing.h"
35 #include "tensorflow/core/profiler/lib/traceme.h"
36 #include "tensorflow/core/util/ptr_util.h"
37 
38 namespace tensorflow {
39 namespace {
40 // LINT.IfChange
41 static constexpr int32_t kMaxConcurrentHandlers = 128;
42 // LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
43 
44 typedef typename internal::RunHandlerEnvironment::Task Task;
45 typedef Eigen::RunQueue<Task, 1024> Queue;
46 
47 }  // namespace
48 
49 namespace internal {
RunHandlerEnvironment(Env * env,const ThreadOptions & thread_options,const string & name)50 RunHandlerEnvironment::RunHandlerEnvironment(
51     Env* env, const ThreadOptions& thread_options, const string& name)
52     : env_(env), thread_options_(thread_options), name_(name) {}
53 
CreateThread(std::function<void ()> f,const std::string & thread_name)54 RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
55     std::function<void()> f, const std::string& thread_name) {
56   return env_->StartThread(thread_options_, thread_name, [=]() {
57     // Set the processor flag to flush denormals to zero.
58     port::ScopedFlushDenormal flush;
59     // Set the processor rounding mode to ROUND TO NEAREST.
60     port::ScopedSetRound round(FE_TONEAREST);
61     if (thread_options_.numa_node != port::kNUMANoAffinity) {
62       port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
63     }
64     f();
65   });
66 }
67 
CreateTask(std::function<void ()> f)68 RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
69     std::function<void()> f) {
70   uint64 id = 0;
71   if (tracing::EventCollector::IsEnabled()) {
72     id = tracing::GetUniqueArg();
73     tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
74   }
75   return Task{
76       std::unique_ptr<TaskImpl>(new TaskImpl{
77           std::move(f),
78           Context(ContextKind::kThread),
79           id,
80       }),
81   };
82 }
83 
ExecuteTask(const Task & t)84 void RunHandlerEnvironment::ExecuteTask(const Task& t) {
85   WithContext wc(t.f->context);
86   tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
87                                t.f->trace_id);
88   t.f->f();
89 }
90 
WaitOnWaiter(Waiter * waiter,Waiter * queue_head,mutex * mutex,int max_sleep_micros)91 void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
92                   int max_sleep_micros) {
93   {
94     mutex_lock l(*mutex);
95     CHECK_EQ(waiter->next, waiter);  // Crash OK.
96     CHECK_EQ(waiter->prev, waiter);  // Crash OK.
97 
98     // Add waiter to the LIFO queue
99     waiter->prev = queue_head;
100     waiter->next = queue_head->next;
101     waiter->next->prev = waiter;
102     waiter->prev->next = waiter;
103   }
104   {
105     mutex_lock l(waiter->mu);
106     // Wait on the condition variable
107     waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
108   }
109 
110   mutex_lock l(*mutex);
111   // Remove waiter from the LIFO queue. Note even when a waiter wakes up due
112   // to a notification we cannot conclude the waiter is not in the queue.
113   // This is due to the fact that a thread preempted right before notifying
114   // may resume after a waiter got re-added.
115   if (waiter->next != waiter) {
116     CHECK(waiter->prev != waiter);  // Crash OK.
117     waiter->next->prev = waiter->prev;
118     waiter->prev->next = waiter->next;
119     waiter->next = waiter;
120     waiter->prev = waiter;
121   } else {
122     CHECK_EQ(waiter->prev, waiter);  // Crash OK.
123   }
124 }
125 
ThreadWorkSource()126 ThreadWorkSource::ThreadWorkSource()
127     : non_blocking_work_sharding_factor_(
128           static_cast<int32>(ParamFromEnvWithDefault(
129               "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
130       non_blocking_work_queues_(non_blocking_work_sharding_factor_),
131       blocking_inflight_(0),
132       non_blocking_inflight_(0),
133       traceme_id_(0),
134       version_(0),
135       sub_thread_pool_waiter_(nullptr) {
136   queue_waiters_.next = &queue_waiters_;
137   queue_waiters_.prev = &queue_waiters_;
138   for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
139     non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
140   }
141 }
142 
~ThreadWorkSource()143 ThreadWorkSource::~ThreadWorkSource() {
144   for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
145     delete non_blocking_work_queues_[i];
146   }
147 }
148 
EnqueueTask(Task t,bool is_blocking)149 Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
150   mutex* mu = nullptr;
151   Queue* task_queue = nullptr;
152   thread_local int64_t closure_counter = 0;
153 
154   if (!is_blocking) {
155     int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
156     task_queue = &(non_blocking_work_queues_[queue_index]->queue);
157     mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
158   } else {
159     task_queue = &blocking_work_queue_;
160     mu = &blocking_queue_op_mu_;
161   }
162 
163   {
164     mutex_lock l(*mu);
165     // For a given queue, only one thread can call PushFront.
166     t = task_queue->PushFront(std::move(t));
167   }
168 
169   Waiter* w = nullptr;
170   static const bool use_sub_thread_pool =
171       ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
172 
173   Waiter* waiter_queue;
174   mutex* waiter_queue_mu;
175   if (use_sub_thread_pool) {
176     // When we use multiple sub thread pools, free threads wait on sub
177     // thread pool waiting queues. Wake up threads from sub thread waiting
178     // queues.
179     // The waiting queues are defined at RunHandlerPool.
180     // Get the waiter_queue and corresponding mutex. Note, the thread work
181     // source may change afterwards if a new request comes or an old request
182     // finishes.
183     tf_shared_lock lock(run_handler_waiter_mu_);
184     waiter_queue = sub_thread_pool_waiter_;
185     waiter_queue_mu = sub_thread_pool_waiter_mu_;
186   } else {
187     waiter_queue = &queue_waiters_;
188     waiter_queue_mu = &waiters_mu_;
189   }
190   {
191     mutex_lock l(*waiter_queue_mu);
192     if (waiter_queue->next != waiter_queue) {
193       // Remove waiter from the LIFO queue
194       w = waiter_queue->next;
195 
196       CHECK(w->prev != w);  // Crash OK.
197       CHECK(w->next != w);  // Crash OK.
198 
199       w->next->prev = w->prev;
200       w->prev->next = w->next;
201 
202       // Use `w->next == &w` to indicate that the waiter has been removed
203       // from the queue.
204       w->next = w;
205       w->prev = w;
206     }
207   }
208   if (w != nullptr) {
209     // We call notify_one() without any locks, so we can miss notifications.
210     // The wake up logic is best effort and a thread will wake in short
211     // period of time in case a notification is missed.
212     w->cv.notify_one();
213   }
214   VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
215           << traceme_id_.load(std::memory_order_relaxed);
216   return t;
217 }
218 
PopBlockingTask()219 Task ThreadWorkSource::PopBlockingTask() {
220   return blocking_work_queue_.PopBack();
221 }
222 
PopNonBlockingTask(int start_index,bool search_from_all_queue)223 Task ThreadWorkSource::PopNonBlockingTask(int start_index,
224                                           bool search_from_all_queue) {
225   Task t;
226   unsigned sharding_factor = NonBlockingWorkShardingFactor();
227   for (unsigned j = 0; j < sharding_factor; ++j) {
228     t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
229             ->queue.PopBack();
230     if (t.f) {
231       return t;
232     }
233     if (!search_from_all_queue) {
234       break;
235     }
236   }
237   return t;
238 }
239 
WaitForWork(int max_sleep_micros)240 void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
241   thread_local Waiter waiter;
242   WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
243 }
244 
TaskQueueSize(bool is_blocking)245 int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
246   if (is_blocking) {
247     return blocking_work_queue_.Size();
248   } else {
249     unsigned total_size = 0;
250     for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
251       total_size += non_blocking_work_queues_[i]->queue.Size();
252     }
253     return total_size;
254   }
255 }
256 
GetTracemeId()257 int64_t ThreadWorkSource::GetTracemeId() {
258   return traceme_id_.load(std::memory_order_relaxed);
259 }
260 
SetTracemeId(int64_t value)261 void ThreadWorkSource::SetTracemeId(int64_t value) { traceme_id_ = value; }
262 
SetWaiter(uint64 version,Waiter * waiter,mutex * mutex)263 void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
264   {
265     tf_shared_lock lock(run_handler_waiter_mu_);
266     // Most of the request won't change sub pool for recomputation.
267     // Optimization for avoiding holding exclusive lock to reduce contention.
268     if (sub_thread_pool_waiter_ == waiter) {
269       return;
270     }
271     // If the current version is a newer version, no need to update.
272     if (version_ > version) {
273       return;
274     }
275   }
276 
277   mutex_lock l(run_handler_waiter_mu_);
278   sub_thread_pool_waiter_ = waiter;
279   sub_thread_pool_waiter_mu_ = mutex;
280   version_ = version;
281 }
282 
GetInflightTaskCount(bool is_blocking)283 int64_t ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
284   std::atomic<int64_t>* counter =
285       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
286   return counter->load(std::memory_order_relaxed);
287 }
288 
IncrementInflightTaskCount(bool is_blocking)289 void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
290   std::atomic<int64_t>* counter =
291       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
292   counter->fetch_add(1, std::memory_order_relaxed);
293 }
294 
DecrementInflightTaskCount(bool is_blocking)295 void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
296   std::atomic<int64_t>* counter =
297       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
298   counter->fetch_sub(1, std::memory_order_relaxed);
299 }
300 
NonBlockingWorkShardingFactor()301 unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
302   return non_blocking_work_sharding_factor_;
303 }
304 
ToString()305 std::string ThreadWorkSource::ToString() {
306   return strings::StrCat("traceme_id = ", GetTracemeId(),
307                          ", inter queue size = ", TaskQueueSize(true),
308                          ", inter inflight = ", GetInflightTaskCount(true),
309                          ", intra queue size = ", TaskQueueSize(false),
310                          ", intra inflight = ", GetInflightTaskCount(false));
311 }
312 
RunHandlerThreadPool(int num_blocking_threads,int num_non_blocking_threads,Env * env,const ThreadOptions & thread_options,const string & name,Eigen::MaxSizeVector<mutex> * waiters_mu,Eigen::MaxSizeVector<Waiter> * queue_waiters)313 RunHandlerThreadPool::RunHandlerThreadPool(
314     int num_blocking_threads, int num_non_blocking_threads, Env* env,
315     const ThreadOptions& thread_options, const string& name,
316     Eigen::MaxSizeVector<mutex>* waiters_mu,
317     Eigen::MaxSizeVector<Waiter>* queue_waiters)
318     : num_threads_(num_blocking_threads + num_non_blocking_threads),
319       num_blocking_threads_(num_blocking_threads),
320       num_non_blocking_threads_(num_non_blocking_threads),
321       thread_data_(num_threads_),
322       env_(env, thread_options, name),
323       name_(name),
324       waiters_mu_(waiters_mu),
325       queue_waiters_(queue_waiters),
326       use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
327           "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
328       num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
329           "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
330           std::vector<int>({num_blocking_threads / 2,
331                             num_blocking_threads - num_blocking_threads / 2}))),
332       sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
333           "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
334           std::vector<double>({0, 0.4}))),
335       sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
336           "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
337           std::vector<double>({0.4, 1}))) {
338   thread_data_.resize(num_threads_);
339   VLOG(1) << "Creating RunHandlerThreadPool " << name << " with  "
340           << num_blocking_threads_ << " blocking threads and "
341           << num_non_blocking_threads_ << " non-blocking threads.";
342 }
343 
~RunHandlerThreadPool()344 RunHandlerThreadPool::~RunHandlerThreadPool() {
345   VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
346 
347   cancelled_ = true;
348   for (size_t i = 0; i < thread_data_.size(); ++i) {
349     {
350       mutex_lock l(thread_data_[i].mu);
351       thread_data_[i].sources_not_empty.notify_all();
352     }
353     thread_data_[i].thread.reset();
354   }
355 }
356 
Start()357 void RunHandlerThreadPool::Start() {
358   cancelled_ = false;
359   int num_blocking_threads = num_blocking_threads_;
360   for (int i = 0; i < num_threads_; i++) {
361     int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
362     for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
363       if (i < num_threads_in_sub_thread_pool_[j]) {
364         sub_thread_pool_id = j;
365         break;
366       }
367     }
368     thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
369     const bool is_blocking_thread = (i < num_blocking_threads) ? true : false;
370     // The blocking threads will handle both inter and intra op workload;
371     // non-blocking thread will handle intra op workload only; and the
372     // sub thread pool is only provided for blocking threads.
373     // Name the threads accordingly.
374     thread_data_[i].thread.reset(env_.CreateThread(
375         [this, is_blocking_thread, i, sub_thread_pool_id]() {
376           WorkerLoop(i, is_blocking_thread);
377         },
378         is_blocking_thread
379             ? strings::StrCat(name_, "_blocking_thread_", sub_thread_pool_id)
380             : strings::StrCat(name_, "_non_blocking_thread")));
381   }
382 }
383 
StartOneThreadForTesting()384 void RunHandlerThreadPool::StartOneThreadForTesting() {
385   cancelled_ = false;
386   thread_data_[0].sub_thread_pool_id = 0;
387   thread_data_[0].thread.reset(
388       env_.CreateThread([this]() { WorkerLoop(0, true); }, name_));
389 }
390 
AddWorkToQueue(ThreadWorkSource * tws,bool is_blocking,std::function<void ()> fn)391 void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
392                                           bool is_blocking,
393                                           std::function<void()> fn) {
394   Task t = env_.CreateTask(std::move(fn));
395   t = tws->EnqueueTask(std::move(t), is_blocking);
396   if (t.f) {
397     VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
398             << tws->GetTracemeId();
399     env_.ExecuteTask(t);
400   }
401 }
402 
403 // TODO(donglin) Change the task steal order to be round-robin such that if
404 // an attempt to steal task from request i failed, then attempt to steal task
405 // from the next request in terms of the arrival time. This approach may
406 // provide better performance due to less lock retention. The drawback is that
407 // the profiler will be a bit harder to read.
SetThreadWorkSources(int tid,int start_request_idx,uint64 version,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources)408 void RunHandlerThreadPool::SetThreadWorkSources(
409     int tid, int start_request_idx, uint64 version,
410     const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
411   mutex_lock l(thread_data_[tid].mu);
412   if (version > thread_data_[tid].new_version) {
413     thread_data_[tid].new_version = version;
414   } else {
415     // A newer version is already updated. No need to update.
416     return;
417   }
418   thread_data_[tid].new_thread_work_sources->resize(0);
419   if (use_sub_thread_pool_) {
420     for (int i = 0; i < thread_work_sources.size(); ++i) {
421       thread_data_[tid].new_thread_work_sources->emplace_back(
422           thread_work_sources[i]);
423     }
424   } else {
425     thread_data_[tid].new_thread_work_sources->emplace_back(
426         thread_work_sources[start_request_idx]);
427     // The number of shards for the queue. Threads in each shard will
428     // prioritize different thread_work_sources. Increase the number of shards
429     // could decrease the contention in the queue. For example, when
430     // num_shards == 1: thread_work_sources are ordered as start_request_idx,
431     // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
432     // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
433     // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
434     // 4... for the other half of the threads.
435     static const int num_shards =
436         ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
437     int token = tid % num_shards;
438     for (int i = 0; i < num_shards; ++i) {
439       for (int j = token; j < thread_work_sources.size(); j += num_shards) {
440         if (j != start_request_idx) {
441           thread_data_[tid].new_thread_work_sources->emplace_back(
442               thread_work_sources[j]);
443         }
444       }
445       token = (token + 1) % num_shards;
446     }
447     thread_data_[tid].sources_not_empty.notify_all();
448   }
449 }
450 
GetPerThread()451 RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
452   thread_local RunHandlerThreadPool::PerThread per_thread_;
453   RunHandlerThreadPool::PerThread* pt = &per_thread_;
454   return pt;
455 }
456 
CurrentThreadId() const457 int RunHandlerThreadPool::CurrentThreadId() const {
458   const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
459   if (pt->pool == this) {
460     return pt->thread_id;
461   } else {
462     return -1;
463   }
464 }
465 
NumThreads() const466 int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
467 
NumBlockingThreads() const468 int RunHandlerThreadPool::NumBlockingThreads() const {
469   return num_blocking_threads_;
470 }
471 
NumNonBlockingThreads() const472 int RunHandlerThreadPool::NumNonBlockingThreads() const {
473   return num_non_blocking_threads_;
474 }
475 
ThreadData()476 RunHandlerThreadPool::ThreadData::ThreadData()
477     : new_version(0),
478       current_index(0),
479       new_thread_work_sources(
480           new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
481               ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
482                                       kMaxConcurrentHandlers)))),
483       current_version(0),
484       current_thread_work_sources(
485           new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
486               ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
487                                       kMaxConcurrentHandlers)))) {}
488 
FindTask(int searching_range_start,int searching_range_end,int thread_id,int sub_thread_pool_id,int max_blocking_inflight,bool may_steal_blocking_work,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources,bool * task_from_blocking_queue,ThreadWorkSource ** tws)489 Task RunHandlerThreadPool::FindTask(
490     int searching_range_start, int searching_range_end, int thread_id,
491     int sub_thread_pool_id, int max_blocking_inflight,
492     bool may_steal_blocking_work,
493     const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
494     bool* task_from_blocking_queue, ThreadWorkSource** tws) {
495   Task t;
496   int current_index = thread_data_[thread_id].current_index;
497   *task_from_blocking_queue = false;
498 
499   for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
500     if (current_index >= searching_range_end ||
501         current_index < searching_range_start) {
502       current_index = searching_range_start;
503     }
504     *tws = thread_work_sources[current_index];
505     ++current_index;
506 
507     // For blocking thread, search for blocking tasks first.
508     if (may_steal_blocking_work &&
509         (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
510       t = (*tws)->PopBlockingTask();
511       if (t.f) {
512         *task_from_blocking_queue = true;
513         break;
514       }
515     }
516 
517     // Search for non-blocking tasks.
518     t = (*tws)->PopNonBlockingTask(thread_id, true);
519     if (t.f) {
520       break;
521     }
522   }
523   thread_data_[thread_id].current_index = current_index;
524   return t;
525 }
526 
527 // Main worker thread loop.
WorkerLoop(int thread_id,bool may_steal_blocking_work)528 void RunHandlerThreadPool::WorkerLoop(int thread_id,
529                                       bool may_steal_blocking_work) {
530   PerThread* pt = GetPerThread();
531   pt->pool = this;
532   pt->thread_id = thread_id;
533   static constexpr int32_t kMaxBlockingInflight = 10;
534 
535   while (!cancelled_) {
536     Task t;
537     ThreadWorkSource* tws = nullptr;
538     bool task_from_blocking_queue = true;
539     int sub_thread_pool_id;
540     // Get the current thread work sources.
541     {
542       mutex_lock l(thread_data_[thread_id].mu);
543       if (thread_data_[thread_id].current_version <
544           thread_data_[thread_id].new_version) {
545         thread_data_[thread_id].current_version =
546             thread_data_[thread_id].new_version;
547         thread_data_[thread_id].current_thread_work_sources.swap(
548             thread_data_[thread_id].new_thread_work_sources);
549       }
550     }
551     Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
552         thread_data_[thread_id].current_thread_work_sources.get();
553     if (use_sub_thread_pool_) {
554       sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
555       int active_requests = thread_work_sources->size();
556       if (may_steal_blocking_work) {
557         // Each thread will first look for tasks from requests that belongs to
558         // its sub thread pool.
559         int search_range_start =
560             active_requests *
561             sub_thread_pool_start_request_percentage_[sub_thread_pool_id];
562         int search_range_end =
563             active_requests *
564             sub_thread_pool_end_request_percentage_[sub_thread_pool_id];
565         search_range_end =
566             std::min(active_requests,
567                      std::max(search_range_end, search_range_start + 1));
568 
569         t = FindTask(search_range_start, search_range_end, thread_id,
570                      sub_thread_pool_id, kMaxBlockingInflight,
571                      /*may_steal_blocking_work=*/true, *thread_work_sources,
572                      &task_from_blocking_queue, &tws);
573         if (!t.f) {
574           // Search from all requests if the thread cannot find tasks from
575           // requests that belong to its own sub thread pool.
576           t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
577                        kMaxBlockingInflight,
578                        /*may_steal_blocking_work=*/true, *thread_work_sources,
579                        &task_from_blocking_queue, &tws);
580         }
581       } else {
582         // For non-blocking threads, it will always search from all pending
583         // requests.
584         t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
585                      kMaxBlockingInflight,
586                      /*may_steal_blocking_work=*/false, *thread_work_sources,
587                      &task_from_blocking_queue, &tws);
588       }
589     } else {
590       // TODO(chaox): Refactor the following code to share the logic with
591       // FindTask.
592       for (int i = 0; i < thread_work_sources->size(); ++i) {
593         tws = (*thread_work_sources)[i];
594         // We want a smallish numbers of inter threads since
595         // otherwise there will be contention in PropagateOutputs.
596         // This is best effort policy.
597         if (may_steal_blocking_work &&
598             tws->GetInflightTaskCount(true) < kMaxBlockingInflight) {
599           t = tws->PopBlockingTask();
600           if (t.f) {
601             break;
602           }
603         }
604         if (i == 0) {
605           // Always look for any work from the "primary" work source.
606           // This way when we wake up a thread for a new closure we are
607           // guaranteed it can be worked on.
608           t = tws->PopNonBlockingTask(thread_id, true);
609           if (t.f) {
610             task_from_blocking_queue = false;
611             break;
612           }
613           if (t.f) {
614             break;
615           }
616         } else {
617           t = tws->PopNonBlockingTask(thread_id, false);
618           if (t.f) {
619             task_from_blocking_queue = false;
620             break;
621           }
622         }
623       }
624     }
625     if (t.f) {
626       profiler::TraceMe activity(
627           [=] {
628             return strings::StrCat(task_from_blocking_queue ? "inter" : "intra",
629                                    " #id = ", tws->GetTracemeId(), " ",
630                                    thread_id, "#");
631           },
632           profiler::TraceMeLevel::kInfo);
633       VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra")
634               << " work from " << tws->GetTracemeId();
635       tws->IncrementInflightTaskCount(task_from_blocking_queue);
636       env_.ExecuteTask(t);
637       tws->DecrementInflightTaskCount(task_from_blocking_queue);
638     } else {
639       profiler::TraceMe activity(
640           [=] {
641             return strings::StrCat("Sleeping#thread_id=", thread_id, "#");
642           },
643           profiler::TraceMeLevel::kInfo);
644       if (VLOG_IS_ON(4)) {
645         for (int i = 0; i < thread_work_sources->size(); ++i) {
646           VLOG(4) << "source id " << i << " "
647                   << (*thread_work_sources)[i]->ToString();
648         }
649       }
650       if (use_sub_thread_pool_) {
651         WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
652       } else {
653         WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
654       }
655     }
656   }
657 }
658 
WaitForWorkInSubThreadPool(bool is_blocking,int sub_thread_pool_id)659 void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
660                                                       int sub_thread_pool_id) {
661   const int kMaxSleepMicros = 250;
662 
663   // The non-blocking thread will just sleep.
664   if (!is_blocking) {
665     Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
666     return;
667   }
668 
669   thread_local Waiter waiter;
670   WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
671                &(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
672 }
673 
WaitForWork(bool is_blocking,int thread_id,int32_t max_blocking_inflight)674 void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
675                                        int32_t max_blocking_inflight) {
676   const int kMaxSleepMicros = 250;
677 
678   // The non-blocking thread will just sleep.
679   if (!is_blocking) {
680     Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
681     return;
682   }
683 
684   ThreadWorkSource* tws = nullptr;
685   {
686     mutex_lock l(thread_data_[thread_id].mu);
687     if (thread_data_[thread_id].new_version >
688         thread_data_[thread_id].current_version) {
689       thread_data_[thread_id].current_thread_work_sources.swap(
690           thread_data_[thread_id].new_thread_work_sources);
691       thread_data_[thread_id].current_version =
692           thread_data_[thread_id].new_version;
693     }
694     Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
695         thread_data_[thread_id].current_thread_work_sources.get();
696     while (!cancelled_ && thread_work_sources->empty()) {
697       // Wait until there is new request
698       thread_data_[thread_id].sources_not_empty.wait(l);
699       if (thread_data_[thread_id].new_version >
700           thread_data_[thread_id].current_version) {
701         thread_data_[thread_id].current_thread_work_sources.swap(
702             thread_data_[thread_id].new_thread_work_sources);
703         thread_data_[thread_id].current_version =
704             thread_data_[thread_id].new_version;
705         thread_work_sources =
706             thread_data_[thread_id].current_thread_work_sources.get();
707       }
708     }
709     if (cancelled_) {
710       return;
711     }
712     tws = (*thread_work_sources)[0];
713   }
714 
715   if (tws->GetInflightTaskCount(true) >= max_blocking_inflight) {
716     // Sleep to reduce contention in PropagateOutputs
717     Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
718   }
719   tws->WaitForWork(kMaxSleepMicros);
720 }
721 
722 }  // namespace internal
723 
724 // Contains the concrete implementation of the RunHandler.
725 // Externally visible RunHandler class simply forwards the work to this one.
726 class RunHandler::Impl {
727  public:
728   explicit Impl(RunHandlerPool::Impl* pool_impl);
729 
~Impl()730   ~Impl() {}
731 
thread_pool_interface()732   thread::ThreadPoolInterface* thread_pool_interface() {
733     return thread_pool_interface_.get();
734   }
735 
736   // Stores now time (in microseconds) since unix epoch when the handler is
737   // requested via RunHandlerPool::Get().
start_time_us() const738   uint64 start_time_us() const { return start_time_us_; }
step_id() const739   int64_t step_id() const { return step_id_; }
740   void ScheduleInterOpClosure(std::function<void()> fn);
741   void ScheduleIntraOpClosure(std::function<void()> fn);
742 
743   void Reset(int64_t step_id,
744              const RunOptions::Experimental::RunHandlerPoolOptions& options);
745 
pool_impl()746   RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
747 
tws()748   internal::ThreadWorkSource* tws() { return &tws_; }
749 
priority()750   int64_t priority() { return options_.priority(); }
751 
752  private:
753   class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface {
754    public:
ThreadPoolInterfaceWrapper(Impl * run_handler_impl)755     explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl)
756         : run_handler_impl_(run_handler_impl) {}
~ThreadPoolInterfaceWrapper()757     ~ThreadPoolInterfaceWrapper() override {}
758     void Schedule(std::function<void()> fn) override;
759     int NumThreads() const override;
760     int CurrentThreadId() const override;
761 
762    private:
763     RunHandler::Impl* run_handler_impl_ = nullptr;
764   };
765 
766   RunHandlerPool::Impl* pool_impl_;  // NOT OWNED.
767   uint64 start_time_us_;
768   int64_t step_id_;
769   std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
770   internal::ThreadWorkSource tws_;
771   RunOptions::Experimental::RunHandlerPoolOptions options_;
772 };
773 
774 // Contains shared state across all run handlers present in the pool. Also
775 // responsible for pool management decisions.
776 // This class is thread safe.
777 class RunHandlerPool::Impl {
778  public:
Impl(int num_inter_op_threads,int num_intra_op_threads)779   explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
780       : max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
781             "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
782         waiters_mu_(
783             ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
784         queue_waiters_(
785             ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
786         run_handler_thread_pool_(new internal::RunHandlerThreadPool(
787             num_inter_op_threads, num_intra_op_threads, Env::Default(),
788             ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
789             &queue_waiters_)),
790         iterations_(0),
791         version_(0),
792         sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
793             "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
794             std::vector<double>({1}))) {
795     VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
796     free_handlers_.reserve(max_handlers_);
797     handlers_.reserve(max_handlers_);
798     for (int i = 0; i < max_handlers_; ++i) {
799       handlers_.emplace_back(new RunHandler::Impl(this));
800       free_handlers_.push_back(handlers_.back().get());
801     }
802     queue_waiters_.resize(
803         ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
804     waiters_mu_.resize(
805         ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
806     for (auto& queue_waiter : queue_waiters_) {
807       queue_waiter.next = &queue_waiter;
808       queue_waiter.prev = &queue_waiter;
809     }
810     run_handler_thread_pool_->Start();
811   }
812 
~Impl()813   ~Impl() {
814     // Sanity check that all handlers have been returned back to the pool before
815     // destruction.
816     DCHECK_EQ(handlers_.size(), max_handlers_);
817     DCHECK_EQ(free_handlers_.size(), handlers_.size());
818     DCHECK_EQ(sorted_active_handlers_.size(), 0);
819     // Stop the threads in run_handler_thread_pool_ before freeing other
820     // pointers. Otherwise a thread may try to access a pointer after the
821     // pointer has been freed.
822     run_handler_thread_pool_.reset();
823   }
824 
run_handler_thread_pool()825   internal::RunHandlerThreadPool* run_handler_thread_pool() {
826     return run_handler_thread_pool_.get();
827   }
828 
has_free_handler()829   bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
830     return !free_handlers_.empty();
831   }
832 
Get(int64_t step_id,int64_t timeout_in_ms,const RunOptions::Experimental::RunHandlerPoolOptions & options)833   std::unique_ptr<RunHandler> Get(
834       int64_t step_id, int64_t timeout_in_ms,
835       const RunOptions::Experimental::RunHandlerPoolOptions& options)
836       TF_LOCKS_EXCLUDED(mu_) {
837     thread_local std::unique_ptr<
838         Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
839         thread_work_sources =
840             std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
841                 new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
842                     static_cast<int32>(ParamFromEnvWithDefault(
843                         "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
844                         kMaxConcurrentHandlers))));
845     uint64 version;
846     int num_active_requests;
847     RunHandler::Impl* handler_impl;
848     {
849       mutex_lock l(mu_);
850       if (!has_free_handler()) {
851         profiler::TraceMe activity(
852             [&] {
853               return strings::StrCat("WaitingForHandler#step_id=", step_id,
854                                      "#");
855             },
856             profiler::TraceMeLevel::kInfo);
857         TRACESTRING(
858             strings::StrCat("RunHandlerPool::Impl::Get waiting for a handler "
859                             "with timeout in millisecond",
860                             timeout_in_ms));
861         if (timeout_in_ms == 0) {
862           mu_.Await(Condition(this, &Impl::has_free_handler));
863         } else if (!mu_.AwaitWithDeadline(
864                        Condition(this, &Impl::has_free_handler),
865                        EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) {
866           return nullptr;
867         }
868       }
869       // Remove the last entry from free_handlers_ and add to the end of
870       // sorted_active_handlers_.
871       handler_impl = free_handlers_.back();
872       handler_impl->Reset(step_id, options);
873       free_handlers_.pop_back();
874 
875       num_active_requests = sorted_active_handlers_.size() + 1;
876       thread_work_sources->resize(num_active_requests);
877       int priority = options.priority();
878       auto it = sorted_active_handlers_.cbegin();
879       bool new_handler_inserted = false;
880       for (int i = 0; i < num_active_requests; ++i) {
881         if (!new_handler_inserted && (it == sorted_active_handlers_.cend() ||
882                                       priority > (*it)->priority())) {
883           sorted_active_handlers_.insert(it, handler_impl);
884           new_handler_inserted = true;
885           // Point to the newly added handler.
886           --it;
887         }
888         (*thread_work_sources)[i] = (*it)->tws();
889         ++it;
890       }
891       version = ++version_;
892     }
893     RecomputePoolStats(num_active_requests, version, *thread_work_sources);
894     return WrapUnique<RunHandler>(new RunHandler(handler_impl));
895   }
896 
ReleaseHandler(RunHandler::Impl * handler)897   void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) {
898     mutex_lock l(mu_);
899     DCHECK_GT(sorted_active_handlers_.size(), 0);
900 
901     CHECK_EQ(handler->tws()->TaskQueueSize(true), 0);   // Crash OK.
902     CHECK_EQ(handler->tws()->TaskQueueSize(false), 0);  // Crash OK.
903 
904     uint64 now = tensorflow::EnvTime::NowMicros();
905     double elapsed = (now - handler->start_time_us()) / 1000.0;
906     time_hist_.Add(elapsed);
907 
908     // Erase from and update sorted_active_handlers_. Add it to the end of
909     // free_handlers_.
910     auto iter = std::find(sorted_active_handlers_.begin(),
911                           sorted_active_handlers_.end(), handler);
912     DCHECK(iter != sorted_active_handlers_.end())
913         << "Unexpected handler: " << handler
914         << " is being requested for release";
915 
916     // Remove this handler from this list and add it to the list of free
917     // handlers.
918     sorted_active_handlers_.erase(iter);
919     free_handlers_.push_back(handler);
920     DCHECK_LE(free_handlers_.size(), max_handlers_);
921     LogInfo();
922 
923     // We do not recompute pool stats during release. The side effect is that
924     // there may be empty thread work sources in the queue. However, any new
925     // requests will trigger recomputation.
926   }
927 
GetActiveHandlerPrioritiesForTesting()928   std::vector<int64_t> GetActiveHandlerPrioritiesForTesting()
929       TF_LOCKS_EXCLUDED(mu_) {
930     mutex_lock l(mu_);
931     std::vector<int64_t> ret;
932     for (const auto& handler_impl : sorted_active_handlers_) {
933       ret.push_back(handler_impl->priority());
934     }
935     return ret;
936   }
937 
938  private:
939   void RecomputePoolStats(
940       int num_active_requests, uint64 version,
941       const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
942           thread_work_sources);
943 
944   void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
945 
946   // Maximum number of handlers pre-created during pool construction time. The
947   // number has been chosen expecting each handler might at least want 1
948   // inter-op thread for execution (during compute intensive workloads like
949   // inference).
950   const int max_handlers_;
951 
952   Eigen::MaxSizeVector<mutex> waiters_mu_;
953   Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
954 
955   std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
956   // Thread compatible part used only by lock under RunHandlerPool.
957   // Handlers are sorted by start time.
958   // TODO(azaks): sort by the remaining latency budget.
959   // TODO(chaox): Consider other data structure for maintaining the sorted
960   // active handlers if the searching overhead(currently O(n)) becomes the
961   // bottleneck.
962   std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_);
963   std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_);
964   std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_);
965 
966   // Histogram of elapsed runtime of every handler (in ms).
967   histogram::Histogram time_hist_ TF_GUARDED_BY(mu_);
968 
969   int64_t iterations_ TF_GUARDED_BY(mu_);
970   mutex mu_;
971   int64_t version_ TF_GUARDED_BY(mu_);
972   const std::vector<double> sub_thread_pool_end_request_percentage_;
973 };
974 
RecomputePoolStats(int num_active_requests,uint64 version,const Eigen::MaxSizeVector<internal::ThreadWorkSource * > & thread_work_sources)975 void RunHandlerPool::Impl::RecomputePoolStats(
976     int num_active_requests, uint64 version,
977     const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
978         thread_work_sources) {
979   if (num_active_requests == 0) return;
980 
981   int sub_thread_pool_id = 0;
982   for (int i = 0; i < num_active_requests; ++i) {
983     while (
984         sub_thread_pool_id <
985             sub_thread_pool_end_request_percentage_.size() - 1 &&
986         i >= num_active_requests *
987                  sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) {
988       sub_thread_pool_id++;
989     }
990     thread_work_sources[i]->SetWaiter(version,
991                                       &queue_waiters_[sub_thread_pool_id],
992                                       &waiters_mu_[sub_thread_pool_id]);
993   }
994 
995   int num_threads = run_handler_thread_pool()->NumThreads();
996   int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads();
997   int num_non_blocking_threads = num_threads - num_blocking_threads;
998 
999   std::vector<int> request_idx_list = ChooseRequestsWithExponentialDistribution(
1000       num_active_requests, num_blocking_threads);
1001   for (int i = 0; i < num_blocking_threads; ++i) {
1002     VLOG(2) << "Set work for tid=" << i
1003             << " with start_request_idx=" << request_idx_list[i];
1004     run_handler_thread_pool()->SetThreadWorkSources(
1005         i, request_idx_list[i], version, thread_work_sources);
1006   }
1007 
1008   request_idx_list = ChooseRequestsWithExponentialDistribution(
1009       num_active_requests, num_non_blocking_threads);
1010   for (int i = 0; i < num_non_blocking_threads; ++i) {
1011     VLOG(2) << "Set work for tid=" << (i + num_blocking_threads)
1012             << " with start_request_idx=" << request_idx_list[i];
1013     run_handler_thread_pool()->SetThreadWorkSources(
1014         i + num_blocking_threads, request_idx_list[i], version,
1015         thread_work_sources);
1016   }
1017 }
1018 
LogInfo()1019 void RunHandlerPool::Impl::LogInfo() {
1020   if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
1021     int num_active_requests = sorted_active_handlers_.size();
1022     VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
1023     VLOG(1) << "Active session runs: " << num_active_requests;
1024     uint64 now = tensorflow::Env::Default()->NowMicros();
1025     string times_str = "";
1026     string ids_str = "";
1027     auto it = sorted_active_handlers_.cbegin();
1028     for (int i = 0; i < num_active_requests; ++i) {
1029       if (i > 0) {
1030         times_str += " ";
1031         ids_str += " ";
1032       }
1033 
1034       times_str +=
1035           strings::StrCat((now - (*it)->start_time_us()) / 1000.0, " ms.");
1036       ids_str += strings::StrCat((*it)->tws()->GetTracemeId());
1037       ++it;
1038     }
1039     VLOG(1) << "Elapsed times are: " << times_str;
1040     VLOG(1) << "Step ids are: " << ids_str;
1041   }
1042 }
1043 
1044 // It is important to return a value such as:
1045 // CurrentThreadId() in [0, NumThreads)
NumThreads() const1046 int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const {
1047   return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads();
1048 }
1049 
CurrentThreadId() const1050 int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const {
1051   return run_handler_impl_->pool_impl_->run_handler_thread_pool()
1052       ->CurrentThreadId();
1053 }
1054 
Schedule(std::function<void ()> fn)1055 void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule(
1056     std::function<void()> fn) {
1057   return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn));
1058 }
1059 
Impl(RunHandlerPool::Impl * pool_impl)1060 RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
1061     : pool_impl_(pool_impl) {
1062   thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this));
1063   Reset(0, RunOptions::Experimental::RunHandlerPoolOptions());
1064 }
1065 
ScheduleInterOpClosure(std::function<void ()> fn)1066 void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
1067   VLOG(3) << "Scheduling inter work for  " << tws()->GetTracemeId();
1068   pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
1069                                                         std::move(fn));
1070 }
1071 
ScheduleIntraOpClosure(std::function<void ()> fn)1072 void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) {
1073   VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
1074   pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
1075                                                         std::move(fn));
1076 }
1077 
Reset(int64_t step_id,const RunOptions::Experimental::RunHandlerPoolOptions & options)1078 void RunHandler::Impl::Reset(
1079     int64_t step_id,
1080     const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1081   start_time_us_ = tensorflow::Env::Default()->NowMicros();
1082   step_id_ = step_id;
1083   options_ = options;
1084   tws_.SetTracemeId(step_id);
1085 }
1086 
RunHandlerPool(int num_inter_op_threads)1087 RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
1088     : impl_(new Impl(num_inter_op_threads, 0)) {}
1089 
RunHandlerPool(int num_inter_op_threads,int num_intra_op_threads)1090 RunHandlerPool::RunHandlerPool(int num_inter_op_threads,
1091                                int num_intra_op_threads)
1092     : impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {}
1093 
~RunHandlerPool()1094 RunHandlerPool::~RunHandlerPool() {}
1095 
Get(int64_t step_id,int64_t timeout_in_ms,const RunOptions::Experimental::RunHandlerPoolOptions & options)1096 std::unique_ptr<RunHandler> RunHandlerPool::Get(
1097     int64_t step_id, int64_t timeout_in_ms,
1098     const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1099   return impl_->Get(step_id, timeout_in_ms, options);
1100 }
1101 
GetActiveHandlerPrioritiesForTesting() const1102 std::vector<int64_t> RunHandlerPool::GetActiveHandlerPrioritiesForTesting()
1103     const {
1104   return impl_->GetActiveHandlerPrioritiesForTesting();
1105 }
1106 
RunHandler(Impl * impl)1107 RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
1108 
ScheduleInterOpClosure(std::function<void ()> fn)1109 void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
1110   impl_->ScheduleInterOpClosure(std::move(fn));
1111 }
1112 
AsIntraThreadPoolInterface()1113 thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() {
1114   return impl_->thread_pool_interface();
1115 }
1116 
~RunHandler()1117 RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
1118 
1119 }  // namespace tensorflow
1120