xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batch_kernels.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/batch_kernels.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/framework/device.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
30 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
31 #include "tensorflow/core/kernels/batching_util/bounded_executor.h"
32 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
33 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
34 #include "tensorflow/core/kernels/ops_util.h"
35 #include "tensorflow/core/lib/monitoring/gauge.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40 #include "tensorflow/core/platform/numbers.h"
41 #include "tensorflow/core/platform/threadpool.h"
42 
43 namespace tensorflow {
44 namespace {
45 // Op attributes.
46 constexpr char kEnableAdaptiveSchedulerAttr[] = "_enable_adaptive_scheduler";
47 constexpr char kMinInflightBatchesAttr[] = "_min_inflight_batches";
48 constexpr char kInitialInflightBatchesAttr[] = "_initial_inflight_batches";
49 constexpr char kMaxInflightBatchesAttr[] = "_max_inflight_batches";
50 constexpr char kBatchesToAverageOverAttr[] = "_batches_to_average_over";
51 
52 // Default thread count in the per-process batching thread pool.
53 constexpr int64_t kBatchThreadPoolSize = 128;
54 }  // namespace
55 
56 // Per-model inflight batches parameters.
57 const int64_t kMinInflightBatches = 16;
58 const int64_t kInitialInflightBatches = 16;
59 const int64_t kBatchesToAverageOver = 10;
60 const int64_t kMaxInflightBatches = 64;
61 
62 auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
63     "/tensorflow/serving/batching/enable_large_batch_splitting",
64     "Tracks the usage of attribute `enable_large_batch_splitting` for "
65     "BatchFunction kernel in a saved model.",
66     "model_name");
67 
RecordBatchSplitUsage(absl::optional<bool> maybe_enable_large_batch_splitting,const string & model_name)68 void RecordBatchSplitUsage(
69     absl::optional<bool> maybe_enable_large_batch_splitting,
70     const string& model_name) {
71   if (maybe_enable_large_batch_splitting.has_value()) {
72     if (maybe_enable_large_batch_splitting.value()) {
73       batch_op_split_usage->GetCell(model_name)->Set("true");
74     } else {
75       batch_op_split_usage->GetCell(model_name)->Set("false");
76     }
77   } else {
78     batch_op_split_usage->GetCell(model_name)->Set("unset");
79   }
80 }
81 
RecordBatchParamNumBatchThreads(int64_t num_batch_threads,const string & model_name)82 void RecordBatchParamNumBatchThreads(int64_t num_batch_threads,
83                                      const string& model_name) {
84   static auto* cell = monitoring::Gauge<int64_t, 1>::New(
85       "/tensorflow/serving/batching/num_batch_threads",
86       "Tracks the number of batch threads of a model.", "model_name");
87   cell->GetCell(model_name)->Set(num_batch_threads);
88 }
89 
GetModelName(OpKernelContext * ctx)90 const string& GetModelName(OpKernelContext* ctx) {
91   static string* kModelNameUnset = new string("model_name_unset");
92   if (!ctx->session_metadata()) return *kModelNameUnset;
93   if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
94   return ctx->session_metadata()->name();
95 }
96 
97 using ::tensorflow::concat_split_util::Concat;
98 using ::tensorflow::concat_split_util::Split;
99 
NumBatchThreadsFromEnvironmentWithDefault(int default_num_batch_threads)100 int32 NumBatchThreadsFromEnvironmentWithDefault(int default_num_batch_threads) {
101   int32_t num;
102   const char* val = std::getenv("TF_NUM_BATCH_THREADS");
103 
104   return (val && strings::safe_strto32(val, &num)) ? num
105                                                    : default_num_batch_threads;
106 }
107 
GetOrCreateBatchThreadsPool()108 static thread::ThreadPool* GetOrCreateBatchThreadsPool() {
109   static thread::ThreadPool* shared_thread_pool = [&]() -> thread::ThreadPool* {
110     serving::BoundedExecutor::Options options;
111 
112     options.num_threads =
113         NumBatchThreadsFromEnvironmentWithDefault(kBatchThreadPoolSize);
114 
115     options.thread_name = std::string("adaptive_batch_threads");
116 
117     auto status_or_executor = serving::BoundedExecutor::Create(options);
118     if (!status_or_executor.ok()) {
119       LOG(WARNING) << "Failed to create a batch threads pool with error "
120                    << status_or_executor.status();
121       return nullptr;
122     }
123     static serving::BoundedExecutor* executor =
124         status_or_executor.ValueOrDie().release();
125     return new thread::ThreadPool(executor);
126   }();
127   return shared_thread_pool;
128 }
129 
130 // A class encapsulating the state and logic for batching tensors.
131 class BatchResource : public serving::BatchResourceBase {
132  public:
Create(int32_t num_batch_threads,int32_t max_execution_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,bool enable_large_batch_splitting,std::unique_ptr<BatchResource> * resource)133   static Status Create(int32_t num_batch_threads,
134                        int32_t max_execution_batch_size,
135                        int32_t batch_timeout_micros,
136                        int32_t max_enqueued_batches,
137                        const std::vector<int32>& allowed_batch_sizes,
138                        FunctionLibraryRuntime::Handle fhandle,
139                        FunctionLibraryRuntime* flib,
140                        bool enable_large_batch_splitting,
141                        std::unique_ptr<BatchResource>* resource) {
142     BatcherT::Options batcher_options;
143     batcher_options.num_batch_threads = num_batch_threads;
144     std::shared_ptr<BatcherT> batcher;
145     TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));
146 
147     resource->reset(new BatchResource(
148         fhandle, flib, std::move(batcher),
149         GetBatcherQueueOptions(num_batch_threads, max_execution_batch_size,
150                                batch_timeout_micros, max_enqueued_batches,
151                                allowed_batch_sizes,
152                                enable_large_batch_splitting),
153         allowed_batch_sizes));
154     return OkStatus();
155   }
156 
Create(AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::unique_ptr<BatchResource> * resource)157   static Status Create(
158       AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
159       int32_t max_batch_size, int32_t batch_timeout_micros,
160       int32_t max_enqueued_batches,
161       const std::vector<int32>& allowed_batch_sizes,
162       FunctionLibraryRuntime::Handle fhandle, FunctionLibraryRuntime* flib,
163       std::unique_ptr<BatchResource>* resource) {
164     std::shared_ptr<AdaptiveBatcherT> batcher;
165     TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
166         adaptive_shared_batch_scheduler_options, &batcher));
167 
168     resource->reset(new BatchResource(
169         fhandle, flib, std::move(batcher),
170         GetAdaptiveBatcherQueueOptions(
171             max_batch_size, batch_timeout_micros, max_enqueued_batches,
172             true /* enable large batch split */, allowed_batch_sizes),
173         allowed_batch_sizes));
174     return OkStatus();
175   }
176 
DebugString() const177   string DebugString() const final { return "BatchResource"; }
178 
179  private:
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)180   BatchResource(FunctionLibraryRuntime::Handle fhandle,
181                 FunctionLibraryRuntime* flib, std::shared_ptr<BatcherT> batcher,
182                 const BatcherT::QueueOptions& batcher_queue_options,
183                 std::vector<int32> allowed_batch_sizes)
184       : BatchResourceBase(
185             /*has_process_batch_function=*/fhandle != kInvalidHandle,
186             std::move(batcher), batcher_queue_options,
187             std::move(allowed_batch_sizes)),
188         fhandle_(fhandle),
189         flib_(flib) {}
190 
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)191   BatchResource(FunctionLibraryRuntime::Handle fhandle,
192                 FunctionLibraryRuntime* flib,
193                 std::shared_ptr<AdaptiveBatcherT> batcher,
194                 const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
195                 std::vector<int32> allowed_batch_sizes)
196       : BatchResourceBase(
197             /*has_process_batch_function=*/fhandle != kInvalidHandle,
198             std::move(batcher), batcher_queue_options,
199             std::move(allowed_batch_sizes)),
200         fhandle_(fhandle),
201         flib_(flib) {}
202 
ProcessFuncBatchImpl(const BatchTask & last_task,absl::Span<const Tensor> inputs,std::vector<Tensor> * combined_outputs,std::function<void (const Status &)> done) const203   void ProcessFuncBatchImpl(
204       const BatchTask& last_task, absl::Span<const Tensor> inputs,
205       std::vector<Tensor>* combined_outputs,
206       std::function<void(const Status&)> done) const override {
207     auto* last_task_context = last_task.context;
208     FunctionLibraryRuntime::Options opts;
209     opts.step_container = last_task_context->step_container();
210     opts.cancellation_manager = last_task_context->cancellation_manager();
211     opts.collective_executor = last_task_context->collective_executor();
212     opts.stats_collector = last_task_context->stats_collector();
213     opts.runner = last_task_context->runner();
214     opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
215     // We do not set 'opts.rendezvous', since if the function is run multiple
216     // times in parallel with the same rendezvous, a _Send node from one run
217     // might be matched with a _Recv node of a different run. Not setting the
218     // rendezvous causes a new rendezvous to be used for each run.
219     Notification done_notif;
220 
221     flib_->Run(opts, fhandle_, inputs, combined_outputs,
222                [&](const Status& run_status) {
223                  done(run_status);
224                  done_notif.Notify();
225                });
226     // By waiting for the notification we are ensuring that this thread isn't
227     // used for processing other batches, which gives the batches time to
228     // coalesce upstream. So overall the number of batches going through the
229     // devices goes down, improving latency and throughput in most cases.
230     done_notif.WaitForNotification();
231   }
232 
233   FunctionLibraryRuntime::Handle fhandle_;
234   FunctionLibraryRuntime* flib_;
235 };
236 
BatchFunctionKernel(OpKernelConstruction * c)237 BatchFunctionKernel::BatchFunctionKernel(OpKernelConstruction* c)
238     : AsyncOpKernel(c) {
239   OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
240   OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
241   OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
242   OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
243   OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
244   OP_REQUIRES_OK(c, c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
245   OP_REQUIRES_OK(c, c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
246   OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
247 
248   OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
249   flib_ = c->function_library();
250 
251   if (c->HasAttr("enable_large_batch_splitting")) {
252     OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
253                                  &enable_large_batch_splitting_));
254     has_attribute_enable_large_batch_splitting_ = true;
255   } else {
256     enable_large_batch_splitting_ = false;
257     has_attribute_enable_large_batch_splitting_ = false;
258   }
259 
260   // Helper function `SetAdaptiveBatchSchedulerOptions` calls
261   // `OP_REQUIRES_OK`, which exits the current function upon error.
262   // So validate status of `op-kernel-construction`.
263   SetAdaptiveBatchSchedulerOptions(c, num_batch_threads_);
264   if (!c->status().ok()) {
265     return;
266   }
267 
268   if (enable_adaptive_batch_threads_) {
269     // One scheduler instance contains a couple of queue instances,
270     // `batcher_queue_` is the key to find queue for this batch-op in the
271     // graph.
272     // Use `shared_name_` and name() as prefix for `batcher_queue_`.
273     // Note name() is unique per session (from session metadata).
274     batcher_queue_ = name() + "/" + shared_name_ + batcher_queue_;
275   }
276 
277   if (shared_name_.empty()) {
278     // If shared_name is not supplied, use name instead (prevent collisions by
279     // default).
280     shared_name_ = name();
281   }
282 
283   OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
284 }
285 
IsExpensive()286 bool BatchFunctionKernel::IsExpensive() { return false; }
287 
ComputeAsync(OpKernelContext * c,DoneCallback done)288 void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) {
289   RecordBatchSplitUsage(has_attribute_enable_large_batch_splitting_
290                             ? absl::make_optional(enable_large_batch_splitting_)
291                             : absl::nullopt,
292                         GetModelName(c));
293   // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
294   RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
295 
296   std::function<Status(BatchResource**)> creator;
297 
298   FunctionLibraryRuntime::Handle handle;
299   OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
300 
301   if (adaptive_batch_scheduler_options_ != absl::nullopt) {
302     creator = [this, handle](BatchResource** r) {
303       serving::AdaptiveSharedBatchScheduler<
304           serving::BatchResourceBase::BatchTask>::Options
305           adaptive_shared_batch_scheduler_options;
306       adaptive_shared_batch_scheduler_options.thread_pool_name =
307           "adaptive_batch_threads";
308       adaptive_shared_batch_scheduler_options.num_batch_threads =
309           adaptive_batch_scheduler_options_->max_in_flight_batches_limit;
310       adaptive_shared_batch_scheduler_options.thread_pool =
311           GetOrCreateBatchThreadsPool();
312       // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
313       // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
314       // way.
315       // Two rationales to use default value (zero) for
316       // `full_batch_scheduling_boost_micros`
317       // 1) In this way, tasks scheduling policy is FIFO. Compared with round
318       // robin (what shared batch scheduler does), FIFO ensures that model
319       // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
320       // will be processed timely.
321       // 2) If set, `full_batch_scheduling_boost_micros` should be of order
322       // the batch processing latency (which varies on a model basis).
323       // If a non-zero value is not set properly, it harms tail latency.
324       adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
325           adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
326       adaptive_shared_batch_scheduler_options.initial_in_flight_batches_limit =
327           adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
328       adaptive_shared_batch_scheduler_options.batches_to_average_over =
329           adaptive_batch_scheduler_options_->batches_to_average_over;
330       adaptive_shared_batch_scheduler_options.fifo_scheduling = true;
331       std::unique_ptr<BatchResource> new_resource;
332       TF_RETURN_IF_ERROR(BatchResource::Create(
333           adaptive_shared_batch_scheduler_options, max_batch_size_,
334           batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
335           handle, flib_, &new_resource));
336       *r = new_resource.release();
337       return OkStatus();
338     };
339   } else {
340     creator = [this, handle](BatchResource** r) {
341       std::unique_ptr<BatchResource> new_resource;
342       TF_RETURN_IF_ERROR(BatchResource::Create(
343           num_batch_threads_, max_batch_size_, batch_timeout_micros_,
344           max_enqueued_batches_, allowed_batch_sizes_, handle, flib_,
345           enable_large_batch_splitting_, &new_resource));
346       *r = new_resource.release();
347       return OkStatus();
348     };
349   }
350 
351   BatchResource* br;
352   OP_REQUIRES_OK_ASYNC(c,
353                        c->resource_manager()->LookupOrCreate(
354                            container_, shared_name_, &br, creator),
355                        done);
356   const Status status =
357       br->RegisterInput(random::New64(), c, batcher_queue_, done);
358   br->Unref();
359   OP_REQUIRES_OK_ASYNC(c, status, done);
360   // Assume br calls done, so nothing to do here.
361 }
362 
InstantiateFunction(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle) const363 Status BatchFunctionKernel::InstantiateFunction(
364     OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) const {
365   // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
366   if (!flib_) {
367     return errors::Internal("No function library");
368   }
369 
370   FunctionLibraryRuntime::InstantiateOptions opts;
371   opts.target = flib_->device() == nullptr ? "" : flib_->device()->name();
372   opts.is_multi_device_function = true;
373   const ConfigProto* config = flib_->config_proto();
374   if (config) {
375     opts.config_proto = *config;
376   }
377 
378   Device* cpu_device;
379   TF_RETURN_IF_ERROR(flib_->device_mgr()->LookupDevice("CPU:0", &cpu_device));
380 
381   const FunctionDef* fdef =
382       flib_->GetFunctionLibraryDefinition()->Find(func_.name());
383   if (!fdef) {
384     return errors::NotFound("Failed to find definition for function \"",
385                             func_.name(), "\"");
386   }
387   OpInputList in_tensors;
388   TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
389   for (int i = 0; i < in_tensors.size(); i++) {
390     if (in_tensors[i].dtype() == DT_RESOURCE) {
391       return errors::InvalidArgument(
392           "BatchFunction cannot take resource inputs but input ", i,
393           " is a resource.");
394     } else {
395       // Currently, inputs are on CPU since they are concatenated on CPU
396       opts.input_devices.push_back(cpu_device->name());
397     }
398   }
399   OpInputList captured_tensors;
400   TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
401   for (const Tensor& t : captured_tensors) {
402     if (t.dtype() == DT_RESOURCE) {
403       const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
404       opts.input_devices.push_back(rhandle.device());
405     } else {
406       opts.input_devices.push_back(cpu_device->name());
407     }
408   }
409   const OpDef& signature = fdef->signature();
410   for (int i = 0; i < signature.output_arg_size(); i++) {
411     // Currently, outputs must be on CPU since they are split on CPU.
412     opts.output_devices.push_back(cpu_device->name());
413   }
414   if (opts.input_devices.size() != signature.input_arg_size()) {
415     return errors::InvalidArgument(
416         "Function takes ", signature.input_arg_size(), " argument(s) but ",
417         opts.input_devices.size(), " argument(s) were passed");
418   }
419   return flib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
420                             handle);
421 }
422 
GetOrCreateFunctionHandle(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle)423 Status BatchFunctionKernel::GetOrCreateFunctionHandle(
424     OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) {
425   mutex_lock ml(mu_);
426   if (!fhandle_) {
427     TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
428     fhandle_ = *handle;
429   } else {
430     *handle = fhandle_.value();
431   }
432   return OkStatus();
433 }
434 
435 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
436 // If large batch split is not enabled, the last one must equal
437 // `max_batch_size_`. otherwise the last element must be smaller than or equal
438 // to `max_batch_size_`.
ValidateAllowedBatchSizes() const439 Status BatchFunctionKernel::ValidateAllowedBatchSizes() const {
440   if (allowed_batch_sizes_.empty()) {
441     return OkStatus();
442   }
443   int32_t last_size = 0;
444   for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
445     const int32_t size = allowed_batch_sizes_.at(i);
446     if (i > 0 && size <= last_size) {
447       return errors::InvalidArgument(
448           "allowed_batch_sizes entries must be monotonically increasing");
449     }
450 
451     if ((!enable_large_batch_splitting_) &&
452         (i == allowed_batch_sizes_.size() - 1) && (size != max_batch_size_)) {
453       return errors::InvalidArgument(
454           "final entry in allowed_batch_sizes must equal max_batch_size when "
455           "enable_large_batch_splitting is False");
456     }
457 
458     last_size = size;
459   }
460   return OkStatus();
461 }
462 
463 // Initialize vars by reading from op-kernel-construction.
464 // Vars
465 // - enable_adaptive_batch_threads_
466 //   true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or
467 //   if `num_batch_threads` is not positive.
468 // - adaptive_batch_scheduler_options_
469 //   Read from corresponding attributes as long as they are set.
SetAdaptiveBatchSchedulerOptions(OpKernelConstruction * c,int32_t num_batch_threads)470 void BatchFunctionKernel::SetAdaptiveBatchSchedulerOptions(
471     OpKernelConstruction* c, int32_t num_batch_threads) {
472   if (c->HasAttr(kEnableAdaptiveSchedulerAttr)) {
473     OP_REQUIRES_OK(c, c->GetAttr(kEnableAdaptiveSchedulerAttr,
474                                  &enable_adaptive_batch_threads_));
475   }
476 
477   if (num_batch_threads <= 0) {
478     enable_adaptive_batch_threads_ = true;
479   }
480 
481   if (!enable_adaptive_batch_threads_) {
482     // adaptive_batch_scheduler_options_ is nullopt.
483     return;
484   }
485 
486   // adaptive_batch_scheduler_options_ is not nullopt
487   AdaptiveBatchSchedulerOptions options;
488 
489   if (c->HasAttr(kBatchesToAverageOverAttr)) {
490     OP_REQUIRES_OK(c, c->GetAttr(kBatchesToAverageOverAttr,
491                                  &options.batches_to_average_over));
492   }
493 
494   if (c->HasAttr(kMinInflightBatchesAttr)) {
495     OP_REQUIRES_OK(c, c->GetAttr(kMinInflightBatchesAttr,
496                                  &options.min_in_flight_batches_limit));
497   }
498 
499   if (c->HasAttr(kInitialInflightBatchesAttr)) {
500     OP_REQUIRES_OK(c, c->GetAttr(kInitialInflightBatchesAttr,
501                                  &options.initial_in_flight_batches_limit));
502   }
503 
504   if (c->HasAttr(kMaxInflightBatchesAttr)) {
505     OP_REQUIRES_OK(c, c->GetAttr(kMaxInflightBatchesAttr,
506                                  &options.max_in_flight_batches_limit));
507   }
508 
509   // At this point, the batch kernel is configured to use adaptive scheduling.
510   // To validate or return error at kernel construction time, invokes
511   // `GetOrCreateBatchThreadsPool` and validates returned `thread_pool` is
512   // valid.
513   // Note`GetOrCreateBatchThreadsPool` creates the thread pool once and
514   // re-uses the thread-pool instance afterwards.
515   thread::ThreadPool* thread_pool = GetOrCreateBatchThreadsPool();
516   OP_REQUIRES(
517       c, thread_pool != nullptr,
518       errors::FailedPrecondition("Failed to create batch threads pool"));
519 
520   adaptive_batch_scheduler_options_ = options;
521 }
522 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
523                         BatchFunctionKernel);
524 // Currently all inputs and outputs are on the host.
525 // TODO(b/173748277): Accept inputs/outputs on the device.
526 REGISTER_KERNEL_BUILDER(Name("BatchFunction")
527                             .Device(DEVICE_GPU)
528                             .HostMemory("in_tensors")
529                             .HostMemory("captured_tensors")
530                             .HostMemory("out_tensors"),
531                         BatchFunctionKernel);
532 REGISTER_KERNEL_BUILDER(Name("BatchFunction")
533                             .Device(DEVICE_DEFAULT)
534                             .HostMemory("in_tensors")
535                             .HostMemory("captured_tensors")
536                             .HostMemory("out_tensors"),
537                         BatchFunctionKernel);
538 
539 class BatchKernel : public AsyncOpKernel {
540  public:
BatchKernel(OpKernelConstruction * c)541   explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
542     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
543     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
544     // If shared_name is not supplied, use name instead (prevent collisions by
545     // default).
546     if (shared_name_.empty()) {
547       shared_name_ = name();
548     }
549     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
550     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
551     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
552     OP_REQUIRES_OK(c,
553                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
554     OP_REQUIRES_OK(c,
555                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
556     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
557     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
558   }
559 
ComputeAsync(OpKernelContext * c,DoneCallback done)560   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
561     BatchResource* br;
562     std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
563       std::unique_ptr<BatchResource> new_resource;
564       TF_RETURN_IF_ERROR(BatchResource::Create(
565           num_batch_threads_, max_batch_size_, batch_timeout_micros_,
566           max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
567           /*flib=*/nullptr, false, &new_resource));
568       *r = new_resource.release();
569       return OkStatus();
570     };
571     OP_REQUIRES_OK_ASYNC(c,
572                          c->resource_manager()->LookupOrCreate(
573                              container_, shared_name_, &br, creator),
574                          done);
575     const Status status =
576         br->RegisterInput(random::New64(), c, batcher_queue_, done);
577     br->Unref();
578     OP_REQUIRES_OK_ASYNC(c, status, done);
579     // Assume br calls done, so nothing to do here.
580   }
581 
582   // Validates 'allowed_batch_sizes_'. The entries must increase
583   // monotonically, and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const584   Status ValidateAllowedBatchSizes() const {
585     if (allowed_batch_sizes_.empty()) {
586       return OkStatus();
587     }
588     int32_t last_size = 0;
589     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
590       const int32_t size = allowed_batch_sizes_.at(i);
591       if (i > 0 && size <= last_size) {
592         return errors::InvalidArgument(
593             "allowed_batch_sizes entries must be monotonically increasing");
594       }
595       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
596         return errors::InvalidArgument(
597             "final entry in allowed_batch_sizes must equal max_batch_size");
598       }
599       last_size = size;
600     }
601     return OkStatus();
602   }
603 
604  private:
605   string container_;
606   string shared_name_;
607   string batcher_queue_;
608   int32 num_batch_threads_;
609   int32 max_batch_size_;
610   int32 batch_timeout_micros_;
611   int32 max_enqueued_batches_;
612   std::vector<int32> allowed_batch_sizes_;
613 };
614 
615 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
616 
617 // A class encapsulating the state and logic for unbatching tensors.
618 //
619 // UnbatchResource keeps two data structures indexed by batch-key: one which has
620 // the continuations for all concurrent kernels which are waiting for tensors
621 // and another which has tensors which are waiting for their corresponding
622 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
623 // waiting already, or we insert it in the queue and then look at its tensor to
624 // see if it can be used to dispatch any stored continuations.
625 class UnbatchResource : public ResourceBase {
626  public:
UnbatchResource(int32_t timeout_micros)627   explicit UnbatchResource(int32_t timeout_micros)
628       : timeout_micros_(timeout_micros),
629         timeout_enforcer_(new serving::PeriodicFunction(
630             [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
631 
~UnbatchResource()632   ~UnbatchResource() override {
633     // Tear down 'timeout_enforcer_' first, since it accesses other state in
634     // this class.
635     timeout_enforcer_ = nullptr;
636   }
637 
DebugString() const638   string DebugString() const final { return "UnbatchResource"; }
639 
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)640   Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
641     const Tensor& data_t = context->input(0);
642     const Tensor& batch_index_t = context->input(1);
643 
644     if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
645       return errors::InvalidArgument(
646           "Wrong shape for index tensor. Expected 0th dimension size to be no "
647           "greater than ",
648           data_t.shape().dim_size(0),
649           "; Got: ", batch_index_t.shape().dim_size(0), ".");
650     }
651     if (batch_index_t.shape().dim_size(1) != 3) {
652       return errors::InvalidArgument(
653           "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
654           "Got: ",
655           batch_index_t.shape().dim_size(1), ".");
656     }
657 
658     if (!TensorShapeUtils::IsScalar(context->input(2).shape())) {
659       return errors::InvalidArgument(
660           "Input id should be scalar; "
661           "Got: ",
662           context->input(2).DebugString(), ".");
663     }
664     const int64_t batch_key = context->input(2).scalar<int64_t>()();
665     const bool nonempty_input = batch_index_t.dim_size(0) > 0;
666 
667     // If we have a non-empty tensor, slice it up.
668     // (It is important to do this outside of the critical section below.)
669     // The following variables are populated iff 'nonempty_input==true'.
670     std::vector<int64_t> sizes;
671     std::vector<int64_t> batch_keys;
672     std::vector<Tensor> split_inputs;
673     if (nonempty_input) {
674       auto batch_indices =
675           batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3});
676       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
677         sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
678         batch_keys.push_back(batch_indices(i, 0));
679       }
680 
681       TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
682     }
683 
684     // Critical section.
685     std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
686     Status status = [&]() -> Status {
687       mutex_lock ml(mu_);
688 
689       // Check to see whether the tensor we want is already ready.
690       auto tensor_it = waiting_tensors_.find(batch_key);
691       if (tensor_it != waiting_tensors_.end()) {
692         context->set_output(0, tensor_it->second.tensor);
693         waiting_tensors_.erase(tensor_it);
694         done_callbacks_to_call.push_back(done);
695         return OkStatus();
696       }
697 
698       const uint64 deadline_micros =
699           Env::Default()->NowMicros() + timeout_micros_;
700 
701       // Add ourselves to the waitlist for tensors.
702       if (!waiting_callbacks_
703                .emplace(batch_key,
704                         WaitingCallback{deadline_micros, context, done})
705                .second) {
706         return errors::AlreadyExists(
707             "Multiple session runs with the same batch key.");
708       }
709 
710       // If we have a non-empty tensor, finish the waitlisted runs,
711       // and store any remaining pieces.
712       if (nonempty_input) {
713         for (size_t i = 0; i < batch_keys.size(); ++i) {
714           auto runs_it = waiting_callbacks_.find(batch_keys[i]);
715           if (runs_it != waiting_callbacks_.end()) {
716             runs_it->second.context->set_output(0, split_inputs[i]);
717             done_callbacks_to_call.push_back(runs_it->second.done);
718             waiting_callbacks_.erase(runs_it);
719           } else {
720             // Note: the deadline here is in case we are arriving late and the
721             // kernel that should rendezvous with this tensor has already waited
722             // and timed out.
723             if (!waiting_tensors_
724                      .emplace(batch_keys[i],
725                               WaitingTensor{deadline_micros, split_inputs[i]})
726                      .second) {
727               return errors::AlreadyExists(
728                   "Multiple tensors returned for same batch key.");
729             }
730           }
731         }
732       }
733 
734       return OkStatus();
735     }();
736 
737     for (const AsyncOpKernel::DoneCallback& done_callback :
738          done_callbacks_to_call) {
739       done_callback();
740     }
741 
742     return status;
743   }
744 
745  private:
746   // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()747   void EnforceTimeout() {
748     const uint64 now = Env::Default()->NowMicros();
749     std::vector<WaitingCallback> evicted_callbacks;
750 
751     {
752       mutex_lock ml(mu_);
753 
754       for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
755         const WaitingTensor& waiting_tensor = it->second;
756         if (waiting_tensor.deadline_micros < now) {
757           it = waiting_tensors_.erase(it);
758         } else {
759           ++it;
760         }
761       }
762 
763       for (auto it = waiting_callbacks_.begin();
764            it != waiting_callbacks_.end();) {
765         const WaitingCallback& waiting_callback = it->second;
766         if (waiting_callback.deadline_micros < now) {
767           evicted_callbacks.push_back(waiting_callback);
768           it = waiting_callbacks_.erase(it);
769         } else {
770           ++it;
771         }
772       }
773     }
774 
775     for (const WaitingCallback& evicted_callback : evicted_callbacks) {
776       evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
777           "Batched data did not arrive within timeout window."));
778       evicted_callback.done();
779     }
780   }
781 
782   struct WaitingTensor {
783     uint64 deadline_micros;
784     Tensor tensor;
785   };
786 
787   struct WaitingCallback {
788     uint64 deadline_micros;
789     OpKernelContext* context;
790     AsyncOpKernel::DoneCallback done;
791   };
792 
793   const int32 timeout_micros_;
794 
795   mutex mu_;
796 
797   // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
798   // waiting for tensors.
799   std::unordered_map<int64_t, WaitingTensor> waiting_tensors_
800       TF_GUARDED_BY(mu_);
801   std::unordered_map<int64_t, WaitingCallback> waiting_callbacks_
802       TF_GUARDED_BY(mu_);
803 
804   // A thread that evicts waiting tensors and callbacks that have exceeded their
805   // deadline.
806   std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
807 };
808 
809 class UnbatchKernel : public AsyncOpKernel {
810  public:
UnbatchKernel(OpKernelConstruction * c)811   explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
812     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
813     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
814     // If shared_name is not supplied, use name instead (prevent collisions by
815     // default).
816     if (shared_name_.empty()) {
817       shared_name_ = name();
818     }
819     OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
820   }
821 
ComputeAsync(OpKernelContext * c,DoneCallback done)822   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
823     UnbatchResource* ubr;
824     std::function<Status(UnbatchResource**)> creator =
825         [this](UnbatchResource** r) {
826           *r = new UnbatchResource(timeout_micros_);
827           return OkStatus();
828         };
829     OP_REQUIRES_OK_ASYNC(c,
830                          c->resource_manager()->LookupOrCreate(
831                              container_, shared_name_, &ubr, creator),
832                          done);
833     auto status = ubr->Compute(c, done);
834     ubr->Unref();
835     OP_REQUIRES_OK_ASYNC(c, status, done);
836     // Assume ubr calls done, so nothing to do here.
837   }
838 
839  private:
840   string container_;
841   string shared_name_;
842   int32 timeout_micros_;
843 };
844 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
845 
846 // A class encapsulating the state and logic for batching tensors
847 // deterministically for the gradient of unbatch.
848 class UnbatchGradResource : public ResourceBase {
849  public:
UnbatchGradResource()850   UnbatchGradResource() {}
851 
DebugString() const852   string DebugString() const final { return "UnbatchGradResource"; }
853 
854   // Flushes the information for one batch, given its context and done
855   // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)856   Status OutputBatch(OpKernelContext* context,
857                      const AsyncOpKernel::DoneCallback& done)
858       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
859     const Tensor& batch_index_t = context->input(1);
860     auto batch_index =
861         batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3});
862     std::vector<Tensor> tensors;
863     for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
864       auto available_it = available_tensors_.find(batch_index(i, 0));
865       if (available_it == available_tensors_.end()) {
866         return errors::Internal("bad bookkeeping of available tensors.");
867       }
868       tensors.push_back(available_it->second);
869       available_tensors_.erase(available_it);
870     }
871 
872     const DataType type = tensors[0].dtype();
873     Tensor concatenated_tensor;
874     switch (type) {
875 #define CASE(type)                                                            \
876   case DataTypeToEnum<type>::value:                                           \
877     TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
878     context->set_output(0, concatenated_tensor);                              \
879     break;
880       TF_CALL_ALL_TYPES(CASE);
881 #undef CASE
882       default:
883         return errors::InvalidArgument("Unsupported data type: ", type);
884     }
885     done();
886     return OkStatus();
887   }
888 
889   // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)890   Status Compute(OpKernelContext* context,
891                  const AsyncOpKernel::DoneCallback& done) {
892     const Tensor& data_t = context->input(0);
893     const Tensor& batch_index_t = context->input(1);
894     const Tensor& grad_t = context->input(2);
895     const Tensor& batch_key_t = context->input(3);
896 
897     mutex_lock ml(mu_);
898     if (batch_key_t.NumElements() != 1) {
899       return errors::InvalidArgument("Expected `id` to be scalar. Received ",
900                                      batch_key_t.DebugString());
901     }
902 
903     const int64_t batch_key = context->input(3).scalar<int64_t>()();
904     // Mark our tensor as available.
905     if (!available_tensors_.emplace(batch_key, grad_t).second) {
906       return errors::InvalidArgument("Two runs with the same batch key.");
907     }
908 
909     // Check whether we have a valid input tensor and, if so, create its
910     // dispatch logic.
911     if (data_t.NumElements() > 0) {
912       if (batch_index_t.NumElements() == 0) {
913         return errors::InvalidArgument(
914             "batch_index is empty while the tensor isn't.");
915       }
916       std::unordered_set<int64_t> missing_tensors;
917       if (batch_index_t.NumElements() != batch_index_t.dim_size(0) * 3) {
918         return errors::InvalidArgument(
919             "batch_index should contain ", batch_index_t.dim_size(0) * 3,
920             " elements. Received ", batch_index_t.NumElements());
921       }
922       const auto batch_index =
923           batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3});
924       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
925         const int64_t batch_key = batch_index(i, 0);
926         if (available_tensors_.find(batch_key) == available_tensors_.end()) {
927           missing_tensors.emplace(batch_key);
928         }
929       }
930       if (missing_tensors.empty()) {
931         return OutputBatch(context, done);
932       }
933       if (!available_batches_
934                .emplace(batch_key, Batch{missing_tensors, context, done})
935                .second) {
936         return errors::InvalidArgument(
937             "Batch key with valid batch used twice.");
938       }
939       for (const int64_t i : missing_tensors) {
940         if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
941           return errors::InvalidArgument(
942               "Missing tensor wanted by more than one batch.");
943         }
944       }
945     } else {
946       // If we don't have a valid input tensor we can output an empty tensor and
947       // call our done closure.
948       TensorShape output_shape(grad_t.shape());
949       output_shape.set_dim(0, 0);
950       Tensor* output = nullptr;
951       TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
952       done();
953     }
954 
955     // Search to see whether our tensor is desired by any existing batch.
956     auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
957     if (desire_it != desired_tensor_to_batch_map_.end()) {
958       // Mark our tensor as no longer missing.
959       auto batch_it = available_batches_.find(desire_it->second);
960       desired_tensor_to_batch_map_.erase(desire_it);
961       if (batch_it == available_batches_.end()) {
962         return errors::InvalidArgument("Batch no longer exists.");
963       }
964       batch_it->second.missing_tensors.erase(batch_key);
965       // If all tensors are available we should concatenate them and dispatch
966       // the batch.
967       if (batch_it->second.missing_tensors.empty()) {
968         TF_RETURN_IF_ERROR(
969             OutputBatch(batch_it->second.context, batch_it->second.done));
970         available_batches_.erase(batch_it);
971       }
972     }
973     return OkStatus();
974   }
975 
976  private:
977   mutex mu_;
978 
979   // Represents a still-incomplete batch of tensors. When all tensors become
980   // available they will be concatenated in the right order and sent through the
981   // context.
982   struct Batch {
983     // Batch keys for tensors which are still missing from this batch. When this
984     // is empty the Tensors can be concatenated and forwarded.
985     std::unordered_set<int64_t> missing_tensors;
986 
987     // Context and callback for the session responsible for finishing this
988     // batch.
989     OpKernelContext* context;
990     AsyncOpKernel::DoneCallback done;
991   };
992 
993   // Map from batch key of the session which will output the batched gradients
994   // to still-incomplete batches.
995   std::unordered_map<int64_t, Batch> available_batches_;
996 
997   // Map from batch key to tensors which are waiting for their batches to be
998   // available.
999   std::unordered_map<int64_t, Tensor> available_tensors_;
1000 
1001   // Map from batch key of a tensor which is not yet available to the batch key
1002   // of the batch to which it belongs.
1003   std::unordered_map<int64_t, int64_t> desired_tensor_to_batch_map_;
1004 };
1005 
1006 class UnbatchGradKernel : public AsyncOpKernel {
1007  public:
UnbatchGradKernel(OpKernelConstruction * c)1008   explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1009     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1010     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1011     // If shared_name is not supplied, use name instead (prevent collisions by
1012     // default).
1013     if (shared_name_.empty()) {
1014       shared_name_ = name();
1015     }
1016   }
1017 
ComputeAsync(OpKernelContext * c,DoneCallback done)1018   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1019     UnbatchGradResource* ubr;
1020     std::function<Status(UnbatchGradResource**)> creator =
1021         [](UnbatchGradResource** r) {
1022           *r = new UnbatchGradResource();
1023           return OkStatus();
1024         };
1025     OP_REQUIRES_OK_ASYNC(c,
1026                          c->resource_manager()->LookupOrCreate(
1027                              container_, shared_name_, &ubr, creator),
1028                          done);
1029     Status status = ubr->Compute(c, done);
1030     ubr->Unref();
1031     OP_REQUIRES_OK_ASYNC(c, status, done);
1032     // Assume ubr calls done, so nothing to do here.
1033   }
1034 
1035  private:
1036   string container_;
1037   string shared_name_;
1038 };
1039 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1040                         UnbatchGradKernel);
1041 
1042 }  // namespace tensorflow
1043