xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batching_util/batch_resource_base.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "absl/time/time.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/common_runtime/cost_constants.h"
23 #include "tensorflow/core/common_runtime/cost_measurement.h"
24 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
25 #include "tensorflow/core/common_runtime/cost_util.h"
26 #include "tensorflow/core/common_runtime/request_cost_accessor.h"
27 #include "tensorflow/core/common_runtime/request_cost_accessor_registry.h"
28 #include "tensorflow/core/framework/ops_util.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/lib/monitoring/gauge.h"
34 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
35 #include "tensorflow/core/lib/monitoring/sampler.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 #include "tensorflow/core/profiler/lib/traceme_encode.h"
38 #include "tensorflow/core/util/incremental_barrier.h"
39 
40 namespace tensorflow {
41 namespace serving {
42 namespace {
43 
44 // TODO(b/181883417): Replace with RecordPaddingSizeV2.
RecordPaddingSize(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)45 void RecordPaddingSize(int32_t padding_size, const string& model_name,
46                        int32_t execution_batch_size, const string& op_name) {
47   static auto* cell = tensorflow::monitoring::PercentileSampler<3>::New(
48       {"/tensorflow/serving/batching/padding_size",
49        "Tracks the padding size distribution on batches by model_name (if "
50        "available).",
51        "model_name", "execution_batch_size", "op_name"},
52       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
53       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
54   cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
55       ->Add(static_cast<double>(padding_size));
56 }
57 
RecordPaddingSizeV2(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)58 void RecordPaddingSizeV2(int32_t padding_size, const string& model_name,
59                          int32_t execution_batch_size, const string& op_name) {
60   static auto* cell = tensorflow::monitoring::Sampler<3>::New(
61       {"/tensorflow/serving/batching/padding_size_v2",
62        "Tracks the padding size distribution on batches by model_name (if "
63        "available).",
64        "model_name", "execution_batch_size", "op_name"},
65       // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
66       // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
67       monitoring::Buckets::Exponential(1, 2, 14));
68   cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
69       ->Add(static_cast<double>(padding_size));
70 }
71 
72 // TODO(b/181883417): Replace with RecordInputBatchSizeV2.
RecordInputBatchSize(int32_t batch_size,const string & model_name,const string & op_name)73 void RecordInputBatchSize(int32_t batch_size, const string& model_name,
74                           const string& op_name) {
75   static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
76       {"/tensorflow/serving/batching/input_batch_size",
77        "Tracks the batch size distribution on the inputs by model_name (if "
78        "available).",
79        "model_name", "op_name"},
80       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
81       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
82   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
83 }
84 
RecordInputBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)85 void RecordInputBatchSizeV2(int32_t batch_size, const string& model_name,
86                             const string& op_name) {
87   static auto* cell = tensorflow::monitoring::Sampler<2>::New(
88       {"/tensorflow/serving/batching/input_batch_size_v2",
89        "Tracks the batch size distribution on the inputs by model_name (if "
90        "available).",
91        "model_name", "op_name"},
92       // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
93       // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
94       monitoring::Buckets::Exponential(1, 2, 14));
95   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
96 }
97 
98 // Record the actual batch size without padding.
RecordBatchSize(int32_t batch_size,const string & model_name,const string & op_name)99 void RecordBatchSize(int32_t batch_size, const string& model_name,
100                      const string& op_name) {
101   static auto* cell = tensorflow::monitoring::Sampler<2>::New(
102       {"/tensorflow/serving/batching/batch_size",
103        "Tracks the batch size distribution on the batch result by model_name "
104        "(if available).",
105        "model_name", "op_name"},
106       monitoring::Buckets::Exponential(1, 1.5, 20));
107   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
108 }
109 
RecordProcessedBatchSize(int32_t batch_size,const string & model_name,const string & op_name)110 void RecordProcessedBatchSize(int32_t batch_size, const string& model_name,
111                               const string& op_name) {
112   static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
113       {"/tensorflow/serving/batching/processed_batch_size",
114        "Tracks the batch size distribution on processing by model_name (if "
115        "available).",
116        "model_name", "op_name"},
117       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
118       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
119   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
120 }
121 
122 // Export the exact number instead of the distribution of processed batch size.
RecordProcessedBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)123 void RecordProcessedBatchSizeV2(int32_t batch_size, const string& model_name,
124                                 const string& op_name) {
125   static auto* cell = monitoring::Counter<3>::New(
126       "/tensorflow/serving/batching/processed_batch_size_v2",
127       "Tracks the batch size on processing by model_name and op name (if "
128       "available).",
129       "model_name", "op_name", "batch_size");
130   cell->GetCell(model_name, op_name, std::to_string(batch_size))
131       ->IncrementBy(1);
132 }
133 
134 // TODO(b/181883417): Replace with RecordBatchDelayUsV2.
RecordBatchDelayUs(int64_t batch_delay_us,const string & model_name,const string & op_name,int32_t batch_size)135 void RecordBatchDelayUs(int64_t batch_delay_us, const string& model_name,
136                         const string& op_name, int32_t batch_size) {
137   static auto* cell = monitoring::PercentileSampler<3>::New(
138       {"/tensorflow/serving/batching/batch_delay_us",
139        "Tracks the batching delay (in microseconds) for inputs by model_name "
140        "(if available).",
141        "model_name", "op_name", "processed_batch_size"},
142       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
143       /*max_samples=*/1024, monitoring::UnitOfMeasure::kTime);
144   cell->GetCell(model_name, op_name, std::to_string(batch_size))
145       ->Add(static_cast<double>(batch_delay_us));
146 }
147 
RecordBatchDelayUsV2(int64_t batch_delay_us,const string & model_name,const string & op_name,int32_t batch_size)148 void RecordBatchDelayUsV2(int64_t batch_delay_us, const string& model_name,
149                           const string& op_name, int32_t batch_size) {
150   static auto* cell = tensorflow::monitoring::Sampler<3>::New(
151       {"/tensorflow/serving/batching/batch_delay_us_v2",
152        "Tracks the batching delay (in microseconds) for inputs by model_name "
153        "(if available).",
154        "model_name", "op_name", "processed_batch_size"},
155       // It's 27 buckets with the last bucket being 2^26 to DBL_MAX;
156       // so the limits are [1, 2, 4, 8, ..., 64 * 1024 * 1024, DBL_MAX].
157       monitoring::Buckets::Exponential(1, 2, 27));
158   cell->GetCell(model_name, op_name, std::to_string(batch_size))
159       ->Add(static_cast<double>(batch_delay_us));
160 }
161 
RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,const string & model_name,const string & op_name)162 void RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,
163                                         const string& model_name,
164                                         const string& op_name) {
165   static auto* cell = monitoring::Gauge<int64_t, 2>::New(
166       "/tensorflow/serving/batching/batch_timeout_micros",
167       "Tracks how long a request can wait before being processed by a batch.",
168       "model_name", "op_name");
169   cell->GetCell(model_name, op_name)->Set(batch_timeout_micros);
170 }
171 
RecordBatchParamMaxBatchSize(int64_t max_batch_size,const string & model_name,const string & op_name)172 void RecordBatchParamMaxBatchSize(int64_t max_batch_size,
173                                   const string& model_name,
174                                   const string& op_name) {
175   static auto* cell = monitoring::Gauge<int64_t, 2>::New(
176       "/tensorflow/serving/batching/max_batch_size",
177       "Tracks the maximum size of a batch.", "model_name", "op_name");
178   cell->GetCell(model_name, op_name)->Set(max_batch_size);
179 }
180 
RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,const string & model_name,const string & op_name)181 void RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,
182                                         const string& model_name,
183                                         const string& op_name) {
184   static auto* cell = monitoring::Gauge<int64_t, 2>::New(
185       "/tensorflow/serving/batching/max_enqueued_batches",
186       "Tracks the maximum number of enqueued batches.", "model_name",
187       "op_name");
188   cell->GetCell(model_name, op_name)->Set(max_enqueued_batches);
189 }
190 
RecordBatchParamAllowedBatchSizes(const string & allowed_batch_sizes,const string & model_name,const string & op_name)191 void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
192                                        const string& model_name,
193                                        const string& op_name) {
194   static auto* cell = monitoring::Gauge<string, 2>::New(
195       "/tensorflow/serving/batching/allowed_batch_sizes",
196       "Tracks the sizes that are allowed to form a batch.", "model_name",
197       "op_name");
198   cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes);
199 }
200 
GetModelName(OpKernelContext * ctx)201 const string& GetModelName(OpKernelContext* ctx) {
202   static string* kModelNameUnset = new string("model_name_unset");
203   if (!ctx->session_metadata()) return *kModelNameUnset;
204   if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
205   return ctx->session_metadata()->name();
206 }
207 
208 }  // namespace
209 
210 std::unique_ptr<BatchResourceBase::BatchTask>
CreateSplitTask(int split_index,AsyncOpKernel::DoneCallback done_callback)211 BatchResourceBase::BatchTask::CreateSplitTask(
212     int split_index, AsyncOpKernel::DoneCallback done_callback) {
213   std::unique_ptr<BatchTask> task = CreateDerivedTask();
214 
215   task->guid = this->guid;
216   task->propagated_context = Context(ContextKind::kThread);
217   task->inputs.reserve(this->inputs.size());
218   task->captured_inputs = this->captured_inputs;
219   task->context = this->context;
220   task->done_callback = done_callback;
221   task->split_index = split_index;
222   task->output = this->output;
223   task->status = this->status;
224   task->is_partial = true;
225   task->start_time = this->start_time;
226   task->request_cost = this->request_cost;
227 
228   return task;
229 }
230 
231 using ::tensorflow::concat_split_util::Concat;
232 using ::tensorflow::concat_split_util::Split;
233 using TensorMatrix = std::vector<std::vector<Tensor>>;
234 
RegisterInput(int64_t guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)235 Status BatchResourceBase::RegisterInput(
236     int64_t guid, OpKernelContext* context, const string& batcher_queue_name,
237     AsyncOpKernel::DoneCallback done_callback) {
238   std::unique_ptr<BatchTask> batch_components;
239   TF_RETURN_IF_ERROR(CreateBatchTask(context, &batch_components));
240   batch_components->start_time = EnvTime::NowNanos();
241   batch_components->guid = guid;
242   batch_components->propagated_context = Context(ContextKind::kThread);
243   OpInputList tensors;
244   TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
245   batch_components->inputs.reserve(tensors.size());
246   for (const Tensor& tensor : tensors) {
247     if (tensor.shape().dims() == 0) {
248       return errors::InvalidArgument(
249           "Batching input tensors must have at least one dimension");
250     }
251     if (tensors.size() >= 2 &&
252         tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
253       return errors::InvalidArgument(
254           "Batching input tensors supplied in a given op invocation must "
255           "have equal 0th-dimension size");
256     }
257     batch_components->inputs.push_back(tensor);
258   }
259   RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context),
260                        context->op_kernel().name());
261   RecordInputBatchSizeV2(tensors[0].shape().dim_size(0), GetModelName(context),
262                          context->op_kernel().name());
263   if (batcher_) {
264     RecordBatchParamBatchTimeoutMicros(
265         batcher_queue_options_.batch_timeout_micros, GetModelName(context),
266         context->op_kernel().name());
267     RecordBatchParamMaxBatchSize(
268         batcher_queue_options_.max_execution_batch_size, GetModelName(context),
269         context->op_kernel().name());
270     RecordBatchParamMaxEnqueuedBatches(
271         batcher_queue_options_.max_enqueued_batches, GetModelName(context),
272         context->op_kernel().name());
273   } else if (adaptive_batcher_) {
274     RecordBatchParamBatchTimeoutMicros(
275         adaptive_batcher_queue_options_.batch_timeout_micros,
276         GetModelName(context), context->op_kernel().name());
277     RecordBatchParamMaxBatchSize(adaptive_batcher_queue_options_.max_batch_size,
278                                  GetModelName(context),
279                                  context->op_kernel().name());
280     RecordBatchParamMaxEnqueuedBatches(
281         adaptive_batcher_queue_options_.max_enqueued_batches,
282         GetModelName(context), context->op_kernel().name());
283   } else {
284     return errors::Internal("No batcher defined.");
285   }
286   RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
287                                     GetModelName(context),
288                                     context->op_kernel().name());
289 
290   // Degenerate case where the input is empty. Just return an empty tensor.
291   if (tensors[0].shape().dim_size(0) == 0) {
292     for (int i = 0; i < context->num_outputs(); i++) {
293       Tensor* empty_output;
294       AllocatorAttributes cpu_alloc;
295       cpu_alloc.set_on_host(true);
296       TF_RETURN_IF_ERROR(context->allocate_output(i, TensorShape({0}),
297                                                   &empty_output, cpu_alloc));
298     }
299     done_callback();
300     return OkStatus();
301   }
302   OpInputList captured_tensors;
303   const auto captured_status =
304       context->input_list("captured_tensors", &captured_tensors);
305   if (captured_status.ok()) {
306     batch_components->captured_inputs.reserve(captured_tensors.size());
307     for (const Tensor& captured_tensor : captured_tensors) {
308       batch_components->captured_inputs.push_back(captured_tensor);
309     }
310   }
311   batch_components->context = context;
312   batch_components->done_callback = std::move(done_callback);
313   batch_components->split_index = 0;
314   batch_components->output = std::make_shared<TensorMatrix>();
315   batch_components->status = std::make_shared<ThreadSafeStatus>();
316 
317   std::unique_ptr<RequestCostAccessor> request_cost_accessor =
318       CreateRequestCostAccessor();
319   if (request_cost_accessor) {
320     batch_components->request_cost = request_cost_accessor->GetRequestCost();
321   }
322 
323   BatcherQueueT* batcher_queue;
324   TF_RETURN_IF_ERROR(
325       LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
326   return batcher_queue->Schedule(&batch_components);
327 }
328 
329 /*static*/ BatchResourceBase::BatcherT::QueueOptions
GetBatcherQueueOptions(int32_t num_batch_threads,int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,bool enable_large_batch_splitting)330 BatchResourceBase::GetBatcherQueueOptions(
331     int32_t num_batch_threads, int32_t max_batch_size,
332     int32_t batch_timeout_micros, int32_t max_enqueued_batches,
333     const std::vector<int32>& allowed_batch_sizes,
334     bool enable_large_batch_splitting) {
335   BatcherT::QueueOptions batcher_queue_options;
336   batcher_queue_options.input_batch_size_limit = max_batch_size;
337   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
338   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
339   batcher_queue_options.enable_large_batch_splitting =
340       enable_large_batch_splitting;
341   if (enable_large_batch_splitting) {
342     batcher_queue_options.split_input_task_func =
343         [](std::unique_ptr<BatchTask>* input_task,
344            int open_batch_remaining_slot, int max_batch_size,
345            std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
346       return SplitInputTask(input_task, open_batch_remaining_slot,
347                             max_batch_size, output_tasks);
348     };
349 
350     if (allowed_batch_sizes.empty()) {
351       batcher_queue_options.max_execution_batch_size = max_batch_size;
352     } else {
353       batcher_queue_options.max_execution_batch_size =
354           *allowed_batch_sizes.rbegin();
355     }
356   }
357 
358   return batcher_queue_options;
359 }
360 
361 /*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
GetAdaptiveBatcherQueueOptions(int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,bool enable_large_batch_splitting,const std::vector<int32> & allowed_batch_sizes)362 BatchResourceBase::GetAdaptiveBatcherQueueOptions(
363     int32_t max_batch_size, int32_t batch_timeout_micros,
364     int32_t max_enqueued_batches, bool enable_large_batch_splitting,
365     const std::vector<int32>& allowed_batch_sizes) {
366   AdaptiveBatcherT::QueueOptions batcher_queue_options;
367   batcher_queue_options.max_input_task_size =
368       absl::make_optional(max_batch_size);
369   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
370   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
371   if (allowed_batch_sizes.empty()) {
372     batcher_queue_options.max_batch_size = max_batch_size;
373   } else {
374     batcher_queue_options.max_batch_size = *allowed_batch_sizes.rbegin();
375   }
376 
377   if (enable_large_batch_splitting) {
378     batcher_queue_options.split_input_task_func =
379         [](std::unique_ptr<BatchTask>* input_task,
380            int open_batch_remaining_slot, int max_batch_size,
381            std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
382       return SplitInputTask(input_task, open_batch_remaining_slot,
383                             max_batch_size, output_tasks);
384     };
385   }
386 
387   return batcher_queue_options;
388 }
389 
ValidateBatch(const BatchT & batch)390 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
391   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
392     const BatchResourceBase::BatchTask& task = batch.task(task_idx);
393 
394     if (task.inputs.size() != batch.task(0).inputs.size()) {
395       return errors::InvalidArgument(
396           "Batching inputs must have equal number of edges");
397     }
398   }
399 
400   return OkStatus();
401 }
402 
403 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
404 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
405 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const406 int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const {
407   if (allowed_batch_sizes_.empty()) {
408     return batch_size;
409   }
410   for (int allowed_size : allowed_batch_sizes_) {
411     if (allowed_size >= batch_size) {
412       return allowed_size;
413     }
414   }
415   LOG(ERROR) << "Batch size " << batch_size
416              << " is greater than largest allowed size; "
417                 "ignoring allowed sizes constraint.";
418   return batch_size;
419 }
420 
ConcatInputTensors(const BatchT & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const421 Status BatchResourceBase::ConcatInputTensors(
422     const BatchT& batch, OpKernelContext* context,
423     std::vector<Tensor>* concatenated_tensors) const {
424   if (batch.num_tasks() == 0) {
425     return errors::InvalidArgument("Empty batch.");
426   }
427 
428   const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
429   const int padding_amount = padded_batch_size - batch.size();
430   profiler::TraceMe trace_me([padded_batch_size, padding_amount]() {
431     return profiler::TraceMeEncode(
432         "ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size},
433                                {"padding_amount", padding_amount}});
434   });
435   RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size,
436                     context->op_kernel().name());
437   RecordPaddingSizeV2(padding_amount, GetModelName(context), padded_batch_size,
438                       context->op_kernel().name());
439   RecordProcessedBatchSize(padded_batch_size, GetModelName(context),
440                            context->op_kernel().name());
441   RecordProcessedBatchSizeV2(padded_batch_size, GetModelName(context),
442                              context->op_kernel().name());
443   RecordBatchSize(batch.size(), GetModelName(context),
444                   context->op_kernel().name());
445 
446   // All tasks should have the same number of input edges.
447   const int num_inputs = batch.task(0).inputs.size();
448   concatenated_tensors->reserve(num_inputs);
449 
450   // Process each input one at a time (the typical case has just one).
451   for (int i = 0; i < num_inputs; ++i) {
452     // Concatenate the tasks ith input tensors into a big output tensor.
453     std::vector<Tensor> to_concatenate;
454     to_concatenate.reserve(batch.num_tasks());
455     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
456       to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
457     }
458 
459     // Add padding as needed. Use the first row of the first task's tensor as
460     // the data for padding.
461     if (padding_amount > 0) {
462       const Tensor& padding_source = batch.task(0).inputs.at(i);
463       Tensor padding;
464       if (padding_source.shape().dim_size(0) == 0) {
465         return errors::InvalidArgument(
466             "Cannot use an empty tensor with zero rows as padding when "
467             "batching. (Input ",
468             i, " got shape ", padding_source.shape().DebugString(), ".)");
469       }
470       if (padding_source.shape().dim_size(0) == 1) {
471         padding = padding_source;
472       } else {
473         padding = padding_source.Slice(0, 1);
474       }
475       for (int i = 0; i < padding_amount; ++i) {
476         to_concatenate.push_back(padding);
477       }
478     }
479 
480     Tensor concatenated_tensor;
481     Status concat_status =
482         Concat(context, to_concatenate, &concatenated_tensor);
483     TF_RETURN_IF_ERROR(concat_status);
484     concatenated_tensors->push_back(concatenated_tensor);
485   }
486   return OkStatus();
487 }
488 
SplitInputTask(std::unique_ptr<BatchTask> * input_task_ptr,int open_batch_remaining_slot,int max_batch_size,std::vector<std::unique_ptr<BatchTask>> * output_tasks)489 /*static*/ Status BatchResourceBase::SplitInputTask(
490     std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
491     int max_batch_size, std::vector<std::unique_ptr<BatchTask>>* output_tasks) {
492   BatchTask& input_task = *(*input_task_ptr);
493   const int64_t input_task_size = input_task.size();
494 
495   DCHECK_GT(input_task_size, 0);
496 
497   std::shared_ptr<ThreadSafeStatus> shared_status = input_task.status;
498 
499   // `split_task_done_callback` runs only after all splitted tasks are
500   // complete.
501   std::function<void()> split_task_done_callback =
502       [done_callback = input_task.done_callback, output = input_task.output,
503        op_kernel_context = input_task.context, status = shared_status]() {
504         const int num_output = op_kernel_context->num_outputs();
505         for (int i = 0; i < num_output; ++i) {
506           Tensor output_tensor;
507 
508           // Concat would memcpy each input tensor to one output tensor.
509           // In this context, Concat can be further optimized to get rid of
510           // some (probably all) memcpy when input tensors are slices of
511           // another copy.
512           std::vector<Tensor> to_concatenate;
513           to_concatenate.reserve(output->size());
514           for (int j = 0; j < output->size(); ++j) {
515             to_concatenate.push_back(std::move((*output)[j][i]));
516           }
517           const auto concat_status =
518               Concat(op_kernel_context, to_concatenate, &output_tensor);
519           if (!concat_status.ok()) {
520             status->Update(concat_status);
521           }
522 
523           op_kernel_context->set_output(i, std::move(output_tensor));
524         }
525         op_kernel_context->SetStatus(status->status());
526         done_callback();
527       };
528   IncrementalBarrier barrier(split_task_done_callback);
529 
530   const internal::InputSplitMetadata input_split_metadata(
531       input_task_size, open_batch_remaining_slot, max_batch_size);
532 
533   const absl::FixedArray<int>& task_sizes = input_split_metadata.task_sizes();
534   const int num_batches = task_sizes.size();
535   std::vector<int64_t> output_task_sizes;
536   output_task_sizes.resize(num_batches);
537   for (int i = 0; i < num_batches; i++) {
538     output_task_sizes[i] = task_sizes[i];
539   }
540 
541   input_task.output->resize(num_batches);
542   for (int i = 0; i < num_batches; ++i) {
543     (*input_task.output)[i].resize(input_task.context->num_outputs());
544   }
545 
546   output_tasks->reserve(num_batches);
547   for (int i = 0; i < num_batches; i++) {
548     output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
549   }
550 
551   const int num_input_tensors = input_task.inputs.size();
552 
553   // Splits each input tensor according to `output_task_sizes`, and
554   // initializes input of `output_tasks` with split results.
555   for (int i = 0; i < num_input_tensors; ++i) {
556     std::vector<Tensor> split_tensors;
557     const Tensor& input_tensor = input_task.inputs[i];
558     // TODO(b/154140947):
559     // Figure out the optimal implementation of Split, by using
560     // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
561     const Status split_status = Split(input_task.context, input_tensor,
562                                       output_task_sizes, &split_tensors);
563     if (!split_status.ok()) {
564       return errors::Internal(
565           "When splitting input, Tensor split operation failed: ",
566           split_status.error_message());
567     }
568     if (split_tensors.size() != output_task_sizes.size()) {
569       return errors::Internal(
570           "When splitting input, tensor split operation did not work as "
571           "expected; got ",
572           split_tensors.size(), " splits; expected ", output_task_sizes.size());
573     }
574     for (int j = 0; j < output_tasks->size(); ++j) {
575       BatchTask& output_task = *((*output_tasks)[j]);
576       auto moved_tensor_iter = std::next(split_tensors.begin(), j);
577       std::move(moved_tensor_iter, moved_tensor_iter + 1,
578                 std::back_inserter(output_task.inputs));
579     }
580   }
581   return OkStatus();
582 }
583 
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,BatchT * batch) const584 Status BatchResourceBase::SplitOutputTensors(
585     const std::vector<Tensor>& combined_outputs, BatchT* batch) const {
586   DCHECK_GE(batch->num_tasks(), 1);
587   if (batch->num_tasks() < 1) {
588     return errors::Internal("Batch size expected to be positive; was ",
589                             batch->num_tasks());
590   }
591 
592   std::vector<int64_t> task_sizes_plus_optional_padding;
593   task_sizes_plus_optional_padding.reserve(batch->num_tasks());
594   for (int i = 0; i < batch->num_tasks(); ++i) {
595     task_sizes_plus_optional_padding.push_back(batch->task(i).size());
596   }
597   const int padding_size =
598       RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
599   if (padding_size > 0) {
600     task_sizes_plus_optional_padding.push_back(padding_size);
601   }
602 
603   // For each output tensor name, a divided-up tensor with one entry per task.
604   std::map<string, std::vector<Tensor>> split_tensors;
605 
606   DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
607   int combined_outputs_size = combined_outputs.size();
608   if (combined_outputs_size != batch->task(0).context->num_outputs()) {
609     return errors::Internal("Wrong number of batched output tensors");
610   }
611 
612   // Generate 'split_tensors' and populate the context outputs.
613   for (int i = 0, iter_limit = combined_outputs.size(); i < iter_limit; ++i) {
614     const Tensor& output_tensor = combined_outputs[i];
615     if (output_tensor.shape().dims() == 0) {
616       return errors::FailedPrecondition(
617           "Batched output tensor has 0 dimensions");
618     }
619     if (output_tensor.shape().dim_size(0) !=
620         static_cast<int64_t>(batch->size() + padding_size)) {
621       return errors::FailedPrecondition(
622           "Batched output tensor's 0th dimension does not equal the sum of "
623           "the 0th dimension sizes of the input tensors");
624     }
625 
626     std::vector<Tensor> split_tensor;
627     const Status split_status = tensor::Split(
628         output_tensor, task_sizes_plus_optional_padding, &split_tensor);
629     DCHECK(split_status.ok()) << split_status.ToString();
630     if (!split_status.ok()) {
631       return errors::Internal("Tensor split operation failed: ",
632                               split_status.error_message());
633     }
634     DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
635     if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
636       return errors::Internal(
637           "Tensor split operation did not work as expected; got ",
638           split_tensor.size(), " splits; expected ",
639           task_sizes_plus_optional_padding.size());
640     }
641 
642     // Ignore a possible final split_tensors entry containing the padding.
643     for (int j = 0; j < batch->num_tasks(); ++j) {
644       BatchTask& task = *(batch->mutable_task(j));
645       if (task.is_partial) {
646         std::vector<Tensor>& tensor_vector = (*task.output)[task.split_index];
647         tensor_vector[i] = std::move(split_tensor[j]);
648       } else {
649         task.context->set_output(i, split_tensor[j]);
650       }
651     }
652   }
653 
654   return OkStatus();
655 }
656 
ProcessFuncBatch(std::unique_ptr<BatchT> batch) const657 void BatchResourceBase::ProcessFuncBatch(std::unique_ptr<BatchT> batch) const {
658   if (batch->empty()) {
659     return;
660   }
661 
662   // We use the 'propagated_context' from one of the threads which setup one
663   // of the tasks. This will propagate any common context over all the threads
664   // which are running this Session, of which this BatchOp is a part.
665   WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
666 
667   // TODO(b/185852990): Add a unit test to check the context is correctly set.
668   // Creates the CostMeasurements within the same context that runs the Session.
669   const CostMeasurement::Context batching_context{/*is_per_query=*/false};
670   std::vector<std::unique_ptr<CostMeasurement>> batch_cost_measurements =
671       CreateCostMeasurements(batching_context);
672 
673   auto& last_task = batch->task(batch->num_tasks() - 1);
674   OpKernelContext* last_task_context = last_task.context;
675 
676   // Regardless of the outcome, we need to propagate the status to the
677   // individual tasks and signal that they are done. We use MakeCleanup() to
678   // ensure that this happens no matter how we exit the method below.
679   Status status;
680   bool cleanup_done = false;
681   int64_t processed_size = batch->size();
682   auto cleanup_fn = [&cleanup_done, &batch, &processed_size,
683                      &batch_cost_measurements](const Status& status) {
684     if (cleanup_done) {
685       return;
686     }
687     SplitBatchCosts(batch_cost_measurements, processed_size, *batch);
688     // Clear the measurements before unblocking the batch task, as measurements
689     // are associated with the task's thread context.
690     batch_cost_measurements.clear();
691     for (int i = 0; i < batch->num_tasks(); ++i) {
692       WithContext wc(batch->task(i).propagated_context);
693       if (batch->task(i).is_partial) {
694         batch->mutable_task(i)->status->Update(status);
695       } else {
696         batch->mutable_task(i)->context->SetStatus(status);
697       }
698       batch->mutable_task(i)->done_callback();
699     }
700     cleanup_done = true;
701   };
702 
703   auto finally =
704       gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
705 
706   status = ValidateBatch(*batch);
707   if (!status.ok()) {
708     return;
709   }
710 
711   std::vector<Tensor> concatenated_tensors;
712   status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
713   processed_size = RoundToLowestAllowedBatchSize(batch->size());
714   if (!status.ok()) {
715     return;
716   }
717 
718   std::vector<Tensor> combined_outputs;
719   std::vector<Tensor> args(concatenated_tensors.begin(),
720                            concatenated_tensors.end());
721   const auto& captured_inputs =
722       batch->task(batch->num_tasks() - 1).captured_inputs;
723   args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
724 
725   uint64 current_time = EnvTime::NowNanos();
726   const string& model_name = GetModelName(last_task_context);
727   for (int i = 0; i < batch->num_tasks(); ++i) {
728     RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3,
729                        model_name, last_task_context->op_kernel().name(),
730                        processed_size);
731     RecordBatchDelayUsV2((current_time - batch->task(i).start_time) * 1e-3,
732                          model_name, last_task_context->op_kernel().name(),
733                          processed_size);
734   }
735   // Releases the cleanup method here, because the callback of the function
736   // library runtime will handle it now.
737   finally.release();
738   ProcessFuncBatchImpl(
739       last_task, args, &combined_outputs, [&](const Status& run_status) {
740         Status final_status;
741         auto run_finally = gtl::MakeCleanup([&]() {
742           // We do the cleanup here as an optimization, so that
743           // it runs in the underlying TF inter-op threadpool.
744           // Running it in the threadpool, let's the ensuing
745           // ops be scheduled faster, because the executor will
746           // add them to the front of the threadpool's task
747           // queue rather than the end.
748           cleanup_fn(final_status);
749         });
750         final_status = run_status;
751         if (!final_status.ok()) {
752           return;
753         }
754         final_status = SplitOutputTensors(combined_outputs, batch.get());
755       });
756 }
757 
758 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<BatchT> batch) const759 void BatchResourceBase::ProcessBatch(std::unique_ptr<BatchT> batch) const {
760   if (batch->empty()) {
761     return;
762   }
763 
764   // We use the 'propagated_context' from one of the threads which setup one
765   // of the tasks. This will propagate any common context over all the threads
766   // which are running this Session, of which this BatchOp is a part.
767   WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
768 
769   // TODO(b/185852990): Add a unit test to check the context is correctly set.
770   // Creates the CostMeasurement within the same context that runs the Session.
771   const CostMeasurement::Context batching_context{/*is_per_query=*/false};
772   std::vector<std::unique_ptr<CostMeasurement>> batch_cost_measurements =
773       CreateCostMeasurements(batching_context);
774 
775   int64_t processed_size = batch->size();
776   auto batch_cost_split_cleanup = gtl::MakeCleanup([&] {
777     SplitBatchCosts(batch_cost_measurements, processed_size, *batch);
778   });
779 
780   OpKernelContext* last_task_context =
781       batch->task(batch->num_tasks() - 1).context;
782   AsyncOpKernel::DoneCallback last_task_callback =
783       batch->task(batch->num_tasks() - 1).done_callback;
784 
785   OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
786                        last_task_callback);
787 
788   // All tasks should have the same number of input edges.
789   const int num_input_edges = batch->task(0).inputs.size();
790   std::vector<Tensor> concatenated_tensors;
791   const Status concat_status =
792       ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
793   processed_size = RoundToLowestAllowedBatchSize(batch->size());
794   OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
795 
796   // Process each input edge one at a time (the typical case has just one).
797   for (int i = 0; i < num_input_edges; ++i) {
798     last_task_context->set_output(i, concatenated_tensors[i]);
799 
800     // Emit batch->num_tasks() - 1 empty output tensors.
801     for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
802       const BatchTask& task = batch->task(task_idx);
803       TensorShape output_shape(task.inputs[i].shape());
804       output_shape.set_dim(0, 0);
805       Tensor* output = nullptr;
806       OP_REQUIRES_OK_ASYNC(
807           task.context, task.context->allocate_output(i, output_shape, &output),
808           task.done_callback);
809     }
810   }
811   // Emit batch->num_tasks() - 1 empty index tensors.
812   for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
813     const BatchTask& task = batch->task(task_idx);
814     TensorShape index_shape({0, 3});
815     Tensor* output = nullptr;
816     OP_REQUIRES_OK_ASYNC(
817         task.context,
818         task.context->allocate_output(num_input_edges, index_shape, &output),
819         task.done_callback);
820   }
821   // Emit all ID tensors.
822   for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
823     const BatchTask& task = batch->task(task_idx);
824     Tensor* id;
825     OP_REQUIRES_OK_ASYNC(task.context,
826                          task.context->allocate_output(num_input_edges + 1,
827                                                        TensorShape({}), &id),
828                          task.done_callback);
829     id->scalar<int64_t>()() = task.guid;
830   }
831   OP_REQUIRES_OK_ASYNC(
832       last_task_context,
833       EmitIndexTensor(last_task_context, *batch, num_input_edges),
834       last_task_callback);
835 
836   // Signal done for each element of the batch. (At this point, the contexts
837   // are no longer guaranteed to remain live.)
838   for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
839     batch->mutable_task(task_idx)->done_callback();
840   }
841 }
842 
EmitIndexTensor(OpKernelContext * context,const BatchT & batch,int output_index)843 /*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context,
844                                                      const BatchT& batch,
845                                                      int output_index) {
846   const TensorShape index_shape({batch.num_tasks(), 3});
847   Tensor* index = nullptr;
848   TF_RETURN_IF_ERROR(
849       context->allocate_output(output_index, index_shape, &index));
850   auto index_flat = index->shaped<int64_t, 2>({batch.num_tasks(), 3});
851   size_t offset = 0;
852   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
853     const BatchTask& task = batch.task(task_idx);
854     index_flat(task_idx, 0) = task.guid;
855     index_flat(task_idx, 1) = offset;
856     index_flat(task_idx, 2) = offset + task.size();
857     offset += task.size();
858   }
859   return OkStatus();
860 }
861 
862 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
863 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueueT ** queue)864 Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
865                                                      BatcherQueueT** queue) {
866   mutex_lock l(batcher_queues_mu_);
867 
868   auto it = batcher_queues_.find(queue_name);
869   if (it != batcher_queues_.end()) {
870     *queue = it->second.get();
871     return OkStatus();
872   }
873 
874   std::unique_ptr<BatcherQueueT> new_queue;
875   auto process_batch_callback = [this](std::unique_ptr<BatchT> batch) {
876     if (!has_process_batch_function_) {
877       ProcessBatch(std::move(batch));
878     } else {
879       ProcessFuncBatch(std::move(batch));
880     }
881   };
882   if (batcher_) {
883     TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
884                                           process_batch_callback, &new_queue));
885   } else if (adaptive_batcher_) {
886     TF_RETURN_IF_ERROR(adaptive_batcher_->AddQueue(
887         adaptive_batcher_queue_options_, process_batch_callback, &new_queue));
888   } else {
889     return errors::Internal("No batcher defined.");
890   }
891   *queue = new_queue.get();
892   batcher_queues_[queue_name] = std::move(new_queue);
893   return OkStatus();
894 }
895 
CreateBatchTask(OpKernelContext * context,std::unique_ptr<BatchResourceBase::BatchTask> * output) const896 Status BatchResourceBase::CreateBatchTask(
897     OpKernelContext* context,
898     std::unique_ptr<BatchResourceBase::BatchTask>* output) const {
899   *output = absl::make_unique<BatchResourceBase::BatchTask>();
900   return OkStatus();
901 }
902 
SplitBatchCosts(std::vector<std::unique_ptr<CostMeasurement>> & batch_cost_measurements,const int64_t processed_size,BatchT & batch)903 void BatchResourceBase::SplitBatchCosts(
904     std::vector<std::unique_ptr<CostMeasurement>>& batch_cost_measurements,
905     const int64_t processed_size, BatchT& batch) {
906   for (auto& batch_cost_measurement : batch_cost_measurements) {
907     if (batch_cost_measurement->GetTotalCost() <= absl::ZeroDuration()) {
908       return;
909     }
910     if (batch.size() == 0) {  // NOLINT: empty() checks the batch contains 0
911                               // tasks. size() gets the sum of task sizes.
912       LOG_EVERY_N_SEC(ERROR, 60)
913           << "Non-zero cost collected but the batch size is 0.";
914       return;
915     }
916     if (processed_size == 0) {
917       LOG_EVERY_N_SEC(ERROR, 60)
918           << "Non-zero cost collected but the processed size is 0.";
919       return;
920     }
921     const absl::string_view cost_type = batch_cost_measurement->GetCostType();
922     const absl::Duration total_cost = batch_cost_measurement->GetTotalCost();
923 
924     for (int i = 0; i < batch.num_tasks(); i++) {
925       RequestCost* request_cost = batch.task(i).request_cost;
926       // Skip recording the cost if the request_cost is null.
927       if (!request_cost) continue;
928 
929       // Smeared cost: cost of paddings are assigned to each task.
930       const auto cost_with_smear =
931           total_cost / batch.size() * batch.task(i).size();
932 
933       // Non-smeared cost: cost of paddings are not assigned to any tasks.
934       const auto cost_no_smear =
935           total_cost / processed_size * batch.task(i).size();
936 
937       request_cost->RecordCost(
938           {{absl::StrCat(cost_type, kWithSmearSuffix), cost_with_smear},
939            {absl::StrCat(cost_type, kNoSmearSuffix), cost_no_smear}});
940     }
941   }
942 }
943 
944 }  // namespace serving
945 }  // namespace tensorflow
946