1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <random>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/threadpool_interface.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 
42 namespace tensorflow {
43 namespace serving {
44 namespace internal {
45 template <typename TaskType>
46 class ASBSBatch;
47 
48 template <typename TaskType>
49 class ASBSQueue;
50 }  // namespace internal
51 
52 // Shared batch scheduler designed to minimize latency. The scheduler keeps
53 // track of a number of queues (one per model or model version) which are
54 // continuously enqueuing requests. The scheduler groups the requests into
55 // batches which it periodically sends off for processing (see
56 // shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler
57 // (ASBS) prioritizes batches primarily by age (i.e. the batch's oldest request)
58 // along with a configurable preference for scheduling larger batches first.
59 //
60 //
61 // ASBS tries to keep the system busy by maintaining an adjustable number of
62 // concurrently processed batches.  If a new batch is created, and the number of
63 // in flight batches is below the target, the next (i.e. oldest) batch is
64 // immediately scheduled.  Similarly, when a batch finishes processing, the
65 // target is rechecked, and another batch may be scheduled.  To avoid the need
66 // to carefully tune the target for workload, model type, platform, etc, it is
67 // dynamically adjusted in order to provide the lowest average latency.
68 //
69 // Some potential use cases:
70 // Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
71 //   involves serial processing by a device, from a latency perspective it is
72 //   desirable to keep the device evenly loaded, avoiding the need to wait for
73 //   the device to process prior batches.
74 // CPU utilization - If the batch processing is cpu dominated, you can reap
75 //   latency gains when underutilized by increasing the processing rate, but
76 //   back the rate off when the load increases to avoid overload.
77 
78 template <typename TaskType>
79 class AdaptiveSharedBatchScheduler
80     : public std::enable_shared_from_this<
81           AdaptiveSharedBatchScheduler<TaskType>> {
82  public:
~AdaptiveSharedBatchScheduler()83   ~AdaptiveSharedBatchScheduler() {
84     // Finish processing batches before destroying other class members.
85     if (owned_batch_thread_pool_) {
86       delete batch_thread_pool_;
87     }
88   }
89 
90   struct Options {
91     // The name to use for the pool of batch threads.
92     string thread_pool_name = {"batch_threads"};
93     // Number of batch processing threads - the maximum value of
94     // in_flight_batches_limit_.  It is recommended that this value be set by
95     // running the system under load, observing the learned value for
96     // in_flight_batches_limit_, and setting this maximum to ~ 2x the value.
97     // Under low load, in_flight_batches_limit_ has no substantial effect on
98     // latency and therefore undergoes a random walk.  Unreasonably large values
99     // for num_batch_threads allows for large in_flight_batches_limit_, which
100     // will harm latency for some time once load increases again.
101     int64_t num_batch_threads = port::MaxParallelism();
102     // You can pass a ThreadPool directly rather than the above two
103     // parameters.  If given, the above two parameers are ignored.  Ownership of
104     // the threadpool is not transferred.
105     thread::ThreadPool* thread_pool = nullptr;
106 
107     // Lower bound for in_flight_batches_limit_. As discussed above, can be used
108     // to minimize the damage caused by the random walk under low load.
109     int64_t min_in_flight_batches_limit = 1;
110     // Although batch selection is primarily based on age, this parameter
111     // specifies a preference for larger batches.  A full batch will be
112     // scheduled before an older, nearly empty batch as long as the age gap is
113     // less than full_batch_scheduling_boost_micros.  The optimal value for this
114     // parameter should be of order the batch processing latency, but must be
115     // chosen carefully, as too large a value will harm tail latency.
116     int64_t full_batch_scheduling_boost_micros = 0;
117     // The environment to use (typically only overridden by test code).
118     Env* env = Env::Default();
119     // Initial limit for number of batches being concurrently processed.
120     // Non-integer values correspond to probabilistic limits - i.e. a value of
121     // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time.
122     double initial_in_flight_batches_limit = 3;
123     // Number of batches between adjustments of in_flight_batches_limit.  Larger
124     // numbers will give less noisy latency measurements, but will be less
125     // responsive to changes in workload.
126     int64_t batches_to_average_over = 1000;
127 
128     // If true, schedule batches using FIFO policy.
129     // Requires that `full_batch_scheduling_boost_micros` is zero.
130     // NOTE:
131     // A new parameter is introduced (not re-using
132     // full_batch_scheduling_boost_micros==zero) for backward compatibility of
133     // API.
134     bool fifo_scheduling = false;
135   };
136 
137   // Ownership is shared between the caller of Create() and any queues created
138   // via AddQueue().
139   static Status Create(
140       const Options& options,
141       std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
142 
143   struct QueueOptions {
144     // Maximum size of a batch that's formed within
145     // `ASBSQueue<TaskType>::Schedule`.
146     int max_batch_size = 1000;
147     // Maximum size of input task, which is submitted to the queue by
148     // calling `ASBSQueue<TaskType>::Schedule` and used to form batches.
149     //
150     // If specified, it should be larger than or equal to 'max_batch_size'.
151     absl::optional<int> max_input_task_size = absl::nullopt;
152     // Maximum number of tasks to add to a specific batch.
153     absl::optional<int> max_tasks_per_batch = absl::nullopt;
154     // Maximum number of enqueued (i.e. non-scheduled) batches.
155     int max_enqueued_batches = 10;
156     // Amount of time non-full batches must wait before becoming schedulable.
157     // A non-zero value can improve performance by limiting the scheduling of
158     // nearly empty batches.
159     int64_t batch_timeout_micros = 0;
160     // If non nullptr, split_input_task_func should split input_task into
161     // multiple tasks, the first of which has size first_size and the remaining
162     // not exceeding max_size. This function may acquire ownership of input_task
163     // and should return a status indicating if the split was successful. Upon
164     // success, the caller can assume that all output_tasks will be scheduled.
165     // Including this option allows the scheduler to pack batches better and
166     // should usually improve overall throughput.
167     std::function<Status(std::unique_ptr<TaskType>* input_task, int first_size,
168                          int max_batch_size,
169                          std::vector<std::unique_ptr<TaskType>>* output_tasks)>
170         split_input_task_func;
171   };
172 
173   using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
174 
175   // Adds queue (and its callback) to be managed by this scheduler.
176   Status AddQueue(const QueueOptions& options,
177                   BatchProcessor process_batch_callback,
178                   std::unique_ptr<BatchScheduler<TaskType>>* queue);
179 
in_flight_batches_limit()180   double in_flight_batches_limit() {
181     mutex_lock l(mu_);
182     return in_flight_batches_limit_;
183   }
184 
185  private:
186   // access to AddBatch, MaybeScheduleClosedBatches, RemoveQueue, GetEnv.
187   friend class internal::ASBSQueue<TaskType>;
188 
189   explicit AdaptiveSharedBatchScheduler(const Options& options);
190 
191   // Tracks processing latency and adjusts in_flight_batches_limit to minimize.
192   void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
193                        BatchProcessor callback, bool is_express);
194 
195   // Schedules batch if in_flight_batches_limit_ is not met.
196   void MaybeScheduleNextBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
197 
198   // Schedules batch using FIFO policy if in_flight_batches_limit_ is not met.
199   void MaybeScheduleNextBatchFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
200 
201   // Schedules all closed batches in batches_ for which an idle thread is
202   // available in batch_thread_pool_.
203   // Batches scheduled this way are called express batches.
204   // Express batches are not limited by in_flight_batches_limit_, and
205   // their latencies will not affect in_flight_batches_limit_.
206   void MaybeScheduleClosedBatches();
207 
208   void MaybeScheduleClosedBatchesLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
209 
210   void MaybeScheduleClosedBatchesLockedFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
211 
212   void MaybeAdjustInflightLimit() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
213 
214   // Notifies scheduler of non-empty batch which is eligible for processing.
215   void AddBatch(const internal::ASBSBatch<TaskType>* batch);
216 
217   // Removes queue from scheduler.
218   void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
219 
GetEnv()220   Env* GetEnv() const { return options_.env; }
221 
222   const Options options_;
223 
224   // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
225   // until they are released for processing.
226   std::vector<const internal::ASBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
227 
228   // Collection of batches added by AddBatch, ordered by age. Owned by
229   // scheduler until they are released for processing.
230   std::deque<const internal::ASBSBatch<TaskType>*> fifo_batches_
231       TF_GUARDED_BY(mu_);
232 
233   // Unowned queues and callbacks added by AddQueue.
234   std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
235       queues_and_callbacks_ TF_GUARDED_BY(mu_);
236 
237   mutex mu_;
238 
239   // Responsible for running the batch processing callbacks.
240   thread::ThreadPool* batch_thread_pool_;
241 
242   bool owned_batch_thread_pool_ = false;
243 
244   // Limit on number of batches which can be concurrently processed.
245   // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2
246   // results in an actual cap of 3 80% of the time, and 4 20% of the time.
247   double in_flight_batches_limit_ TF_GUARDED_BY(mu_);
248 
249   // Number of regular batches currently being processed.
250   int64_t in_flight_batches_ TF_GUARDED_BY(mu_) = 0;
251   // Number of express batches currently being processed.
252   int64_t in_flight_express_batches_ TF_GUARDED_BY(mu_) = 0;
253 
254   // RNG engine and distribution.
255   std::default_random_engine rand_engine_;
256   std::uniform_real_distribution<double> rand_double_;
257 
258   // Fields controlling the dynamic adjustment of in_flight_batches_limit_.
259   // Number of batches since the last in_flight_batches_limit_ adjustment.
260   int64_t batch_count_ TF_GUARDED_BY(mu_) = 0;
261 
262   struct DelayStats {
263     // Sum of processing latency for batches counted by batch_count_.
264     int64_t batch_latency_sum = 0;
265     // Average batch latency for previous value of in_flight_batches_limit_.
266     double last_avg_latency_ms = 0;
267     // Did last_avg_latency_ms decrease from the previous last_avg_latency_ms?
268     bool last_latency_decreased = false;
269     // Current direction (+-) to adjust in_flight_batches_limit_
270     int step_direction = 1;
271   };
272 
273   // Delay stats between the creation of a batch and the completion of a
274   // batch.
275   DelayStats batch_delay_stats_ TF_GUARDED_BY(mu_);
276 
277   // Max adjustment size (as a fraction of in_flight_batches_limit_).
278   constexpr static double kMaxStepSizeMultiplier = 0.125;  // 1/8;
279   // Min adjustment size (as a fraction of in_flight_batches_limit_).
280   constexpr static double kMinStepSizeMultiplier = 0.0078125;  // 1/128
281   // Current adjustment size (as a fraction of in_flight_batches_limit_).
282   double step_size_multiplier_ TF_GUARDED_BY(mu_) = kMaxStepSizeMultiplier;
283 
284   TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
285 };
286 
287 //////////////////////////////////////////////////////////
288 // Implementation details follow. API users need not read.
289 
290 namespace internal {
291 // Consolidates tasks into batches, passing them off to the
292 // AdaptiveSharedBatchScheduler for processing.
293 template <typename TaskType>
294 class ASBSQueue : public BatchScheduler<TaskType> {
295  public:
296   using QueueOptions =
297       typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
298 
299   ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
300             const QueueOptions& options);
301 
302   ~ASBSQueue() override;
303 
304   // Adds task to current batch. Fails if the task size is larger than the batch
305   // size or if the current batch is full and this queue's number of outstanding
306   // batches is at its maximum.
307   Status Schedule(std::unique_ptr<TaskType>* task) override;
308 
309   // Number of tasks waiting to be scheduled.
310   size_t NumEnqueuedTasks() const override;
311 
312   // Number of size 1 tasks which could currently be scheduled without failing.
313   size_t SchedulingCapacity() const override;
314 
315   // Notifies queue that a batch is about to be scheduled; the queue should not
316   // place any more tasks in this batch.
317   void ReleaseBatch(const ASBSBatch<TaskType>* batch);
318 
max_task_size()319   size_t max_task_size() const override { return options_.max_batch_size; }
320 
321  private:
322   // Number of size 1 tasks which could currently be scheduled without failing.
323   size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
324 
325   // Returns uint64 one greater than was returned by the previous call.
326   // Context id is reused after std::numeric_limits<uint64>::max is exhausted.
327   static uint64 NewTraceMeContextIdForBatch();
328 
329   std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
330   const QueueOptions options_;
331   // Owned by scheduler_.
332   ASBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
333   int64_t num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
334   int64_t num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
335   mutable mutex mu_;
336   TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
337 };
338 
339 // Batch which remembers when and by whom it was created.
340 template <typename TaskType>
341 class ASBSBatch : public Batch<TaskType> {
342  public:
ASBSBatch(ASBSQueue<TaskType> * queue,int64_t creation_time_micros,int64_t batch_timeout_micros,uint64 traceme_context_id)343   ASBSBatch(ASBSQueue<TaskType>* queue, int64_t creation_time_micros,
344             int64_t batch_timeout_micros, uint64 traceme_context_id)
345       : queue_(queue),
346         creation_time_micros_(creation_time_micros),
347         schedulable_time_micros_(creation_time_micros + batch_timeout_micros),
348         traceme_context_id_(traceme_context_id) {}
349 
~ASBSBatch()350   ~ASBSBatch() override {}
351 
queue()352   ASBSQueue<TaskType>* queue() const { return queue_; }
353 
creation_time_micros()354   int64_t creation_time_micros() const { return creation_time_micros_; }
355 
schedulable_time_micros()356   int64_t schedulable_time_micros() const { return schedulable_time_micros_; }
357 
traceme_context_id()358   uint64 traceme_context_id() const { return traceme_context_id_; }
359 
360  private:
361   ASBSQueue<TaskType>* queue_;
362   const int64_t creation_time_micros_;
363   const int64_t schedulable_time_micros_;
364   const uint64 traceme_context_id_;
365   TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
366 };
367 }  // namespace internal
368 
369 // ---------------- AdaptiveSharedBatchScheduler ----------------
370 
371 template <typename TaskType>
372 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMaxStepSizeMultiplier;
373 
374 template <typename TaskType>
375 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMinStepSizeMultiplier;
376 
377 template <typename TaskType>
Create(const Options & options,std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> * scheduler)378 Status AdaptiveSharedBatchScheduler<TaskType>::Create(
379     const Options& options,
380     std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
381   if (options.num_batch_threads < 1) {
382     return errors::InvalidArgument("num_batch_threads must be positive; was ",
383                                    options.num_batch_threads);
384   }
385   if (options.min_in_flight_batches_limit < 1) {
386     return errors::InvalidArgument(
387         "min_in_flight_batches_limit must be >= 1; was ",
388         options.min_in_flight_batches_limit);
389   }
390   if (options.min_in_flight_batches_limit > options.num_batch_threads) {
391     return errors::InvalidArgument(
392         "min_in_flight_batches_limit (", options.min_in_flight_batches_limit,
393         ") must be <= num_batch_threads (", options.num_batch_threads, ")");
394   }
395   if (options.full_batch_scheduling_boost_micros < 0) {
396     return errors::InvalidArgument(
397         "full_batch_scheduling_boost_micros can't be negative; was ",
398         options.full_batch_scheduling_boost_micros);
399   }
400   if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
401     return errors::InvalidArgument(
402         "initial_in_flight_batches_limit (",
403         options.initial_in_flight_batches_limit,
404         ") should not be larger than num_batch_threads (",
405         options.num_batch_threads, ")");
406   }
407   if (options.initial_in_flight_batches_limit <
408       options.min_in_flight_batches_limit) {
409     return errors::InvalidArgument("initial_in_flight_batches_limit (",
410                                    options.initial_in_flight_batches_limit,
411                                    "must be >= min_in_flight_batches_limit (",
412                                    options.min_in_flight_batches_limit, ")");
413   }
414   if (options.batches_to_average_over < 1) {
415     return errors::InvalidArgument(
416         "batches_to_average_over should be "
417         "greater than or equal to 1; was ",
418         options.batches_to_average_over);
419   }
420   scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
421   return OkStatus();
422 }
423 
424 template <typename TaskType>
AdaptiveSharedBatchScheduler(const Options & options)425 AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
426     const Options& options)
427     : options_(options),
428       in_flight_batches_limit_(options.initial_in_flight_batches_limit),
429       rand_double_(0.0, 1.0) {
430   std::random_device device;
431   rand_engine_.seed(device());
432   if (options.thread_pool == nullptr) {
433     owned_batch_thread_pool_ = true;
434     batch_thread_pool_ = new thread::ThreadPool(
435         GetEnv(), options.thread_pool_name, options.num_batch_threads);
436   } else {
437     owned_batch_thread_pool_ = false;
438     batch_thread_pool_ = options.thread_pool;
439   }
440 }
441 
442 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)443 Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
444     const QueueOptions& options, BatchProcessor process_batch_callback,
445     std::unique_ptr<BatchScheduler<TaskType>>* queue) {
446   if (options.max_batch_size <= 0) {
447     return errors::InvalidArgument("max_batch_size must be positive; was ",
448                                    options.max_batch_size);
449   }
450   if (options.max_enqueued_batches <= 0) {
451     return errors::InvalidArgument(
452         "max_enqueued_batches must be positive; was ",
453         options.max_enqueued_batches);
454   }
455   if (options.max_input_task_size.has_value()) {
456     if (options.max_input_task_size.value() < options.max_batch_size) {
457       return errors::InvalidArgument(
458           "max_input_task_size must be larger than or equal to max_batch_size;"
459           "got max_input_task_size as ",
460           options.max_input_task_size.value(), " and max_batch_size as ",
461           options.max_batch_size);
462     }
463   }
464   internal::ASBSQueue<TaskType>* asbs_queue_raw;
465   queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
466                    this->shared_from_this(), options));
467   mutex_lock l(mu_);
468   queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
469   return OkStatus();
470 }
471 
472 template <typename TaskType>
AddBatch(const internal::ASBSBatch<TaskType> * batch)473 void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
474     const internal::ASBSBatch<TaskType>* batch) {
475   mutex_lock l(mu_);
476   if (options_.fifo_scheduling) {
477     fifo_batches_.push_back(batch);
478   } else {
479     batches_.push_back(batch);
480   }
481   int64_t delay_micros =
482       batch->schedulable_time_micros() - GetEnv()->NowMicros();
483   if (delay_micros <= 0) {
484     MaybeScheduleNextBatch();
485     return;
486   }
487   // Try to schedule batch once it becomes schedulable. Although scheduler waits
488   // for all batches to finish processing before allowing itself to be deleted,
489   // MaybeScheduleNextBatch() is called in other places, and therefore it's
490   // possible the scheduler could be deleted by the time this closure runs.
491   // Grab a shared_ptr reference to prevent this from happening.
492   GetEnv()->SchedClosureAfter(
493       delay_micros, [this, lifetime_preserver = this->shared_from_this()] {
494         mutex_lock l(mu_);
495         MaybeScheduleNextBatch();
496       });
497 }
498 
499 template <typename TaskType>
RemoveQueue(const internal::ASBSQueue<TaskType> * queue)500 void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
501     const internal::ASBSQueue<TaskType>* queue) {
502   mutex_lock l(mu_);
503   queues_and_callbacks_.erase(queue);
504 }
505 
506 template <typename TaskType>
MaybeScheduleNextBatchFIFO()507 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatchFIFO() {
508   const internal::ASBSBatch<TaskType>* batch = *fifo_batches_.begin();
509   fifo_batches_.pop_front();
510   // Queue may destroy itself after ReleaseBatch is called.
511   batch->queue()->ReleaseBatch(batch);
512   batch_thread_pool_->Schedule(std::bind(
513       &AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this, batch,
514       queues_and_callbacks_[batch->queue()], false /* is express */));
515   in_flight_batches_++;
516 }
517 
518 template <typename TaskType>
519 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLockedFIFO()520     TaskType>::MaybeScheduleClosedBatchesLockedFIFO() {
521   // Only schedule closed batches if we have spare capacity.
522   int available_threads =
523       static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
524                        in_flight_express_batches_);
525   for (auto it = fifo_batches_.begin();
526        it != fifo_batches_.end() && available_threads > 0;
527        it = fifo_batches_.begin()) {
528     if ((*it)->IsClosed()) {
529       const internal::ASBSBatch<TaskType>* batch = *it;
530       fifo_batches_.pop_front();
531       batch->queue()->ReleaseBatch(batch);
532       batch_thread_pool_->Schedule(
533           std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
534                     this, batch, queues_and_callbacks_[batch->queue()], true));
535       in_flight_express_batches_++;
536       available_threads--;
537     } else {
538       // Batches are FIFO, so stop iteration after finding the first non-closed
539       // batches.
540       break;
541     }
542   }
543 }
544 
545 template <typename TaskType>
MaybeScheduleNextBatch()546 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
547   bool batch_empty =
548       options_.fifo_scheduling ? fifo_batches_.empty() : batches_.empty();
549   if (batch_empty || in_flight_batches_ >= in_flight_batches_limit_) return;
550   // Non-integer limit handled probabilistically.
551   if (in_flight_batches_limit_ - in_flight_batches_ < 1 &&
552       rand_double_(rand_engine_) >
553           in_flight_batches_limit_ - in_flight_batches_) {
554     return;
555   }
556 
557   if (options_.fifo_scheduling) {
558     MaybeScheduleNextBatchFIFO();
559     return;
560   }
561 
562   auto best_it = batches_.end();
563   double best_score = (std::numeric_limits<double>::max)();
564   int64_t now_micros = GetEnv()->NowMicros();
565   for (auto it = batches_.begin(); it != batches_.end(); it++) {
566     if ((*it)->schedulable_time_micros() > now_micros) continue;
567     const double score =
568         (*it)->creation_time_micros() -
569         options_.full_batch_scheduling_boost_micros * (*it)->size() /
570             static_cast<double>((*it)->queue()->max_task_size());
571     if (best_it == batches_.end() || score < best_score) {
572       best_score = score;
573       best_it = it;
574     }
575   }
576   // No schedulable batches.
577   if (best_it == batches_.end()) return;
578   const internal::ASBSBatch<TaskType>* batch = *best_it;
579   batches_.erase(best_it);
580   // Queue may destroy itself after ReleaseBatch is called.
581   batch->queue()->ReleaseBatch(batch);
582   batch_thread_pool_->Schedule(
583       std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
584                 batch, queues_and_callbacks_[batch->queue()], false));
585   in_flight_batches_++;
586 }
587 
588 template <typename TaskType>
MaybeScheduleClosedBatches()589 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatches() {
590   mutex_lock l(mu_);
591   MaybeScheduleClosedBatchesLocked();
592 }
593 
594 template <typename TaskType>
595 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLocked()596     TaskType>::MaybeScheduleClosedBatchesLocked() {
597   if (options_.fifo_scheduling) {
598     MaybeScheduleClosedBatchesLockedFIFO();
599     return;
600   }
601   // Only schedule closed batches if we have spare capacity.
602   int available_threads =
603       static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
604                        in_flight_express_batches_);
605   for (auto it = batches_.begin();
606        it != batches_.end() && available_threads > 0;) {
607     if ((*it)->IsClosed()) {
608       const internal::ASBSBatch<TaskType>* batch = *it;
609       it = batches_.erase(it);
610       batch->queue()->ReleaseBatch(batch);
611       batch_thread_pool_->Schedule(
612           std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
613                     this, batch, queues_and_callbacks_[batch->queue()], true));
614       in_flight_express_batches_++;
615       available_threads--;
616     } else {
617       ++it;
618     }
619   }
620 }
621 
622 template <typename TaskType>
CallbackWrapper(const internal::ASBSBatch<TaskType> * batch,AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,bool is_express)623 void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
624     const internal::ASBSBatch<TaskType>* batch,
625     AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
626     bool is_express) {
627   profiler::TraceMeConsumer trace_me(
628       [&] {
629         return profiler::TraceMeEncode(
630             "ProcessBatch", {{"batch_size_before_padding", batch->size()},
631                              {"_r", 2} /*root_event*/});
632       },
633       profiler::ContextType::kAdaptiveSharedBatchScheduler,
634       batch->traceme_context_id());
635   const int64_t start_time = batch->creation_time_micros();
636   callback(std::unique_ptr<Batch<TaskType>>(
637       const_cast<internal::ASBSBatch<TaskType>*>(batch)));
638   int64_t end_time = GetEnv()->NowMicros();
639   mutex_lock l(mu_);
640   if (is_express) {
641     in_flight_express_batches_--;
642     MaybeScheduleClosedBatchesLocked();
643     return;
644   }
645   in_flight_batches_--;
646   batch_count_++;
647   batch_delay_stats_.batch_latency_sum += end_time - start_time;
648 
649   MaybeAdjustInflightLimit();
650 
651   MaybeScheduleNextBatch();
652 }
653 
654 template <typename TaskType>
MaybeAdjustInflightLimit()655 void AdaptiveSharedBatchScheduler<TaskType>::MaybeAdjustInflightLimit() {
656   // Occasionally adjust in_flight_batches_limit_ to minimize average latency.
657   // Although the optimal value may depend on the workload, the latency should
658   // be a simple convex function of in_flight_batches_limit_, allowing us to
659   // locate the global minimum relatively quickly.
660   if (batch_count_ == options_.batches_to_average_over) {
661     double current_avg_latency_ms =
662         (batch_delay_stats_.batch_latency_sum / 1000.) / batch_count_;
663     bool current_latency_decreased =
664         current_avg_latency_ms < batch_delay_stats_.last_avg_latency_ms;
665     if (current_latency_decreased) {
666       // If latency improvement was because we're moving in the correct
667       // direction, increase step_size so that we can get to the minimum faster.
668       // If latency improvement was due to backtracking from a previous failure,
669       // decrease step_size in order to refine our location.
670       step_size_multiplier_ *=
671           (batch_delay_stats_.last_latency_decreased ? 2 : 0.5);
672       step_size_multiplier_ =
673           std::min(step_size_multiplier_, kMaxStepSizeMultiplier);
674       step_size_multiplier_ =
675           std::max(step_size_multiplier_, kMinStepSizeMultiplier);
676     } else {
677       // Return (nearly) to previous position and confirm that latency is better
678       // there before decreasing step size.
679       batch_delay_stats_.step_direction = -batch_delay_stats_.step_direction;
680     }
681     in_flight_batches_limit_ += batch_delay_stats_.step_direction *
682                                 in_flight_batches_limit_ *
683                                 step_size_multiplier_;
684     in_flight_batches_limit_ =
685         std::min(in_flight_batches_limit_,
686                  static_cast<double>(options_.num_batch_threads));
687     in_flight_batches_limit_ =
688         std::max(in_flight_batches_limit_,
689                  static_cast<double>(options_.min_in_flight_batches_limit));
690     batch_delay_stats_.last_avg_latency_ms = current_avg_latency_ms;
691     batch_delay_stats_.last_latency_decreased = current_latency_decreased;
692     batch_count_ = 0;
693     batch_delay_stats_.batch_latency_sum = 0;
694   }
695 }
696 
697 // ---------------- ASBSQueue ----------------
698 
699 namespace internal {
700 template <typename TaskType>
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,const QueueOptions & options)701 ASBSQueue<TaskType>::ASBSQueue(
702     std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
703     const QueueOptions& options)
704     : scheduler_(scheduler), options_(options) {}
705 
706 template <typename TaskType>
~ASBSQueue()707 ASBSQueue<TaskType>::~ASBSQueue() {
708   // Wait until last batch has been scheduled.
709   const int kSleepMicros = 1000;
710   for (;;) {
711     {
712       mutex_lock l(mu_);
713       if (num_enqueued_batches_ == 0) {
714         break;
715       }
716     }
717     scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
718   }
719   scheduler_->RemoveQueue(this);
720 }
721 
722 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)723 Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
724   size_t size = (*task)->size();
725   if (options_.split_input_task_func == nullptr &&
726       size > options_.max_batch_size) {
727     return errors::InvalidArgument("Task size ", size,
728                                    " is larger than maximum batch size ",
729                                    options_.max_batch_size);
730   }
731   if (options_.max_input_task_size.has_value() &&
732       (size > options_.max_input_task_size.value())) {
733     return errors::InvalidArgument("Task size ", size,
734                                    " is larger than max input task size ",
735                                    options_.max_input_task_size.value());
736   }
737 
738   std::vector<std::unique_ptr<TaskType>> tasks_to_schedule;
739   std::vector<ASBSBatch<TaskType>*> new_batches;
740   bool closed_batch = false;
741   {
742     mutex_lock l(mu_);
743     if (size > SchedulingCapacityLocked()) {
744       return errors::Unavailable("The batch scheduling queue is full");
745     }
746 
747     int remaining_batch_size =
748         current_batch_ == nullptr
749             ? options_.max_batch_size
750             : options_.max_batch_size - current_batch_->size();
751     if (options_.split_input_task_func == nullptr ||
752         size <= remaining_batch_size) {
753       // Either we don't allow task splitting or task fits within the current
754       // batch.
755       tasks_to_schedule.push_back(std::move(*task));
756     } else {
757       // Split task in order to completely fill the current batch.
758       // Beyond this point Schedule should not fail, as the caller has been
759       // promised that all of the split tasks will be scheduled.
760       TF_RETURN_IF_ERROR(options_.split_input_task_func(
761           task, remaining_batch_size, options_.max_batch_size,
762           &tasks_to_schedule));
763     }
764     for (auto& task : tasks_to_schedule) {
765       // Can't fit within current batch, close it off and try to create another.
766       if (current_batch_ &&
767           current_batch_->size() + task->size() > options_.max_batch_size) {
768         current_batch_->Close();
769         closed_batch = true;
770         current_batch_ = nullptr;
771       }
772       if (!current_batch_) {
773         num_enqueued_batches_++;
774         // batch.traceme_context_id connects TraceMeProducer and
775         // TraceMeConsumer.
776         // When multiple calls to "ASBS::Schedule" accumulate to one batch, they
777         // are processed in the same batch and should share traceme_context_id.
778         current_batch_ = new ASBSBatch<TaskType>(
779             this, scheduler_->GetEnv()->NowMicros(),
780             options_.batch_timeout_micros, NewTraceMeContextIdForBatch());
781         new_batches.push_back(current_batch_);
782       }
783 
784       // Annotate each task (corresponds to one call of schedule) with a
785       // TraceMeProducer.
786       profiler::TraceMeProducer trace_me(
787           [task_size = task->size()] {
788             return profiler::TraceMeEncode(
789                 "ASBSQueue::Schedule",
790                 {{"batching_input_task_size", task_size}});
791           },
792           profiler::ContextType::kAdaptiveSharedBatchScheduler,
793           this->current_batch_->traceme_context_id());
794       current_batch_->AddTask(std::move(task));
795       num_enqueued_tasks_++;
796       // If current_batch_ is now full, allow it to be processed immediately.
797       bool reached_max_tasks =
798           (options_.max_tasks_per_batch.has_value() &&
799            current_batch_->num_tasks() >= options_.max_tasks_per_batch.value());
800       if (current_batch_->size() == options_.max_batch_size ||
801           reached_max_tasks) {
802         current_batch_->Close();
803         closed_batch = true;
804         current_batch_ = nullptr;
805       }
806     }
807   }
808   // Scheduler functions must be called outside of lock, since they may call
809   // ReleaseBatch.
810   for (auto* batch : new_batches) {
811     scheduler_->AddBatch(batch);
812   }
813   if (closed_batch) {
814     scheduler_->MaybeScheduleClosedBatches();
815   }
816   return OkStatus();
817 }
818 
819 template <typename TaskType>
ReleaseBatch(const ASBSBatch<TaskType> * batch)820 void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
821   mutex_lock l(mu_);
822   num_enqueued_batches_--;
823   num_enqueued_tasks_ -= batch->num_tasks();
824   if (batch == current_batch_) {
825     current_batch_->Close();
826     current_batch_ = nullptr;
827   }
828 }
829 
830 template <typename TaskType>
NumEnqueuedTasks()831 size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
832   mutex_lock l(mu_);
833   return num_enqueued_tasks_;
834 }
835 
836 template <typename TaskType>
SchedulingCapacity()837 size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
838   mutex_lock l(mu_);
839   return SchedulingCapacityLocked();
840 }
841 
842 template <typename TaskType>
SchedulingCapacityLocked()843 size_t ASBSQueue<TaskType>::SchedulingCapacityLocked() const {
844   const int current_batch_capacity =
845       current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
846   const int spare_batches =
847       options_.max_enqueued_batches - num_enqueued_batches_;
848   return spare_batches * options_.max_batch_size + current_batch_capacity;
849 }
850 
851 template <typename TaskType>
852 // static
NewTraceMeContextIdForBatch()853 uint64 ASBSQueue<TaskType>::NewTraceMeContextIdForBatch() {
854   static std::atomic<uint64> traceme_context_id(0);
855   return traceme_context_id.fetch_add(1, std::memory_order_relaxed);
856 }
857 }  // namespace internal
858 }  // namespace serving
859 }  // namespace tensorflow
860 
861 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
862