1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
18
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <random>
24 #include <unordered_map>
25 #include <vector>
26
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/threadpool_interface.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41
42 namespace tensorflow {
43 namespace serving {
44 namespace internal {
45 template <typename TaskType>
46 class ASBSBatch;
47
48 template <typename TaskType>
49 class ASBSQueue;
50 } // namespace internal
51
52 // Shared batch scheduler designed to minimize latency. The scheduler keeps
53 // track of a number of queues (one per model or model version) which are
54 // continuously enqueuing requests. The scheduler groups the requests into
55 // batches which it periodically sends off for processing (see
56 // shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler
57 // (ASBS) prioritizes batches primarily by age (i.e. the batch's oldest request)
58 // along with a configurable preference for scheduling larger batches first.
59 //
60 //
61 // ASBS tries to keep the system busy by maintaining an adjustable number of
62 // concurrently processed batches. If a new batch is created, and the number of
63 // in flight batches is below the target, the next (i.e. oldest) batch is
64 // immediately scheduled. Similarly, when a batch finishes processing, the
65 // target is rechecked, and another batch may be scheduled. To avoid the need
66 // to carefully tune the target for workload, model type, platform, etc, it is
67 // dynamically adjusted in order to provide the lowest average latency.
68 //
69 // Some potential use cases:
70 // Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
71 // involves serial processing by a device, from a latency perspective it is
72 // desirable to keep the device evenly loaded, avoiding the need to wait for
73 // the device to process prior batches.
74 // CPU utilization - If the batch processing is cpu dominated, you can reap
75 // latency gains when underutilized by increasing the processing rate, but
76 // back the rate off when the load increases to avoid overload.
77
78 template <typename TaskType>
79 class AdaptiveSharedBatchScheduler
80 : public std::enable_shared_from_this<
81 AdaptiveSharedBatchScheduler<TaskType>> {
82 public:
~AdaptiveSharedBatchScheduler()83 ~AdaptiveSharedBatchScheduler() {
84 // Finish processing batches before destroying other class members.
85 if (owned_batch_thread_pool_) {
86 delete batch_thread_pool_;
87 }
88 }
89
90 struct Options {
91 // The name to use for the pool of batch threads.
92 string thread_pool_name = {"batch_threads"};
93 // Number of batch processing threads - the maximum value of
94 // in_flight_batches_limit_. It is recommended that this value be set by
95 // running the system under load, observing the learned value for
96 // in_flight_batches_limit_, and setting this maximum to ~ 2x the value.
97 // Under low load, in_flight_batches_limit_ has no substantial effect on
98 // latency and therefore undergoes a random walk. Unreasonably large values
99 // for num_batch_threads allows for large in_flight_batches_limit_, which
100 // will harm latency for some time once load increases again.
101 int64_t num_batch_threads = port::MaxParallelism();
102 // You can pass a ThreadPool directly rather than the above two
103 // parameters. If given, the above two parameers are ignored. Ownership of
104 // the threadpool is not transferred.
105 thread::ThreadPool* thread_pool = nullptr;
106
107 // Lower bound for in_flight_batches_limit_. As discussed above, can be used
108 // to minimize the damage caused by the random walk under low load.
109 int64_t min_in_flight_batches_limit = 1;
110 // Although batch selection is primarily based on age, this parameter
111 // specifies a preference for larger batches. A full batch will be
112 // scheduled before an older, nearly empty batch as long as the age gap is
113 // less than full_batch_scheduling_boost_micros. The optimal value for this
114 // parameter should be of order the batch processing latency, but must be
115 // chosen carefully, as too large a value will harm tail latency.
116 int64_t full_batch_scheduling_boost_micros = 0;
117 // The environment to use (typically only overridden by test code).
118 Env* env = Env::Default();
119 // Initial limit for number of batches being concurrently processed.
120 // Non-integer values correspond to probabilistic limits - i.e. a value of
121 // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time.
122 double initial_in_flight_batches_limit = 3;
123 // Number of batches between adjustments of in_flight_batches_limit. Larger
124 // numbers will give less noisy latency measurements, but will be less
125 // responsive to changes in workload.
126 int64_t batches_to_average_over = 1000;
127
128 // If true, schedule batches using FIFO policy.
129 // Requires that `full_batch_scheduling_boost_micros` is zero.
130 // NOTE:
131 // A new parameter is introduced (not re-using
132 // full_batch_scheduling_boost_micros==zero) for backward compatibility of
133 // API.
134 bool fifo_scheduling = false;
135 };
136
137 // Ownership is shared between the caller of Create() and any queues created
138 // via AddQueue().
139 static Status Create(
140 const Options& options,
141 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
142
143 struct QueueOptions {
144 // Maximum size of a batch that's formed within
145 // `ASBSQueue<TaskType>::Schedule`.
146 int max_batch_size = 1000;
147 // Maximum size of input task, which is submitted to the queue by
148 // calling `ASBSQueue<TaskType>::Schedule` and used to form batches.
149 //
150 // If specified, it should be larger than or equal to 'max_batch_size'.
151 absl::optional<int> max_input_task_size = absl::nullopt;
152 // Maximum number of tasks to add to a specific batch.
153 absl::optional<int> max_tasks_per_batch = absl::nullopt;
154 // Maximum number of enqueued (i.e. non-scheduled) batches.
155 int max_enqueued_batches = 10;
156 // Amount of time non-full batches must wait before becoming schedulable.
157 // A non-zero value can improve performance by limiting the scheduling of
158 // nearly empty batches.
159 int64_t batch_timeout_micros = 0;
160 // If non nullptr, split_input_task_func should split input_task into
161 // multiple tasks, the first of which has size first_size and the remaining
162 // not exceeding max_size. This function may acquire ownership of input_task
163 // and should return a status indicating if the split was successful. Upon
164 // success, the caller can assume that all output_tasks will be scheduled.
165 // Including this option allows the scheduler to pack batches better and
166 // should usually improve overall throughput.
167 std::function<Status(std::unique_ptr<TaskType>* input_task, int first_size,
168 int max_batch_size,
169 std::vector<std::unique_ptr<TaskType>>* output_tasks)>
170 split_input_task_func;
171 };
172
173 using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
174
175 // Adds queue (and its callback) to be managed by this scheduler.
176 Status AddQueue(const QueueOptions& options,
177 BatchProcessor process_batch_callback,
178 std::unique_ptr<BatchScheduler<TaskType>>* queue);
179
in_flight_batches_limit()180 double in_flight_batches_limit() {
181 mutex_lock l(mu_);
182 return in_flight_batches_limit_;
183 }
184
185 private:
186 // access to AddBatch, MaybeScheduleClosedBatches, RemoveQueue, GetEnv.
187 friend class internal::ASBSQueue<TaskType>;
188
189 explicit AdaptiveSharedBatchScheduler(const Options& options);
190
191 // Tracks processing latency and adjusts in_flight_batches_limit to minimize.
192 void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
193 BatchProcessor callback, bool is_express);
194
195 // Schedules batch if in_flight_batches_limit_ is not met.
196 void MaybeScheduleNextBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
197
198 // Schedules batch using FIFO policy if in_flight_batches_limit_ is not met.
199 void MaybeScheduleNextBatchFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
200
201 // Schedules all closed batches in batches_ for which an idle thread is
202 // available in batch_thread_pool_.
203 // Batches scheduled this way are called express batches.
204 // Express batches are not limited by in_flight_batches_limit_, and
205 // their latencies will not affect in_flight_batches_limit_.
206 void MaybeScheduleClosedBatches();
207
208 void MaybeScheduleClosedBatchesLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
209
210 void MaybeScheduleClosedBatchesLockedFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
211
212 void MaybeAdjustInflightLimit() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
213
214 // Notifies scheduler of non-empty batch which is eligible for processing.
215 void AddBatch(const internal::ASBSBatch<TaskType>* batch);
216
217 // Removes queue from scheduler.
218 void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
219
GetEnv()220 Env* GetEnv() const { return options_.env; }
221
222 const Options options_;
223
224 // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
225 // until they are released for processing.
226 std::vector<const internal::ASBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
227
228 // Collection of batches added by AddBatch, ordered by age. Owned by
229 // scheduler until they are released for processing.
230 std::deque<const internal::ASBSBatch<TaskType>*> fifo_batches_
231 TF_GUARDED_BY(mu_);
232
233 // Unowned queues and callbacks added by AddQueue.
234 std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
235 queues_and_callbacks_ TF_GUARDED_BY(mu_);
236
237 mutex mu_;
238
239 // Responsible for running the batch processing callbacks.
240 thread::ThreadPool* batch_thread_pool_;
241
242 bool owned_batch_thread_pool_ = false;
243
244 // Limit on number of batches which can be concurrently processed.
245 // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2
246 // results in an actual cap of 3 80% of the time, and 4 20% of the time.
247 double in_flight_batches_limit_ TF_GUARDED_BY(mu_);
248
249 // Number of regular batches currently being processed.
250 int64_t in_flight_batches_ TF_GUARDED_BY(mu_) = 0;
251 // Number of express batches currently being processed.
252 int64_t in_flight_express_batches_ TF_GUARDED_BY(mu_) = 0;
253
254 // RNG engine and distribution.
255 std::default_random_engine rand_engine_;
256 std::uniform_real_distribution<double> rand_double_;
257
258 // Fields controlling the dynamic adjustment of in_flight_batches_limit_.
259 // Number of batches since the last in_flight_batches_limit_ adjustment.
260 int64_t batch_count_ TF_GUARDED_BY(mu_) = 0;
261
262 struct DelayStats {
263 // Sum of processing latency for batches counted by batch_count_.
264 int64_t batch_latency_sum = 0;
265 // Average batch latency for previous value of in_flight_batches_limit_.
266 double last_avg_latency_ms = 0;
267 // Did last_avg_latency_ms decrease from the previous last_avg_latency_ms?
268 bool last_latency_decreased = false;
269 // Current direction (+-) to adjust in_flight_batches_limit_
270 int step_direction = 1;
271 };
272
273 // Delay stats between the creation of a batch and the completion of a
274 // batch.
275 DelayStats batch_delay_stats_ TF_GUARDED_BY(mu_);
276
277 // Max adjustment size (as a fraction of in_flight_batches_limit_).
278 constexpr static double kMaxStepSizeMultiplier = 0.125; // 1/8;
279 // Min adjustment size (as a fraction of in_flight_batches_limit_).
280 constexpr static double kMinStepSizeMultiplier = 0.0078125; // 1/128
281 // Current adjustment size (as a fraction of in_flight_batches_limit_).
282 double step_size_multiplier_ TF_GUARDED_BY(mu_) = kMaxStepSizeMultiplier;
283
284 TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
285 };
286
287 //////////////////////////////////////////////////////////
288 // Implementation details follow. API users need not read.
289
290 namespace internal {
291 // Consolidates tasks into batches, passing them off to the
292 // AdaptiveSharedBatchScheduler for processing.
293 template <typename TaskType>
294 class ASBSQueue : public BatchScheduler<TaskType> {
295 public:
296 using QueueOptions =
297 typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
298
299 ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
300 const QueueOptions& options);
301
302 ~ASBSQueue() override;
303
304 // Adds task to current batch. Fails if the task size is larger than the batch
305 // size or if the current batch is full and this queue's number of outstanding
306 // batches is at its maximum.
307 Status Schedule(std::unique_ptr<TaskType>* task) override;
308
309 // Number of tasks waiting to be scheduled.
310 size_t NumEnqueuedTasks() const override;
311
312 // Number of size 1 tasks which could currently be scheduled without failing.
313 size_t SchedulingCapacity() const override;
314
315 // Notifies queue that a batch is about to be scheduled; the queue should not
316 // place any more tasks in this batch.
317 void ReleaseBatch(const ASBSBatch<TaskType>* batch);
318
max_task_size()319 size_t max_task_size() const override { return options_.max_batch_size; }
320
321 private:
322 // Number of size 1 tasks which could currently be scheduled without failing.
323 size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
324
325 // Returns uint64 one greater than was returned by the previous call.
326 // Context id is reused after std::numeric_limits<uint64>::max is exhausted.
327 static uint64 NewTraceMeContextIdForBatch();
328
329 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
330 const QueueOptions options_;
331 // Owned by scheduler_.
332 ASBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
333 int64_t num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
334 int64_t num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
335 mutable mutex mu_;
336 TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
337 };
338
339 // Batch which remembers when and by whom it was created.
340 template <typename TaskType>
341 class ASBSBatch : public Batch<TaskType> {
342 public:
ASBSBatch(ASBSQueue<TaskType> * queue,int64_t creation_time_micros,int64_t batch_timeout_micros,uint64 traceme_context_id)343 ASBSBatch(ASBSQueue<TaskType>* queue, int64_t creation_time_micros,
344 int64_t batch_timeout_micros, uint64 traceme_context_id)
345 : queue_(queue),
346 creation_time_micros_(creation_time_micros),
347 schedulable_time_micros_(creation_time_micros + batch_timeout_micros),
348 traceme_context_id_(traceme_context_id) {}
349
~ASBSBatch()350 ~ASBSBatch() override {}
351
queue()352 ASBSQueue<TaskType>* queue() const { return queue_; }
353
creation_time_micros()354 int64_t creation_time_micros() const { return creation_time_micros_; }
355
schedulable_time_micros()356 int64_t schedulable_time_micros() const { return schedulable_time_micros_; }
357
traceme_context_id()358 uint64 traceme_context_id() const { return traceme_context_id_; }
359
360 private:
361 ASBSQueue<TaskType>* queue_;
362 const int64_t creation_time_micros_;
363 const int64_t schedulable_time_micros_;
364 const uint64 traceme_context_id_;
365 TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
366 };
367 } // namespace internal
368
369 // ---------------- AdaptiveSharedBatchScheduler ----------------
370
371 template <typename TaskType>
372 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMaxStepSizeMultiplier;
373
374 template <typename TaskType>
375 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMinStepSizeMultiplier;
376
377 template <typename TaskType>
Create(const Options & options,std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> * scheduler)378 Status AdaptiveSharedBatchScheduler<TaskType>::Create(
379 const Options& options,
380 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
381 if (options.num_batch_threads < 1) {
382 return errors::InvalidArgument("num_batch_threads must be positive; was ",
383 options.num_batch_threads);
384 }
385 if (options.min_in_flight_batches_limit < 1) {
386 return errors::InvalidArgument(
387 "min_in_flight_batches_limit must be >= 1; was ",
388 options.min_in_flight_batches_limit);
389 }
390 if (options.min_in_flight_batches_limit > options.num_batch_threads) {
391 return errors::InvalidArgument(
392 "min_in_flight_batches_limit (", options.min_in_flight_batches_limit,
393 ") must be <= num_batch_threads (", options.num_batch_threads, ")");
394 }
395 if (options.full_batch_scheduling_boost_micros < 0) {
396 return errors::InvalidArgument(
397 "full_batch_scheduling_boost_micros can't be negative; was ",
398 options.full_batch_scheduling_boost_micros);
399 }
400 if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
401 return errors::InvalidArgument(
402 "initial_in_flight_batches_limit (",
403 options.initial_in_flight_batches_limit,
404 ") should not be larger than num_batch_threads (",
405 options.num_batch_threads, ")");
406 }
407 if (options.initial_in_flight_batches_limit <
408 options.min_in_flight_batches_limit) {
409 return errors::InvalidArgument("initial_in_flight_batches_limit (",
410 options.initial_in_flight_batches_limit,
411 "must be >= min_in_flight_batches_limit (",
412 options.min_in_flight_batches_limit, ")");
413 }
414 if (options.batches_to_average_over < 1) {
415 return errors::InvalidArgument(
416 "batches_to_average_over should be "
417 "greater than or equal to 1; was ",
418 options.batches_to_average_over);
419 }
420 scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
421 return OkStatus();
422 }
423
424 template <typename TaskType>
AdaptiveSharedBatchScheduler(const Options & options)425 AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
426 const Options& options)
427 : options_(options),
428 in_flight_batches_limit_(options.initial_in_flight_batches_limit),
429 rand_double_(0.0, 1.0) {
430 std::random_device device;
431 rand_engine_.seed(device());
432 if (options.thread_pool == nullptr) {
433 owned_batch_thread_pool_ = true;
434 batch_thread_pool_ = new thread::ThreadPool(
435 GetEnv(), options.thread_pool_name, options.num_batch_threads);
436 } else {
437 owned_batch_thread_pool_ = false;
438 batch_thread_pool_ = options.thread_pool;
439 }
440 }
441
442 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)443 Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
444 const QueueOptions& options, BatchProcessor process_batch_callback,
445 std::unique_ptr<BatchScheduler<TaskType>>* queue) {
446 if (options.max_batch_size <= 0) {
447 return errors::InvalidArgument("max_batch_size must be positive; was ",
448 options.max_batch_size);
449 }
450 if (options.max_enqueued_batches <= 0) {
451 return errors::InvalidArgument(
452 "max_enqueued_batches must be positive; was ",
453 options.max_enqueued_batches);
454 }
455 if (options.max_input_task_size.has_value()) {
456 if (options.max_input_task_size.value() < options.max_batch_size) {
457 return errors::InvalidArgument(
458 "max_input_task_size must be larger than or equal to max_batch_size;"
459 "got max_input_task_size as ",
460 options.max_input_task_size.value(), " and max_batch_size as ",
461 options.max_batch_size);
462 }
463 }
464 internal::ASBSQueue<TaskType>* asbs_queue_raw;
465 queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
466 this->shared_from_this(), options));
467 mutex_lock l(mu_);
468 queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
469 return OkStatus();
470 }
471
472 template <typename TaskType>
AddBatch(const internal::ASBSBatch<TaskType> * batch)473 void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
474 const internal::ASBSBatch<TaskType>* batch) {
475 mutex_lock l(mu_);
476 if (options_.fifo_scheduling) {
477 fifo_batches_.push_back(batch);
478 } else {
479 batches_.push_back(batch);
480 }
481 int64_t delay_micros =
482 batch->schedulable_time_micros() - GetEnv()->NowMicros();
483 if (delay_micros <= 0) {
484 MaybeScheduleNextBatch();
485 return;
486 }
487 // Try to schedule batch once it becomes schedulable. Although scheduler waits
488 // for all batches to finish processing before allowing itself to be deleted,
489 // MaybeScheduleNextBatch() is called in other places, and therefore it's
490 // possible the scheduler could be deleted by the time this closure runs.
491 // Grab a shared_ptr reference to prevent this from happening.
492 GetEnv()->SchedClosureAfter(
493 delay_micros, [this, lifetime_preserver = this->shared_from_this()] {
494 mutex_lock l(mu_);
495 MaybeScheduleNextBatch();
496 });
497 }
498
499 template <typename TaskType>
RemoveQueue(const internal::ASBSQueue<TaskType> * queue)500 void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
501 const internal::ASBSQueue<TaskType>* queue) {
502 mutex_lock l(mu_);
503 queues_and_callbacks_.erase(queue);
504 }
505
506 template <typename TaskType>
MaybeScheduleNextBatchFIFO()507 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatchFIFO() {
508 const internal::ASBSBatch<TaskType>* batch = *fifo_batches_.begin();
509 fifo_batches_.pop_front();
510 // Queue may destroy itself after ReleaseBatch is called.
511 batch->queue()->ReleaseBatch(batch);
512 batch_thread_pool_->Schedule(std::bind(
513 &AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this, batch,
514 queues_and_callbacks_[batch->queue()], false /* is express */));
515 in_flight_batches_++;
516 }
517
518 template <typename TaskType>
519 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLockedFIFO()520 TaskType>::MaybeScheduleClosedBatchesLockedFIFO() {
521 // Only schedule closed batches if we have spare capacity.
522 int available_threads =
523 static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
524 in_flight_express_batches_);
525 for (auto it = fifo_batches_.begin();
526 it != fifo_batches_.end() && available_threads > 0;
527 it = fifo_batches_.begin()) {
528 if ((*it)->IsClosed()) {
529 const internal::ASBSBatch<TaskType>* batch = *it;
530 fifo_batches_.pop_front();
531 batch->queue()->ReleaseBatch(batch);
532 batch_thread_pool_->Schedule(
533 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
534 this, batch, queues_and_callbacks_[batch->queue()], true));
535 in_flight_express_batches_++;
536 available_threads--;
537 } else {
538 // Batches are FIFO, so stop iteration after finding the first non-closed
539 // batches.
540 break;
541 }
542 }
543 }
544
545 template <typename TaskType>
MaybeScheduleNextBatch()546 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
547 bool batch_empty =
548 options_.fifo_scheduling ? fifo_batches_.empty() : batches_.empty();
549 if (batch_empty || in_flight_batches_ >= in_flight_batches_limit_) return;
550 // Non-integer limit handled probabilistically.
551 if (in_flight_batches_limit_ - in_flight_batches_ < 1 &&
552 rand_double_(rand_engine_) >
553 in_flight_batches_limit_ - in_flight_batches_) {
554 return;
555 }
556
557 if (options_.fifo_scheduling) {
558 MaybeScheduleNextBatchFIFO();
559 return;
560 }
561
562 auto best_it = batches_.end();
563 double best_score = (std::numeric_limits<double>::max)();
564 int64_t now_micros = GetEnv()->NowMicros();
565 for (auto it = batches_.begin(); it != batches_.end(); it++) {
566 if ((*it)->schedulable_time_micros() > now_micros) continue;
567 const double score =
568 (*it)->creation_time_micros() -
569 options_.full_batch_scheduling_boost_micros * (*it)->size() /
570 static_cast<double>((*it)->queue()->max_task_size());
571 if (best_it == batches_.end() || score < best_score) {
572 best_score = score;
573 best_it = it;
574 }
575 }
576 // No schedulable batches.
577 if (best_it == batches_.end()) return;
578 const internal::ASBSBatch<TaskType>* batch = *best_it;
579 batches_.erase(best_it);
580 // Queue may destroy itself after ReleaseBatch is called.
581 batch->queue()->ReleaseBatch(batch);
582 batch_thread_pool_->Schedule(
583 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
584 batch, queues_and_callbacks_[batch->queue()], false));
585 in_flight_batches_++;
586 }
587
588 template <typename TaskType>
MaybeScheduleClosedBatches()589 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatches() {
590 mutex_lock l(mu_);
591 MaybeScheduleClosedBatchesLocked();
592 }
593
594 template <typename TaskType>
595 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLocked()596 TaskType>::MaybeScheduleClosedBatchesLocked() {
597 if (options_.fifo_scheduling) {
598 MaybeScheduleClosedBatchesLockedFIFO();
599 return;
600 }
601 // Only schedule closed batches if we have spare capacity.
602 int available_threads =
603 static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
604 in_flight_express_batches_);
605 for (auto it = batches_.begin();
606 it != batches_.end() && available_threads > 0;) {
607 if ((*it)->IsClosed()) {
608 const internal::ASBSBatch<TaskType>* batch = *it;
609 it = batches_.erase(it);
610 batch->queue()->ReleaseBatch(batch);
611 batch_thread_pool_->Schedule(
612 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
613 this, batch, queues_and_callbacks_[batch->queue()], true));
614 in_flight_express_batches_++;
615 available_threads--;
616 } else {
617 ++it;
618 }
619 }
620 }
621
622 template <typename TaskType>
CallbackWrapper(const internal::ASBSBatch<TaskType> * batch,AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,bool is_express)623 void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
624 const internal::ASBSBatch<TaskType>* batch,
625 AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
626 bool is_express) {
627 profiler::TraceMeConsumer trace_me(
628 [&] {
629 return profiler::TraceMeEncode(
630 "ProcessBatch", {{"batch_size_before_padding", batch->size()},
631 {"_r", 2} /*root_event*/});
632 },
633 profiler::ContextType::kAdaptiveSharedBatchScheduler,
634 batch->traceme_context_id());
635 const int64_t start_time = batch->creation_time_micros();
636 callback(std::unique_ptr<Batch<TaskType>>(
637 const_cast<internal::ASBSBatch<TaskType>*>(batch)));
638 int64_t end_time = GetEnv()->NowMicros();
639 mutex_lock l(mu_);
640 if (is_express) {
641 in_flight_express_batches_--;
642 MaybeScheduleClosedBatchesLocked();
643 return;
644 }
645 in_flight_batches_--;
646 batch_count_++;
647 batch_delay_stats_.batch_latency_sum += end_time - start_time;
648
649 MaybeAdjustInflightLimit();
650
651 MaybeScheduleNextBatch();
652 }
653
654 template <typename TaskType>
MaybeAdjustInflightLimit()655 void AdaptiveSharedBatchScheduler<TaskType>::MaybeAdjustInflightLimit() {
656 // Occasionally adjust in_flight_batches_limit_ to minimize average latency.
657 // Although the optimal value may depend on the workload, the latency should
658 // be a simple convex function of in_flight_batches_limit_, allowing us to
659 // locate the global minimum relatively quickly.
660 if (batch_count_ == options_.batches_to_average_over) {
661 double current_avg_latency_ms =
662 (batch_delay_stats_.batch_latency_sum / 1000.) / batch_count_;
663 bool current_latency_decreased =
664 current_avg_latency_ms < batch_delay_stats_.last_avg_latency_ms;
665 if (current_latency_decreased) {
666 // If latency improvement was because we're moving in the correct
667 // direction, increase step_size so that we can get to the minimum faster.
668 // If latency improvement was due to backtracking from a previous failure,
669 // decrease step_size in order to refine our location.
670 step_size_multiplier_ *=
671 (batch_delay_stats_.last_latency_decreased ? 2 : 0.5);
672 step_size_multiplier_ =
673 std::min(step_size_multiplier_, kMaxStepSizeMultiplier);
674 step_size_multiplier_ =
675 std::max(step_size_multiplier_, kMinStepSizeMultiplier);
676 } else {
677 // Return (nearly) to previous position and confirm that latency is better
678 // there before decreasing step size.
679 batch_delay_stats_.step_direction = -batch_delay_stats_.step_direction;
680 }
681 in_flight_batches_limit_ += batch_delay_stats_.step_direction *
682 in_flight_batches_limit_ *
683 step_size_multiplier_;
684 in_flight_batches_limit_ =
685 std::min(in_flight_batches_limit_,
686 static_cast<double>(options_.num_batch_threads));
687 in_flight_batches_limit_ =
688 std::max(in_flight_batches_limit_,
689 static_cast<double>(options_.min_in_flight_batches_limit));
690 batch_delay_stats_.last_avg_latency_ms = current_avg_latency_ms;
691 batch_delay_stats_.last_latency_decreased = current_latency_decreased;
692 batch_count_ = 0;
693 batch_delay_stats_.batch_latency_sum = 0;
694 }
695 }
696
697 // ---------------- ASBSQueue ----------------
698
699 namespace internal {
700 template <typename TaskType>
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,const QueueOptions & options)701 ASBSQueue<TaskType>::ASBSQueue(
702 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
703 const QueueOptions& options)
704 : scheduler_(scheduler), options_(options) {}
705
706 template <typename TaskType>
~ASBSQueue()707 ASBSQueue<TaskType>::~ASBSQueue() {
708 // Wait until last batch has been scheduled.
709 const int kSleepMicros = 1000;
710 for (;;) {
711 {
712 mutex_lock l(mu_);
713 if (num_enqueued_batches_ == 0) {
714 break;
715 }
716 }
717 scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
718 }
719 scheduler_->RemoveQueue(this);
720 }
721
722 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)723 Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
724 size_t size = (*task)->size();
725 if (options_.split_input_task_func == nullptr &&
726 size > options_.max_batch_size) {
727 return errors::InvalidArgument("Task size ", size,
728 " is larger than maximum batch size ",
729 options_.max_batch_size);
730 }
731 if (options_.max_input_task_size.has_value() &&
732 (size > options_.max_input_task_size.value())) {
733 return errors::InvalidArgument("Task size ", size,
734 " is larger than max input task size ",
735 options_.max_input_task_size.value());
736 }
737
738 std::vector<std::unique_ptr<TaskType>> tasks_to_schedule;
739 std::vector<ASBSBatch<TaskType>*> new_batches;
740 bool closed_batch = false;
741 {
742 mutex_lock l(mu_);
743 if (size > SchedulingCapacityLocked()) {
744 return errors::Unavailable("The batch scheduling queue is full");
745 }
746
747 int remaining_batch_size =
748 current_batch_ == nullptr
749 ? options_.max_batch_size
750 : options_.max_batch_size - current_batch_->size();
751 if (options_.split_input_task_func == nullptr ||
752 size <= remaining_batch_size) {
753 // Either we don't allow task splitting or task fits within the current
754 // batch.
755 tasks_to_schedule.push_back(std::move(*task));
756 } else {
757 // Split task in order to completely fill the current batch.
758 // Beyond this point Schedule should not fail, as the caller has been
759 // promised that all of the split tasks will be scheduled.
760 TF_RETURN_IF_ERROR(options_.split_input_task_func(
761 task, remaining_batch_size, options_.max_batch_size,
762 &tasks_to_schedule));
763 }
764 for (auto& task : tasks_to_schedule) {
765 // Can't fit within current batch, close it off and try to create another.
766 if (current_batch_ &&
767 current_batch_->size() + task->size() > options_.max_batch_size) {
768 current_batch_->Close();
769 closed_batch = true;
770 current_batch_ = nullptr;
771 }
772 if (!current_batch_) {
773 num_enqueued_batches_++;
774 // batch.traceme_context_id connects TraceMeProducer and
775 // TraceMeConsumer.
776 // When multiple calls to "ASBS::Schedule" accumulate to one batch, they
777 // are processed in the same batch and should share traceme_context_id.
778 current_batch_ = new ASBSBatch<TaskType>(
779 this, scheduler_->GetEnv()->NowMicros(),
780 options_.batch_timeout_micros, NewTraceMeContextIdForBatch());
781 new_batches.push_back(current_batch_);
782 }
783
784 // Annotate each task (corresponds to one call of schedule) with a
785 // TraceMeProducer.
786 profiler::TraceMeProducer trace_me(
787 [task_size = task->size()] {
788 return profiler::TraceMeEncode(
789 "ASBSQueue::Schedule",
790 {{"batching_input_task_size", task_size}});
791 },
792 profiler::ContextType::kAdaptiveSharedBatchScheduler,
793 this->current_batch_->traceme_context_id());
794 current_batch_->AddTask(std::move(task));
795 num_enqueued_tasks_++;
796 // If current_batch_ is now full, allow it to be processed immediately.
797 bool reached_max_tasks =
798 (options_.max_tasks_per_batch.has_value() &&
799 current_batch_->num_tasks() >= options_.max_tasks_per_batch.value());
800 if (current_batch_->size() == options_.max_batch_size ||
801 reached_max_tasks) {
802 current_batch_->Close();
803 closed_batch = true;
804 current_batch_ = nullptr;
805 }
806 }
807 }
808 // Scheduler functions must be called outside of lock, since they may call
809 // ReleaseBatch.
810 for (auto* batch : new_batches) {
811 scheduler_->AddBatch(batch);
812 }
813 if (closed_batch) {
814 scheduler_->MaybeScheduleClosedBatches();
815 }
816 return OkStatus();
817 }
818
819 template <typename TaskType>
ReleaseBatch(const ASBSBatch<TaskType> * batch)820 void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
821 mutex_lock l(mu_);
822 num_enqueued_batches_--;
823 num_enqueued_tasks_ -= batch->num_tasks();
824 if (batch == current_batch_) {
825 current_batch_->Close();
826 current_batch_ = nullptr;
827 }
828 }
829
830 template <typename TaskType>
NumEnqueuedTasks()831 size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
832 mutex_lock l(mu_);
833 return num_enqueued_tasks_;
834 }
835
836 template <typename TaskType>
SchedulingCapacity()837 size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
838 mutex_lock l(mu_);
839 return SchedulingCapacityLocked();
840 }
841
842 template <typename TaskType>
SchedulingCapacityLocked()843 size_t ASBSQueue<TaskType>::SchedulingCapacityLocked() const {
844 const int current_batch_capacity =
845 current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
846 const int spare_batches =
847 options_.max_enqueued_batches - num_enqueued_batches_;
848 return spare_batches * options_.max_batch_size + current_batch_capacity;
849 }
850
851 template <typename TaskType>
852 // static
NewTraceMeContextIdForBatch()853 uint64 ASBSQueue<TaskType>::NewTraceMeContextIdForBatch() {
854 static std::atomic<uint64> traceme_context_id(0);
855 return traceme_context_id.fetch_add(1, std::memory_order_relaxed);
856 }
857 } // namespace internal
858 } // namespace serving
859 } // namespace tensorflow
860
861 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
862