xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
18 
19 #include <stddef.h>
20 
21 #include <deque>
22 #include <functional>
23 #include <list>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/time/clock.h"
30 #include "absl/types/variant.h"
31 #include "absl/utility/utility.h"
32 #include "tensorflow/core/kernels/batching_util/batch_input_task.h"
33 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
34 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/strings/strcat.h"
38 #include "tensorflow/core/platform/byte_order.h"
39 #include "tensorflow/core/platform/cpu_info.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/thread_annotations.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow/core/profiler/lib/connected_traceme.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/profiler/lib/traceme_encode.h"
47 
48 namespace tensorflow {
49 namespace serving {
50 namespace internal {
51 template <typename TaskType>
52 class Queue;
53 }  // namespace internal
54 }  // namespace serving
55 }  // namespace tensorflow
56 
57 namespace tensorflow {
58 namespace serving {
59 
60 // A batch scheduler for server instances that service multiple request types
61 // (e.g. multiple machine-learned models, or multiple versions of a model served
62 // concurrently), or even multiple distinct tasks for a given request. The
63 // scheduler multiplexes batches of different kinds of tasks onto a fixed-size
64 // thread pool (each batch contains tasks of a single type), in a carefully
65 // controlled manner. A common configuration is to set the number of threads
66 // equal to the number of hardware accelerator units, in which case the
67 // scheduler takes care of multiplexing the task types onto the shared hardware,
68 // in a manner that is both fair and efficient.
69 //
70 // Semantically, SharedBatchScheduler behaves like having N instances of
71 // BasicBatchScheduler (see basic_batch_scheduler.h), one per task type. The
72 // difference is that under the covers there is a single shared thread pool,
73 // instead of N independent ones, with their sharing deliberately coordinated.
74 //
75 // SharedBatchScheduler does not implement the BatchScheduler API; rather, it
76 // presents an abstraction of "queues", where each queue corresponds to one type
77 // of task. Tasks submitted to a given queue are placed in their own batches,
78 // and cannot be mixed with other tasks. Queues can be added and deleted
79 // dynamically, to accommodate e.g. versions of a model being brought up and
80 // down over the lifetime of a server.
81 //
82 // The batch thread pool round-robins through the queues, running one batch
83 // from a queue and then moving to the next queue. Each queue behaves like a
84 // BasicBatchScheduler instance, in the sense that it has maximum batch size and
85 // timeout parameters, which govern when a batch is eligible to be processed.
86 //
87 // Each queue is independently configured with a maximum size (in terms of the
88 // maximum number of batches worth of enqueued tasks). For online serving, it is
89 // recommended that the queue sizes be configured such that the sum of the sizes
90 // of the active queues roughly equal the number of batch threads. (The idea is
91 // that if all threads become available at roughly the same time, there will be
92 // enough enqueued work for them to take on, but no more.)
93 //
94 // If queue sizes are configured in the manner suggested above, the maximum time
95 // a task can spend in a queue before being placed in a batch and assigned to a
96 // thread for processing, is the greater of:
97 //  - the maximum time to process one batch of tasks from any active queue
98 //  - the configured timeout parameter for the task's queue (which can be 0)
99 //
100 // For bulk processing jobs and throughput-oriented benchmarks, you may want to
101 // set the maximum queue size to a large value.
102 //
103 // TODO(b/26539183): Support queue servicing policies other than round-robin.
104 // E.g. let each queue specify a "share" (an int >= 1), so e.g. with queues A
105 // and B having shares 1 and 2 respectively, the servicing pattern is ABBABB...
106 //
107 //
108 // PERFORMANCE TUNING: See README.md.
109 //
110 template <typename TaskType>
111 class SharedBatchScheduler
112     : public std::enable_shared_from_this<SharedBatchScheduler<TaskType>> {
113  public:
114   using BatchTaskHandleUniquePtr =
115       std::unique_ptr<Batch<internal::BatchInputTaskHandle<TaskType>>>;
116   using BatchTaskUniqueptr = std::unique_ptr<Batch<TaskType>>;
117   using BatchUniquePtr =
118       absl::variant<BatchTaskUniqueptr, BatchTaskHandleUniquePtr>;
119   // TODO(b/25089730): Tune defaults based on best practices as they develop.
120   struct Options {
121     // The name to use for the pool of batch threads.
122     string thread_pool_name = {"batch_threads"};
123 
124     // The number of threads to use to process batches.
125     // Must be >= 1, and should be tuned carefully.
126     int num_batch_threads = port::MaxParallelism();
127 
128     // The environment to use.
129     // (Typically only overridden by test code.)
130     Env* env = Env::Default();
131   };
132   // Ownership is shared between the caller of Create() and any queues created
133   // via AddQueue().
134   static Status Create(
135       const Options& options,
136       std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler);
137 
138   ~SharedBatchScheduler();
139 
140   // Adds a queue to which tasks may be submitted. The returned queue implements
141   // the BatchScheduler API. Each queue has its own set of scheduling options,
142   // and its own callback to process batches of tasks submitted to the queue.
143   //
144   // The returned queue's destructor blocks until all tasks submitted to it have
145   // been processed.
146   struct QueueOptions {
147     // The size limit of an input batch to the queue.
148     //
149     // If `enable_large_batch_splitting` is True, 'input_batch_size_limit'
150     // should be greater or equal than `max_execution_batch_size`; otherwise
151     // `input_batch_size_limit` should be equal to `max_execution_batch_size`.
152     size_t input_batch_size_limit = 1000;
153 
154     // If a task has been enqueued for this amount of time (in microseconds),
155     // and a thread is available, the scheduler will immediately form a batch
156     // from enqueued tasks and assign the batch to the thread for processing,
157     // even if the batch's size is below 'input_batch_size_limit'.
158     //
159     // This parameter offers a way to bound queue latency, so that a task isn't
160     // stuck in the queue indefinitely waiting for enough tasks to arrive to
161     // make a full batch. (The latency bound is given in the class documentation
162     // above.)
163     //
164     // The goal is to smooth out batch sizes under low request rates, and thus
165     // avoid latency spikes.
166     int64_t batch_timeout_micros = 0;
167 
168     // The maximum allowable number of enqueued (accepted by Schedule() but
169     // not yet being processed on a batch thread) tasks in terms of batches.
170     // If this limit is reached, Schedule() will return an UNAVAILABLE error.
171     // See the class documentation above for guidelines on how to tune this
172     // parameter.
173     //
174     // Must be positive, or else invalid argument error will be returned at
175     // queue creation time.
176     size_t max_enqueued_batches = 10;
177 
178     // If true, queue implementation would split one input batch task into
179     // subtasks (as specified by `split_input_task_func` below) and fit subtasks
180     // into different batches.
181     //
182     // For usage of `split_input_task_func`, please see its comment.
183     bool enable_large_batch_splitting = false;
184 
185     // `input_task`: a unit of task to be split.
186     // `first_output_task_size`: task size of first output.
187     // `max_execution_batch_size`: Maximum size of each batch.
188     // `output_tasks`: A list of output tasks after split.
189     //
190     // REQUIRED:
191     // 1) All `output_tasks` should be non-empty tasks.
192     // 2) Sizes of `output_tasks` add up to size of `input_task`.
193     //
194     // NOTE:
195     // Instantiations of `TaskType` may vary, so it's up to caller to define
196     // how (e.g., which members to access) to split input tasks.
197     std::function<Status(std::unique_ptr<TaskType>* input_task,
198                          int first_output_task_size, int input_batch_size_limit,
199                          std::vector<std::unique_ptr<TaskType>>* output_tasks)>
200         split_input_task_func;
201 
202     // If true, batch input tasks are split lazily after dequeue and not on the
203     // critical path of enqueue operations.
204     //
205     // Must be false if `enable_large_batch_splitting` is false; elsewise errors
206     // will be returned at queue creation time.
207     //
208     // TODO(b/194294263):
209     // Make `enable_lazy_split` a template parameter of queue, and adapts
210     // `batches_` and `task_handle_batches_` into one deque of
211     // tensorflow::serving::Batch.
212     bool enable_lazy_split = false;
213 
214     // The maximum size of each enqueued batch (i.e., in `batches_`).
215     //
216     // The scheduler may form batches of any size between 1 and this number
217     // (inclusive). If there is a need to quantize the batch sizes, i.e. only
218     // submit batches whose size is in a small set of allowed sizes, that can be
219     // done by adding padding in the process-batch callback.
220     size_t max_execution_batch_size = 1000;
221   };
222   Status AddQueue(const QueueOptions& options,
223                   std::function<void(std::unique_ptr<Batch<TaskType>>)>
224                       process_batch_callback,
225                   std::unique_ptr<BatchScheduler<TaskType>>* queue);
226 
227  private:
228   explicit SharedBatchScheduler(const Options& options);
229 
230   void GetNextWorkItem_Locked(internal::Queue<TaskType>** queue_for_batch_out,
231                               BatchUniquePtr* batch_to_process_out)
232       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
233 
234   // The code executed in 'batch_threads_'. Obtains a batch to process from the
235   // queue pointed to by 'next_queue_to_schedule_', and processes it. If that
236   // queue declines to provide a batch to process, moves onto the next queue. If
237   // no queues provide a batch to process, just sleeps briefly and exits.
238   void ThreadLogic();
239 
240   // Called by `AddQueue`.
241   Status AddQueueAfterRewritingOptions(
242       const QueueOptions& options,
243       std::function<void(std::unique_ptr<Batch<TaskType>>)>
244           process_batch_callback,
245       std::unique_ptr<BatchScheduler<TaskType>>* queue);
246 
247   static bool BatchExists(const BatchUniquePtr& batch_to_process);
248 
249   const Options options_;
250 
251   mutex mu_;
252 
253   // A list of queues. (We use std::list instead of std::vector to ensure that
254   // iterators are not invalidated by adding/removing elements. It also offers
255   // efficient removal of elements from the middle.)
256   using QueueList = std::list<std::unique_ptr<internal::Queue<TaskType>>>;
257 
258   // All "active" queues, i.e. ones that either:
259   //  - have not been removed, or
260   //  - have been removed but are not yet empty.
261   QueueList queues_ TF_GUARDED_BY(mu_);
262 
263   // An iterator over 'queues_', pointing to the queue from which the next
264   // available batch thread should grab work.
265   typename QueueList::iterator next_queue_to_schedule_ TF_GUARDED_BY(mu_);
266 
267   // Used by idle batch threads to wait for work to enter the system. Notified
268   // whenever a batch becomes schedulable.
269   condition_variable schedulable_batch_cv_;
270 
271   // Threads that process batches obtained from the queues.
272   std::vector<std::unique_ptr<PeriodicFunction>> batch_threads_;
273 
274   TF_DISALLOW_COPY_AND_ASSIGN(SharedBatchScheduler);
275 };
276 
277 //////////
278 // Implementation details follow. API users need not read.
279 
280 namespace internal {
281 
282 // A task queue for SharedBatchScheduler. Accepts tasks and accumulates them
283 // into batches, and dispenses those batches to be processed via a "pull"
284 // interface. The queue's behavior is governed by maximum batch size, timeout
285 // and maximum queue length parameters; see their documentation in
286 // SharedBatchScheduler.
287 //
288 // The queue is implemented as a deque of batches, with these invariants:
289 //  - The number of batches is between 1 and 'options_.max_enqueued_batches'.
290 //  - The back-most batch is open; the rest are closed.
291 //
292 // Submitted tasks are added to the open batch. If that batch doesn't have room
293 // but the queue isn't full, then that batch is closed and a new open batch is
294 // started.
295 //
296 // Batch pull requests are handled by dequeuing the front-most batch if it is
297 // closed. If the front-most batch is open (i.e. the queue contains only one
298 // batch) and has reached the timeout, it is immediately closed and returned;
299 // otherwise no batch is returned for the request.
300 template <typename TaskType>
301 class Queue {
302  public:
303   using ProcessBatchCallback =
304       std::function<void(std::unique_ptr<Batch<TaskType>>)>;
305   using SchedulableBatchCallback = std::function<void()>;
306   using SplitInputTaskIntoSubtasksCallback = std::function<Status(
307       std::unique_ptr<TaskType>* input_task, int open_batch_remaining_slot,
308       int max_execution_batch_size,
309       std::vector<std::unique_ptr<TaskType>>* output_tasks)>;
310   Queue(const typename SharedBatchScheduler<TaskType>::QueueOptions& options,
311         Env* env, ProcessBatchCallback process_batch_callback,
312         SchedulableBatchCallback schedulable_batch_callback);
313 
314   // Illegal to destruct unless the queue is empty.
315   ~Queue();
316 
317   // Submits a task to the queue, with the same semantics as
318   // BatchScheduler::Schedule().
319   Status Schedule(std::unique_ptr<TaskType>* task);
320 
321   // Enqueue `task` as it is OR split it inline (eagerly) to form batches to be
322   // processed by `Queue<TaskType>::ProcessBatch`
323   Status ScheduleWithoutOrEagerSplit(std::unique_ptr<TaskType>* task);
324 
325   // Enqueue `task` along with the batch queue metadata.
326   // Batches are formed by the time `ScheduleWithLazySplit` returns; and each
327   // batch in the deque could evaluate to a batch to be processed after it's
328   // dequeued (out of mutex-protected area).
329   Status ScheduleWithLazySplit(std::unique_ptr<TaskType>* task);
330 
331   // Returns the number of enqueued tasks, with the same semantics as
332   // BatchScheduler::NumEnqueuedTasks().
333   size_t NumEnqueuedTasks() const;
334 
335   // Returns the queue capacity, with the same semantics as
336   // BatchScheduler::SchedulingCapacity().
337   size_t SchedulingCapacity() const;
338 
339   // Returns the maximum allowed size of tasks submitted to the queue.
max_task_size()340   size_t max_task_size() const { return options_.input_batch_size_limit; }
341 
342   // Returns the maximum allowed size of tasks to be executed.
343   // Returned value would be less than or equal to the maximum allowed input
344   // size that's provided by caller of batch scheduler.
max_execution_batch_size()345   size_t max_execution_batch_size() const { return max_execution_batch_size_; }
346 
347   // Called by a thread that is ready to process a batch, to request one from
348   // this queue. Either returns a batch that is ready to be processed, or
349   // nullptr if the queue declines to schedule a batch at this time. If it
350   // returns a batch, the batch is guaranteed to be closed.
351   typename SharedBatchScheduler<TaskType>::BatchUniquePtr ScheduleBatch();
352 
353   // A variant of `ScheduleBatch`.
354   // Batches are guaranteed to form at task enqueue time.
355   std::unique_ptr<Batch<TaskType>> ScheduleBatchWithEagerSplit();
356 
357   // Processes a batch that has been returned earlier by ScheduleBatch().
358   void ProcessBatch(std::unique_ptr<Batch<TaskType>> batch);
359 
360   // Determines whether the queue is empty, i.e. has no tasks waiting or being
361   // processed.
362   bool IsEmpty() const;
363 
364   // Marks the queue closed, and waits until it is empty.
365   void CloseAndWaitUntilEmpty();
366 
closed()367   bool closed() const TF_NO_THREAD_SAFETY_ANALYSIS { return closed_.load(); }
368 
369  private:
370   // Computes the max_execution_batch_size of the queue based on queue options.
GetMaxExecutionBatchSize(const typename SharedBatchScheduler<TaskType>::QueueOptions & options)371   static size_t GetMaxExecutionBatchSize(
372       const typename SharedBatchScheduler<TaskType>::QueueOptions& options) {
373     // If `enable_large_batch_splitting`, returns `max_execution_batch_size`
374     // configured by user options directly; returns `input_batch_size_limit`
375     // otherwise.
376     //
377     // Note `input_batch_size_limit` is used for backward compatibitliy ->
378     // users may not specify `max_execution_batch_size` explicitly.
379     if (options.enable_large_batch_splitting) {
380       return options.max_execution_batch_size;
381     } else {
382       return options.input_batch_size_limit;
383     }
384   }
385 
386   // Same as IsEmpty(), but assumes the caller already holds a lock on 'mu_'.
387   bool IsEmptyInternal() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
388 
389   // Closes the open batch residing at the back of std::deque, and inserts a
390   // fresh open batch behind it.
391   void StartNewBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
392 
393   // Split `input task` into `output_tasks` according to 'task_sizes'.
394   Status SplitInputBatchIntoSubtasks(
395       std::unique_ptr<TaskType>* input_task,
396       std::vector<std::unique_ptr<TaskType>>* output_tasks)
397       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
398 
399   // Determines whether the open batch residing at the back of 'batches_' is
400   // currently schedulable.
401   bool IsOpenBatchSchedulable() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
402 
403   // A variant of `IsOpenBatchSchedulable`; used when batches are formed at
404   // task enqueue time, and open batch is `batches_.back()`.
405   bool IsOpenBatchSchedulableAfterEagerSplit() const
406       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
407 
408   // Same as SchedulingCapacity(), but assumes the caller already holds a
409   // lock on 'mu_'.
410   size_t SchedulingCapacityInternal() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
411 
412   // Returns true if queue doesn't have capacity for this task.
413   //
414   // `task` must outlive this method.
415   bool BatchTaskExceedQueueCapacity(TaskType* task) const
416       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
417 
418   // The task size of the last batch in the queue.
419   size_t tail_batch_task_size() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
420 
421   // Returns the number of enqueued batches.
422   int64 num_enqueued_batches() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
423 
424   const typename SharedBatchScheduler<TaskType>::QueueOptions options_;
425 
426   // The environment to use.
427   Env* env_;
428 
429   // The maximum batch size to be executed by `Queue::ProcessBatch`.
430   // See the comment of QueueOptions and helper function
431   // `GetMaxExecutionBatchSize` for more details on what it means.
432   const size_t max_execution_batch_size_;
433 
434   // A callback invoked to processes a batch of work units. Always invoked
435   // from a batch thread.
436   ProcessBatchCallback process_batch_callback_;
437 
438   // A callback invoked to notify the scheduler that a new batch has become
439   // schedulable.
440   SchedulableBatchCallback schedulable_batch_callback_;
441 
442   mutable mutex mu_;
443 
444   // Whether this queue can accept new tasks. This variable is monotonic: it
445   // starts as false, and then at some point gets set to true and remains true
446   // for the duration of this object's life.
TF_GUARDED_BY(mu_)447   std::atomic<bool> closed_ TF_GUARDED_BY(mu_){false};
448 
449   // The enqueued batches.
450   // Each element corresponds to a task to be dequeued and processed by
451   // `Queue<TaskType>::ProcessBatch`.
452   //
453   // Used iff `QueueOptions.enable_lazy_split` is false.
454   std::deque<std::unique_ptr<Batch<TaskType>>> batches_ TF_GUARDED_BY(mu_);
455 
456   // The enqueued batches.
457   //
458   // Each element corresponds to the `task` enqueued in `Queue::Schedule`; the
459   // element could be split and processed in batches at dequeue time.
460   //
461   // Used iff `QueueOptions.enable_lazy_split` is true.
462   std::deque<std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>>>
463       task_handle_batches_ TF_GUARDED_BY(mu_);
464 
465   // The counter of the TraceMe context ids.
466   uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0;
467 
468   // The time at which the first task was added to the open (back-most) batch
469   // in 'batches_'. Valid iff that batch contains at least one task.
470   uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
471 
472   // Whether this queue contains a batch that is eligible to be scheduled.
473   // Used to keep track of when to call 'schedulable_batch_callback_'.
474   bool schedulable_batch_ TF_GUARDED_BY(mu_) = false;
475 
476   // The number of batches currently being processed by batch threads.
477   // Incremented in ScheduleBatch() and decremented in ProcessBatch().
478   int num_batches_being_processed_ TF_GUARDED_BY(mu_) = 0;
479 
480   // Used by CloseAndWaitUntilEmpty() to wait until the queue is empty, for
481   // the case in which the queue is not empty when CloseAndWaitUntilEmpty()
482   // starts. When ProcessBatch() dequeues the last batch and makes the queue
483   // empty, if 'empty_notification_' is non-null it calls
484   // 'empty_notification_->Notify()'.
485   Notification* empty_notification_ TF_GUARDED_BY(mu_) = nullptr;
486 
487   TF_DISALLOW_COPY_AND_ASSIGN(Queue);
488 };
489 
490 // A RAII-style object that points to a Queue and implements
491 // the BatchScheduler API. To be handed out to clients who call AddQueue().
492 template <typename TaskType>
493 class QueueHandle : public BatchScheduler<TaskType> {
494  public:
495   QueueHandle(std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,
496               Queue<TaskType>* queue);
497   ~QueueHandle() override;
498 
499   Status Schedule(std::unique_ptr<TaskType>* task) override;
500   size_t NumEnqueuedTasks() const override;
501   size_t SchedulingCapacity() const override;
502 
max_task_size()503   size_t max_task_size() const override { return queue_->max_task_size(); }
504 
505  private:
506   // The scheduler that owns 'queue_'.
507   std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler_;
508 
509   // The queue this handle wraps. Owned by 'scheduler_', which keeps it alive at
510   // least until this class's destructor closes it.
511   Queue<TaskType>* queue_;
512 
513   TF_DISALLOW_COPY_AND_ASSIGN(QueueHandle);
514 };
515 
516 }  // namespace internal
517 
518 template <typename TaskType>
Create(const Options & options,std::shared_ptr<SharedBatchScheduler<TaskType>> * scheduler)519 Status SharedBatchScheduler<TaskType>::Create(
520     const Options& options,
521     std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler) {
522   if (options.num_batch_threads < 1) {
523     return errors::InvalidArgument("num_batch_threads must be positive; was ",
524                                    options.num_batch_threads);
525   }
526   scheduler->reset(new SharedBatchScheduler<TaskType>(options));
527   return OkStatus();
528 }
529 
530 template <typename TaskType>
~SharedBatchScheduler()531 SharedBatchScheduler<TaskType>::~SharedBatchScheduler() {
532   // Wait until the batch threads finish clearing out and deleting the closed
533   // queues.
534   for (;;) {
535     {
536       mutex_lock l(mu_);
537       if (queues_.empty()) {
538         break;
539       }
540     }
541     const int64_t kSleepTimeMicros = 100;
542     options_.env->SleepForMicroseconds(kSleepTimeMicros);
543   }
544   // Delete the batch threads before allowing state the threads may access (e.g.
545   // 'mu_') to be deleted.
546   batch_threads_.clear();
547 }
548 
549 template <typename TaskType>
AddQueue(const QueueOptions & options,std::function<void (std::unique_ptr<Batch<TaskType>>)> process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)550 Status SharedBatchScheduler<TaskType>::AddQueue(
551     const QueueOptions& options,
552     std::function<void(std::unique_ptr<Batch<TaskType>>)>
553         process_batch_callback,
554     std::unique_ptr<BatchScheduler<TaskType>>* queue) {
555   QueueOptions rewrite_options = options;
556   if ((!rewrite_options.enable_large_batch_splitting) &&
557       rewrite_options.max_enqueued_batches == 0) {
558     // Many existing models (with very low QPS) rely on this option to be >0.
559     // Rewrite and set this to one and retain old behavior to allow such models
560     // to continue to work.
561     //
562     // Note, technically an invalid-argument error should be returned, but
563     // that may break such models.
564     rewrite_options.max_enqueued_batches = 1;
565   }
566   return AddQueueAfterRewritingOptions(rewrite_options, process_batch_callback,
567                                        queue);
568 }
569 
570 template <typename TaskType>
AddQueueAfterRewritingOptions(const QueueOptions & options,std::function<void (std::unique_ptr<Batch<TaskType>>)> process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)571 Status SharedBatchScheduler<TaskType>::AddQueueAfterRewritingOptions(
572     const QueueOptions& options,
573     std::function<void(std::unique_ptr<Batch<TaskType>>)>
574         process_batch_callback,
575     std::unique_ptr<BatchScheduler<TaskType>>* queue) {
576   if (options.input_batch_size_limit == 0) {
577     return errors::InvalidArgument(
578         "input_batch_size_limit must be positive; was ",
579         options.input_batch_size_limit);
580   }
581   if (options.batch_timeout_micros < 0) {
582     return errors::InvalidArgument(
583         "batch_timeout_micros must be non-negative; was ",
584         options.batch_timeout_micros);
585   }
586   if (options.max_enqueued_batches == 0) {
587     return errors::InvalidArgument(
588         "max_enqueued_batches must be positive; was ",
589         options.max_enqueued_batches);
590   }
591 
592   if (options.enable_large_batch_splitting &&
593       options.split_input_task_func == nullptr) {
594     return errors::InvalidArgument(
595         "split_input_task_func must be specified when split_input_task is "
596         "true: ",
597         options.enable_large_batch_splitting);
598   }
599 
600   if (options.enable_lazy_split && (!options.enable_large_batch_splitting)) {
601     return errors::InvalidArgument(
602         "enable_lazy_split should be enabled only if "
603         "enable_large_batch_splitting is enabled.");
604   }
605 
606   if (options.enable_large_batch_splitting &&
607       (options.input_batch_size_limit < options.max_execution_batch_size)) {
608     return errors::InvalidArgument(
609         "When enable_large_batch_splitting is true, input_batch_size_limit "
610         "must be "
611         "greater than or equal to max_execution_batch_size.",
612         options.enable_large_batch_splitting, options.input_batch_size_limit,
613         options.max_execution_batch_size);
614   }
615 
616   auto schedulable_batch_callback = [this] {
617     mutex_lock l(mu_);
618     schedulable_batch_cv_.notify_one();
619   };
620   auto internal_queue =
621       std::unique_ptr<internal::Queue<TaskType>>(new internal::Queue<TaskType>(
622           options, options_.env, process_batch_callback,
623           schedulable_batch_callback));
624   auto handle = std::unique_ptr<BatchScheduler<TaskType>>(
625       new internal::QueueHandle<TaskType>(this->shared_from_this(),
626                                           internal_queue.get()));
627   {
628     mutex_lock l(mu_);
629     queues_.push_back(std::move(internal_queue));
630     if (next_queue_to_schedule_ == queues_.end()) {
631       next_queue_to_schedule_ = queues_.begin();
632     }
633   }
634   *queue = std::move(handle);
635   return OkStatus();
636 }
637 
638 template <typename TaskType>
SharedBatchScheduler(const Options & options)639 SharedBatchScheduler<TaskType>::SharedBatchScheduler(const Options& options)
640     : options_(options), next_queue_to_schedule_(queues_.end()) {
641   // Kick off the batch threads.
642   PeriodicFunction::Options periodic_fn_options;
643   periodic_fn_options.thread_name_prefix =
644       strings::StrCat(options.thread_pool_name, "_");
645   for (int i = 0; i < options.num_batch_threads; ++i) {
646     std::unique_ptr<PeriodicFunction> thread(new PeriodicFunction(
647         [this] { this->ThreadLogic(); },
648         0 /* function invocation interval time */, periodic_fn_options));
649     batch_threads_.push_back(std::move(thread));
650   }
651 }
652 
653 template <typename TaskType>
BatchExists(const BatchUniquePtr & batch_to_process)654 bool SharedBatchScheduler<TaskType>::BatchExists(
655     const BatchUniquePtr& batch_to_process) {
656   if (absl::holds_alternative<BatchTaskUniqueptr>(batch_to_process)) {
657     return absl::get<BatchTaskUniqueptr>(batch_to_process) == nullptr;
658   }
659   return absl::get<BatchTaskHandleUniquePtr>(batch_to_process) == nullptr;
660 }
661 
662 template <typename TaskType>
GetNextWorkItem_Locked(internal::Queue<TaskType> ** queue_for_batch_out,BatchUniquePtr * batch_to_process_out)663 void SharedBatchScheduler<TaskType>::GetNextWorkItem_Locked(
664     internal::Queue<TaskType>** queue_for_batch_out,
665     BatchUniquePtr* batch_to_process_out) {
666   BatchUniquePtr batch_to_process;
667   internal::Queue<TaskType>* queue_for_batch = nullptr;
668   const int num_queues = queues_.size();
669   for (int num_queues_tried = 0;
670        (BatchExists(batch_to_process)) && num_queues_tried < num_queues;
671        ++num_queues_tried) {
672     DCHECK(next_queue_to_schedule_ != queues_.end());
673 
674     // If a closed queue responds to ScheduleBatch() with nullptr, the queue
675     // will never yield any further batches so we can drop it. To avoid a
676     // race, we take a snapshot of the queue's closedness state *before*
677     // calling ScheduleBatch().
678     const bool queue_closed = (*next_queue_to_schedule_)->closed();
679 
680     // Ask '*next_queue_to_schedule_' if it wants us to process a batch.
681     batch_to_process = (*next_queue_to_schedule_)->ScheduleBatch();
682 
683     if (!BatchExists(batch_to_process)) {
684       queue_for_batch = next_queue_to_schedule_->get();
685     }
686 
687     // Advance 'next_queue_to_schedule_'.
688     if (queue_closed && (*next_queue_to_schedule_)->IsEmpty() &&
689         (BatchExists(batch_to_process))) {
690       // We've encountered a closed queue with no work to do. Drop it.
691       DCHECK_NE(queue_for_batch, next_queue_to_schedule_->get());
692       next_queue_to_schedule_ = queues_.erase(next_queue_to_schedule_);
693     } else {
694       ++next_queue_to_schedule_;
695     }
696     if (next_queue_to_schedule_ == queues_.end() && !queues_.empty()) {
697       // We've hit the end. Wrap to the first queue.
698       next_queue_to_schedule_ = queues_.begin();
699     }
700   }
701   *queue_for_batch_out = queue_for_batch;
702   *batch_to_process_out = std::move(batch_to_process);
703 }
704 
705 template <typename TaskType>
ThreadLogic()706 void SharedBatchScheduler<TaskType>::ThreadLogic() {
707   // A batch to process next (or nullptr if no work to do).
708   BatchUniquePtr batch_to_process;
709   // The queue with which 'batch_to_process' is associated.
710   internal::Queue<TaskType>* queue_for_batch = nullptr;
711   {
712     mutex_lock l(mu_);
713     while (true) {
714       GetNextWorkItem_Locked(&queue_for_batch, &batch_to_process);
715       if (!BatchExists(batch_to_process)) {
716         break;
717       }
718       // We couldn't find any work to do. Wait until a new batch becomes
719       // schedulable, or some time has elapsed, before checking again.
720       const int64_t kTimeoutMillis =
721           1;  // The smallest accepted granule of time.
722       WaitForMilliseconds(&l, &schedulable_batch_cv_, kTimeoutMillis);
723       if (queues_.empty()) return;
724     }
725   }
726 
727   std::unique_ptr<Batch<TaskType>> batch_to_schedule;
728   if (absl::holds_alternative<BatchTaskHandleUniquePtr>(batch_to_process)) {
729     // The corresponding `queue_for_batch` must be created with
730     // `enable_lazy_split=true`.
731     BatchTaskHandleUniquePtr ptr =
732         std::move(absl::get<BatchTaskHandleUniquePtr>(batch_to_process));
733     batch_to_schedule = std::make_unique<Batch<TaskType>>();
734     std::vector<std::unique_ptr<internal::BatchInputTaskHandle<TaskType>>>
735         task_handles = ptr->RemoveAllTasks();
736 
737     // TODO(b/194294263):
738     // Handle the batch-kernel callback properly when lazy split returns
739     // error.
740     for (int i = 0; i < task_handles.size(); i++) {
741       batch_to_schedule->AddTask(std::move(task_handles[i]->GetSplitTask()));
742     }
743     batch_to_schedule->Close();
744 
745   } else {
746     // The corresponding `queue_for_batch` must be created with
747     // `enable_lazy_split=false`.
748     batch_to_schedule =
749         std::move(absl::get<BatchTaskUniqueptr>(batch_to_process));
750   }
751 
752   queue_for_batch->ProcessBatch(std::move(batch_to_schedule));
753 }
754 
755 namespace internal {
756 
757 template <typename TaskType>
Queue(const typename SharedBatchScheduler<TaskType>::QueueOptions & options,Env * env,ProcessBatchCallback process_batch_callback,SchedulableBatchCallback schedulable_batch_callback)758 Queue<TaskType>::Queue(
759     const typename SharedBatchScheduler<TaskType>::QueueOptions& options,
760     Env* env, ProcessBatchCallback process_batch_callback,
761     SchedulableBatchCallback schedulable_batch_callback)
762     : options_(options),
763       env_(env),
764       max_execution_batch_size_(GetMaxExecutionBatchSize(options_)),
765       process_batch_callback_(process_batch_callback),
766       schedulable_batch_callback_(schedulable_batch_callback) {
767   // Set the higher 32 bits of traceme_context_id_counter_ to be the creation
768   // time of the queue. This prevents the batches in different queues to have
769   // the same traceme_context_id_counter_.
770   traceme_context_id_counter_ = absl::GetCurrentTimeNanos() << 32;
771   // Create an initial, open batch.
772   if (options_.enable_lazy_split) {
773     task_handle_batches_.emplace_back(
774         new Batch<BatchInputTaskHandle<TaskType>>);
775   } else {
776     batches_.emplace_back(new Batch<TaskType>);
777   }
778 }
779 
780 template <typename TaskType>
~Queue()781 Queue<TaskType>::~Queue() {
782   mutex_lock l(mu_);
783   DCHECK(IsEmptyInternal());
784 
785   // Close the (empty) open batch, so its destructor doesn't block.
786   if (options_.enable_lazy_split) {
787     task_handle_batches_.back()->Close();
788   } else {
789     batches_.back()->Close();
790   }
791 }
792 
793 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)794 Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
795   if ((*task)->size() > options_.input_batch_size_limit) {
796     return errors::InvalidArgument("Task size ", (*task)->size(),
797                                    " is larger than maximum input batch size ",
798                                    options_.input_batch_size_limit);
799   }
800   if (options_.enable_lazy_split) {
801     return ScheduleWithLazySplit(std::move(task));
802   }
803   return ScheduleWithoutOrEagerSplit(std::move(task));
804 }
805 
806 template <typename TaskType>
ScheduleWithLazySplit(std::unique_ptr<TaskType> * task)807 Status Queue<TaskType>::ScheduleWithLazySplit(std::unique_ptr<TaskType>* task) {
808   profiler::TraceMe trace_me([task] {
809     return profiler::TraceMeEncode(
810         "ScheduleWithLazySplit",
811         {{"batching_input_task_size", (*task)->size()}});
812   });
813   // The max size to be enqueued.
814   const int max_execution_batch_size = options_.max_execution_batch_size;
815 
816   bool notify_of_schedulable_batch = false;
817   {
818     mutex_lock l(mu_);
819 
820     DCHECK(!closed_);
821 
822     if (BatchTaskExceedQueueCapacity((*task).get())) {
823       return errors::Unavailable(
824           "The batch scheduling queue to which this task was submitted is "
825           "full");
826     }
827     const int64 open_batch_capacity =
828         max_execution_batch_size - this->tail_batch_task_size();
829 
830     auto input_batch = std::make_shared<BatchInputTask<TaskType>>(
831         std::move(*task), open_batch_capacity, max_execution_batch_size,
832         options_.split_input_task_func);
833     std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> task_handles;
834 
835     input_batch->ToTaskHandles(&task_handles);
836 
837     for (int i = 0; i < task_handles.size(); ++i) {
838       if (task_handle_batches_.back()->size() + task_handles[i]->size() >
839           options_.max_execution_batch_size) {
840         StartNewBatch();
841       }
842       if (task_handle_batches_.back()->empty()) {
843         open_batch_start_time_micros_ = env_->NowMicros();
844       }
845       profiler::TraceMeProducer trace_me(
846           [&task_handles, i] {
847             return profiler::TraceMeEncode("ScheduleOutputTask",
848                                            {{"size", task_handles[i]->size()}});
849           },
850           profiler::ContextType::kSharedBatchScheduler,
851           task_handle_batches_.back()->traceme_context_id());
852 
853       task_handle_batches_.back()->AddTask(std::move(task_handles[i]));
854     }
855 
856     if (!schedulable_batch_) {
857       if (batches_.size() > 1 || IsOpenBatchSchedulable()) {
858         schedulable_batch_ = true;
859         notify_of_schedulable_batch = true;
860       }
861     }
862   }
863   // TODO(b/194294263):
864   // Add unit tests to verify that `schedulable_batch_callback_` could be
865   // triggered when batches are scheduled.
866   if (notify_of_schedulable_batch) {
867     schedulable_batch_callback_();
868   }
869 
870   return OkStatus();
871 }
872 
873 // TODO(b/194294263):
874 // Merge `ScheduleWithoutOrEagerSplit` and `ScheduleWithLazySplit` into
875 // `Schedule`.
876 template <typename TaskType>
ScheduleWithoutOrEagerSplit(std::unique_ptr<TaskType> * task)877 Status Queue<TaskType>::ScheduleWithoutOrEagerSplit(
878     std::unique_ptr<TaskType>* task) {
879   const bool large_batch_splitting = options_.enable_large_batch_splitting;
880   profiler::TraceMe trace_me([task, large_batch_splitting] {
881     return profiler::TraceMeEncode(
882         large_batch_splitting ? "ScheduleWithEagerSplit"
883                               : "ScheduleWithoutSplit",
884         {{"batching_input_task_size", (*task)->size()}});
885   });
886 
887   bool notify_of_schedulable_batch = false;
888   {
889     mutex_lock l(mu_);
890 
891     DCHECK(!closed_);
892 
893     // TODO(b/161857471):
894     // Add test coverage when when concurrent incoming batches arrives and
895     // use up all queue capacity.
896     if (BatchTaskExceedQueueCapacity((*task).get())) {
897       return errors::Unavailable(
898           "The batch scheduling queue to which this task was submitted is "
899           "full");
900     }
901 
902     const int64_t open_batch_remaining_slot =
903         max_execution_batch_size() - batches_.back()->size();
904 
905     const int64_t input_task_size = (*task)->size();
906 
907     std::vector<std::unique_ptr<TaskType>> output_tasks;
908 
909     if (input_task_size <= open_batch_remaining_slot ||
910         !large_batch_splitting) {
911       // This is the fast path when input doesn't need to be split.
912       output_tasks.push_back(std::move(*task));
913     } else {
914       TF_RETURN_IF_ERROR(SplitInputBatchIntoSubtasks(task, &output_tasks));
915     }
916 
917     for (int i = 0; i < output_tasks.size(); ++i) {
918       if (batches_.back()->size() + output_tasks[i]->size() >
919           max_execution_batch_size()) {
920         StartNewBatch();
921       }
922       if (batches_.back()->empty()) {
923         open_batch_start_time_micros_ = env_->NowMicros();
924       }
925       profiler::TraceMeProducer trace_me(
926           [&output_tasks, i] {
927             return profiler::TraceMeEncode("ScheduleOutputTask",
928                                            {{"size", output_tasks[i]->size()}});
929           },
930           profiler::ContextType::kSharedBatchScheduler,
931           batches_.back()->traceme_context_id());
932       batches_.back()->AddTask(std::move(output_tasks[i]));
933     }
934 
935     if (!schedulable_batch_) {
936       if (batches_.size() > 1 || IsOpenBatchSchedulable()) {
937         schedulable_batch_ = true;
938         notify_of_schedulable_batch = true;
939       }
940     }
941   }
942 
943   if (notify_of_schedulable_batch) {
944     schedulable_batch_callback_();
945   }
946 
947   return OkStatus();
948 }
949 
950 template <typename TaskType>
NumEnqueuedTasks()951 size_t Queue<TaskType>::NumEnqueuedTasks() const {
952   size_t num_enqueued_tasks = 0;
953   mutex_lock l(mu_);
954   if (options_.enable_lazy_split) {
955     for (const auto& batch : task_handle_batches_) {
956       num_enqueued_tasks += batch->num_tasks();
957     }
958     return num_enqueued_tasks;
959   }
960 
961   for (const auto& batch : batches_) {
962     num_enqueued_tasks += batch->num_tasks();
963   }
964   return num_enqueued_tasks;
965 }
966 
967 template <typename TaskType>
SchedulingCapacity()968 size_t Queue<TaskType>::SchedulingCapacity() const {
969   mutex_lock l(mu_);
970   return SchedulingCapacityInternal();
971 }
972 
973 template <typename TaskType>
SchedulingCapacityInternal()974 size_t Queue<TaskType>::SchedulingCapacityInternal() const {
975   const int64 num_new_batches_schedulable =
976       static_cast<int64_t>(options_.max_enqueued_batches) -
977       this->num_enqueued_batches();
978   const int64 execution_batch_size_limit = max_execution_batch_size();
979   const int64 open_batch_capacity =
980       execution_batch_size_limit - this->tail_batch_task_size();
981   // Note the returned value is guaranteed to be not negative, since
982   // enqueue operation could only happen if queue has enough capacity.
983   return (num_new_batches_schedulable * execution_batch_size_limit) +
984          open_batch_capacity;
985 }
986 
987 template <typename TaskType>
BatchTaskExceedQueueCapacity(TaskType * task)988 bool Queue<TaskType>::BatchTaskExceedQueueCapacity(TaskType* task) const {
989   // Queue creation requires that `enable_large_batch_splitting` is true
990   // when `enable_lazy_split` is true, so this covers both eager split and
991   // lazy split.
992   if (options_.enable_large_batch_splitting) {
993     return task->size() > SchedulingCapacityInternal();
994   }
995 
996   // NOTE, the capacity checking below is loose and is retained
997   // for backward compatibility that was broken due to the merge of no-split
998   // and eager split.
999   // There are existing clients/models that rely on the loose check
1000   // and can get errors after the merge. Retaining the old behavior
1001   // allows such models to continue to work.
1002   //
1003   // We need to revisit/remove this check after we fix model configs.
1004   bool batch_task_exceed_queue_capacity = false;
1005   if (batches_.back()->size() + task->size() >
1006       options_.input_batch_size_limit) {
1007     if (batches_.size() >= options_.max_enqueued_batches) {
1008       batch_task_exceed_queue_capacity = true;
1009     }
1010   }
1011   return batch_task_exceed_queue_capacity;
1012 }
1013 
1014 template <typename TaskType>
1015 std::unique_ptr<Batch<TaskType>>
ScheduleBatchWithEagerSplit()1016 Queue<TaskType>::ScheduleBatchWithEagerSplit() {
1017   // The batch to schedule, which we may populate below. (If left as nullptr,
1018   // that means we are electing not to schedule a batch at this time.)
1019   std::unique_ptr<Batch<TaskType>> batch_to_schedule;
1020 
1021   {
1022     mutex_lock l(mu_);
1023 
1024     // Consider closing the open batch at this time, to schedule it.
1025     if (batches_.size() == 1 && IsOpenBatchSchedulable()) {
1026       StartNewBatch();
1027     }
1028 
1029     if (batches_.size() >= 2) {
1030       // There is at least one closed batch that is ready to be scheduled.
1031       ++num_batches_being_processed_;
1032       batch_to_schedule = std::move(batches_.front());
1033       batches_.pop_front();
1034     } else {
1035       schedulable_batch_ = false;
1036     }
1037   }
1038 
1039   return batch_to_schedule;
1040 }
1041 
1042 template <typename TaskType>
1043 typename SharedBatchScheduler<TaskType>::BatchUniquePtr
ScheduleBatch()1044 Queue<TaskType>::ScheduleBatch() {
1045   if (!options_.enable_lazy_split) {
1046     return ScheduleBatchWithEagerSplit();
1047   }
1048   // The batch to schedule, which we may populate below. (If left as nullptr,
1049   // that means we are electing not to schedule a batch at this time.)
1050   std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>>
1051       task_handles_to_schedule;
1052 
1053   {
1054     mutex_lock l(mu_);
1055 
1056     // Consider closing the open batch at this time, to schedule it.
1057     if (task_handle_batches_.size() == 1 && IsOpenBatchSchedulable()) {
1058       StartNewBatch();
1059     }
1060 
1061     if (task_handle_batches_.size() >= 2) {
1062       // There is at least one closed batch that is ready to be scheduled.
1063       ++num_batches_being_processed_;
1064       task_handles_to_schedule = std::move(task_handle_batches_.front());
1065       task_handle_batches_.pop_front();
1066     } else {
1067       schedulable_batch_ = false;
1068     }
1069   }
1070 
1071   return std::move(task_handles_to_schedule);
1072 }
1073 
1074 template <typename TaskType>
ProcessBatch(std::unique_ptr<Batch<TaskType>> batch)1075 void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
1076   profiler::TraceMeConsumer trace_me(
1077       [&] {
1078         return profiler::TraceMeEncode(
1079             "ProcessBatch", {{"batch_size_before_padding", batch->size()},
1080                              {"_r", 2} /*root_event*/});
1081       },
1082       profiler::ContextType::kSharedBatchScheduler,
1083       batch->traceme_context_id());
1084   process_batch_callback_(std::move(batch));
1085 
1086   {
1087     mutex_lock l(mu_);
1088     --num_batches_being_processed_;
1089     if (empty_notification_ != nullptr && IsEmptyInternal()) {
1090       empty_notification_->Notify();
1091     }
1092   }
1093 }
1094 
1095 template <typename TaskType>
IsEmpty()1096 bool Queue<TaskType>::IsEmpty() const {
1097   mutex_lock l(mu_);
1098   return IsEmptyInternal();
1099 }
1100 
1101 template <typename TaskType>
CloseAndWaitUntilEmpty()1102 void Queue<TaskType>::CloseAndWaitUntilEmpty() {
1103   Notification empty;
1104   {
1105     mutex_lock l(mu_);
1106     closed_ = true;
1107     if (IsEmptyInternal()) {
1108       empty.Notify();
1109     } else {
1110       // Arrange for ProcessBatch() to notify when the queue becomes empty.
1111       empty_notification_ = &empty;
1112     }
1113   }
1114   empty.WaitForNotification();
1115 }
1116 
1117 template <typename TaskType>
IsEmptyInternal()1118 bool Queue<TaskType>::IsEmptyInternal() const {
1119   if (options_.enable_lazy_split) {
1120     return num_batches_being_processed_ == 0 &&
1121            task_handle_batches_.size() == 1 &&
1122            task_handle_batches_.back()->empty();
1123   }
1124   return num_batches_being_processed_ == 0 && batches_.size() == 1 &&
1125          batches_.back()->empty();
1126 }
1127 
1128 template <typename TaskType>
StartNewBatch()1129 void Queue<TaskType>::StartNewBatch() {
1130   if (options_.enable_lazy_split) {
1131     task_handle_batches_.back()->Close();
1132     task_handle_batches_.emplace_back(new Batch<BatchInputTaskHandle<TaskType>>(
1133         ++traceme_context_id_counter_));
1134     return;
1135   }
1136   batches_.back()->Close();
1137   batches_.emplace_back(new Batch<TaskType>(++traceme_context_id_counter_));
1138 }
1139 
1140 template <typename TaskType>
SplitInputBatchIntoSubtasks(std::unique_ptr<TaskType> * input_task,std::vector<std::unique_ptr<TaskType>> * output_tasks)1141 Status Queue<TaskType>::SplitInputBatchIntoSubtasks(
1142     std::unique_ptr<TaskType>* input_task,
1143     std::vector<std::unique_ptr<TaskType>>* output_tasks) {
1144   const int open_batch_remaining_slot =
1145       max_execution_batch_size() - this->tail_batch_task_size();
1146   return options_.split_input_task_func(
1147       std::move(input_task), open_batch_remaining_slot,
1148       max_execution_batch_size(), std::move(output_tasks));
1149 }
1150 
1151 template <typename TaskType>
IsOpenBatchSchedulableAfterEagerSplit()1152 bool Queue<TaskType>::IsOpenBatchSchedulableAfterEagerSplit() const {
1153   Batch<TaskType>* open_batch = batches_.back().get();
1154   if (open_batch->empty()) {
1155     return false;
1156   }
1157   return closed_ || open_batch->size() >= max_execution_batch_size() ||
1158          env_->NowMicros() >=
1159              open_batch_start_time_micros_ + options_.batch_timeout_micros;
1160 }
1161 
1162 template <typename TaskType>
IsOpenBatchSchedulable()1163 bool Queue<TaskType>::IsOpenBatchSchedulable() const {
1164   if (!options_.enable_lazy_split) {
1165     return IsOpenBatchSchedulableAfterEagerSplit();
1166   }
1167   Batch<BatchInputTaskHandle<TaskType>>* open_batch =
1168       task_handle_batches_.back().get();
1169   if (open_batch->empty()) {
1170     return false;
1171   }
1172   return closed_ || open_batch->size() >= max_execution_batch_size() ||
1173          env_->NowMicros() >=
1174              open_batch_start_time_micros_ + options_.batch_timeout_micros;
1175 }
1176 
1177 template <typename TaskType>
tail_batch_task_size()1178 size_t Queue<TaskType>::tail_batch_task_size() const {
1179   if (options_.enable_lazy_split) {
1180     return task_handle_batches_.back()->size();
1181   }
1182 
1183   return batches_.back()->size();
1184 }
1185 
1186 template <typename TaskType>
num_enqueued_batches()1187 int64 Queue<TaskType>::num_enqueued_batches() const {
1188   if (options_.enable_lazy_split) {
1189     return task_handle_batches_.size();
1190   }
1191   return batches_.size();
1192 }
1193 
1194 template <typename TaskType>
QueueHandle(std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,Queue<TaskType> * queue)1195 QueueHandle<TaskType>::QueueHandle(
1196     std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,
1197     Queue<TaskType>* queue)
1198     : scheduler_(scheduler), queue_(queue) {}
1199 
1200 template <typename TaskType>
~QueueHandle()1201 QueueHandle<TaskType>::~QueueHandle() {
1202   queue_->CloseAndWaitUntilEmpty();
1203 }
1204 
1205 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)1206 Status QueueHandle<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
1207   return queue_->Schedule(task);
1208 }
1209 
1210 template <typename TaskType>
NumEnqueuedTasks()1211 size_t QueueHandle<TaskType>::NumEnqueuedTasks() const {
1212   return queue_->NumEnqueuedTasks();
1213 }
1214 
1215 template <typename TaskType>
SchedulingCapacity()1216 size_t QueueHandle<TaskType>::SchedulingCapacity() const {
1217   return queue_->SchedulingCapacity();
1218 }
1219 
1220 }  // namespace internal
1221 
1222 }  // namespace serving
1223 }  // namespace tensorflow
1224 
1225 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
1226