1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "absl/time/time.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/common_runtime/cost_constants.h"
23 #include "tensorflow/core/common_runtime/cost_measurement.h"
24 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
25 #include "tensorflow/core/common_runtime/cost_util.h"
26 #include "tensorflow/core/common_runtime/request_cost_accessor.h"
27 #include "tensorflow/core/common_runtime/request_cost_accessor_registry.h"
28 #include "tensorflow/core/framework/ops_util.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/lib/monitoring/gauge.h"
34 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
35 #include "tensorflow/core/lib/monitoring/sampler.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 #include "tensorflow/core/profiler/lib/traceme_encode.h"
38 #include "tensorflow/core/util/incremental_barrier.h"
39
40 namespace tensorflow {
41 namespace serving {
42 namespace {
43
44 // TODO(b/181883417): Replace with RecordPaddingSizeV2.
RecordPaddingSize(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)45 void RecordPaddingSize(int32_t padding_size, const string& model_name,
46 int32_t execution_batch_size, const string& op_name) {
47 static auto* cell = tensorflow::monitoring::PercentileSampler<3>::New(
48 {"/tensorflow/serving/batching/padding_size",
49 "Tracks the padding size distribution on batches by model_name (if "
50 "available).",
51 "model_name", "execution_batch_size", "op_name"},
52 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
53 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
54 cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
55 ->Add(static_cast<double>(padding_size));
56 }
57
RecordPaddingSizeV2(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)58 void RecordPaddingSizeV2(int32_t padding_size, const string& model_name,
59 int32_t execution_batch_size, const string& op_name) {
60 static auto* cell = tensorflow::monitoring::Sampler<3>::New(
61 {"/tensorflow/serving/batching/padding_size_v2",
62 "Tracks the padding size distribution on batches by model_name (if "
63 "available).",
64 "model_name", "execution_batch_size", "op_name"},
65 // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
66 // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
67 monitoring::Buckets::Exponential(1, 2, 14));
68 cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
69 ->Add(static_cast<double>(padding_size));
70 }
71
72 // TODO(b/181883417): Replace with RecordInputBatchSizeV2.
RecordInputBatchSize(int32_t batch_size,const string & model_name,const string & op_name)73 void RecordInputBatchSize(int32_t batch_size, const string& model_name,
74 const string& op_name) {
75 static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
76 {"/tensorflow/serving/batching/input_batch_size",
77 "Tracks the batch size distribution on the inputs by model_name (if "
78 "available).",
79 "model_name", "op_name"},
80 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
81 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
82 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
83 }
84
RecordInputBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)85 void RecordInputBatchSizeV2(int32_t batch_size, const string& model_name,
86 const string& op_name) {
87 static auto* cell = tensorflow::monitoring::Sampler<2>::New(
88 {"/tensorflow/serving/batching/input_batch_size_v2",
89 "Tracks the batch size distribution on the inputs by model_name (if "
90 "available).",
91 "model_name", "op_name"},
92 // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
93 // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
94 monitoring::Buckets::Exponential(1, 2, 14));
95 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
96 }
97
98 // Record the actual batch size without padding.
RecordBatchSize(int32_t batch_size,const string & model_name,const string & op_name)99 void RecordBatchSize(int32_t batch_size, const string& model_name,
100 const string& op_name) {
101 static auto* cell = tensorflow::monitoring::Sampler<2>::New(
102 {"/tensorflow/serving/batching/batch_size",
103 "Tracks the batch size distribution on the batch result by model_name "
104 "(if available).",
105 "model_name", "op_name"},
106 monitoring::Buckets::Exponential(1, 1.5, 20));
107 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
108 }
109
RecordProcessedBatchSize(int32_t batch_size,const string & model_name,const string & op_name)110 void RecordProcessedBatchSize(int32_t batch_size, const string& model_name,
111 const string& op_name) {
112 static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
113 {"/tensorflow/serving/batching/processed_batch_size",
114 "Tracks the batch size distribution on processing by model_name (if "
115 "available).",
116 "model_name", "op_name"},
117 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
118 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
119 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
120 }
121
122 // Export the exact number instead of the distribution of processed batch size.
RecordProcessedBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)123 void RecordProcessedBatchSizeV2(int32_t batch_size, const string& model_name,
124 const string& op_name) {
125 static auto* cell = monitoring::Counter<3>::New(
126 "/tensorflow/serving/batching/processed_batch_size_v2",
127 "Tracks the batch size on processing by model_name and op name (if "
128 "available).",
129 "model_name", "op_name", "batch_size");
130 cell->GetCell(model_name, op_name, std::to_string(batch_size))
131 ->IncrementBy(1);
132 }
133
134 // TODO(b/181883417): Replace with RecordBatchDelayUsV2.
RecordBatchDelayUs(int64_t batch_delay_us,const string & model_name,const string & op_name,int32_t batch_size)135 void RecordBatchDelayUs(int64_t batch_delay_us, const string& model_name,
136 const string& op_name, int32_t batch_size) {
137 static auto* cell = monitoring::PercentileSampler<3>::New(
138 {"/tensorflow/serving/batching/batch_delay_us",
139 "Tracks the batching delay (in microseconds) for inputs by model_name "
140 "(if available).",
141 "model_name", "op_name", "processed_batch_size"},
142 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
143 /*max_samples=*/1024, monitoring::UnitOfMeasure::kTime);
144 cell->GetCell(model_name, op_name, std::to_string(batch_size))
145 ->Add(static_cast<double>(batch_delay_us));
146 }
147
RecordBatchDelayUsV2(int64_t batch_delay_us,const string & model_name,const string & op_name,int32_t batch_size)148 void RecordBatchDelayUsV2(int64_t batch_delay_us, const string& model_name,
149 const string& op_name, int32_t batch_size) {
150 static auto* cell = tensorflow::monitoring::Sampler<3>::New(
151 {"/tensorflow/serving/batching/batch_delay_us_v2",
152 "Tracks the batching delay (in microseconds) for inputs by model_name "
153 "(if available).",
154 "model_name", "op_name", "processed_batch_size"},
155 // It's 27 buckets with the last bucket being 2^26 to DBL_MAX;
156 // so the limits are [1, 2, 4, 8, ..., 64 * 1024 * 1024, DBL_MAX].
157 monitoring::Buckets::Exponential(1, 2, 27));
158 cell->GetCell(model_name, op_name, std::to_string(batch_size))
159 ->Add(static_cast<double>(batch_delay_us));
160 }
161
RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,const string & model_name,const string & op_name)162 void RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,
163 const string& model_name,
164 const string& op_name) {
165 static auto* cell = monitoring::Gauge<int64_t, 2>::New(
166 "/tensorflow/serving/batching/batch_timeout_micros",
167 "Tracks how long a request can wait before being processed by a batch.",
168 "model_name", "op_name");
169 cell->GetCell(model_name, op_name)->Set(batch_timeout_micros);
170 }
171
RecordBatchParamMaxBatchSize(int64_t max_batch_size,const string & model_name,const string & op_name)172 void RecordBatchParamMaxBatchSize(int64_t max_batch_size,
173 const string& model_name,
174 const string& op_name) {
175 static auto* cell = monitoring::Gauge<int64_t, 2>::New(
176 "/tensorflow/serving/batching/max_batch_size",
177 "Tracks the maximum size of a batch.", "model_name", "op_name");
178 cell->GetCell(model_name, op_name)->Set(max_batch_size);
179 }
180
RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,const string & model_name,const string & op_name)181 void RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,
182 const string& model_name,
183 const string& op_name) {
184 static auto* cell = monitoring::Gauge<int64_t, 2>::New(
185 "/tensorflow/serving/batching/max_enqueued_batches",
186 "Tracks the maximum number of enqueued batches.", "model_name",
187 "op_name");
188 cell->GetCell(model_name, op_name)->Set(max_enqueued_batches);
189 }
190
RecordBatchParamAllowedBatchSizes(const string & allowed_batch_sizes,const string & model_name,const string & op_name)191 void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
192 const string& model_name,
193 const string& op_name) {
194 static auto* cell = monitoring::Gauge<string, 2>::New(
195 "/tensorflow/serving/batching/allowed_batch_sizes",
196 "Tracks the sizes that are allowed to form a batch.", "model_name",
197 "op_name");
198 cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes);
199 }
200
GetModelName(OpKernelContext * ctx)201 const string& GetModelName(OpKernelContext* ctx) {
202 static string* kModelNameUnset = new string("model_name_unset");
203 if (!ctx->session_metadata()) return *kModelNameUnset;
204 if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
205 return ctx->session_metadata()->name();
206 }
207
208 } // namespace
209
210 std::unique_ptr<BatchResourceBase::BatchTask>
CreateSplitTask(int split_index,AsyncOpKernel::DoneCallback done_callback)211 BatchResourceBase::BatchTask::CreateSplitTask(
212 int split_index, AsyncOpKernel::DoneCallback done_callback) {
213 std::unique_ptr<BatchTask> task = CreateDerivedTask();
214
215 task->guid = this->guid;
216 task->propagated_context = Context(ContextKind::kThread);
217 task->inputs.reserve(this->inputs.size());
218 task->captured_inputs = this->captured_inputs;
219 task->context = this->context;
220 task->done_callback = done_callback;
221 task->split_index = split_index;
222 task->output = this->output;
223 task->status = this->status;
224 task->is_partial = true;
225 task->start_time = this->start_time;
226 task->request_cost = this->request_cost;
227
228 return task;
229 }
230
231 using ::tensorflow::concat_split_util::Concat;
232 using ::tensorflow::concat_split_util::Split;
233 using TensorMatrix = std::vector<std::vector<Tensor>>;
234
RegisterInput(int64_t guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)235 Status BatchResourceBase::RegisterInput(
236 int64_t guid, OpKernelContext* context, const string& batcher_queue_name,
237 AsyncOpKernel::DoneCallback done_callback) {
238 std::unique_ptr<BatchTask> batch_components;
239 TF_RETURN_IF_ERROR(CreateBatchTask(context, &batch_components));
240 batch_components->start_time = EnvTime::NowNanos();
241 batch_components->guid = guid;
242 batch_components->propagated_context = Context(ContextKind::kThread);
243 OpInputList tensors;
244 TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
245 batch_components->inputs.reserve(tensors.size());
246 for (const Tensor& tensor : tensors) {
247 if (tensor.shape().dims() == 0) {
248 return errors::InvalidArgument(
249 "Batching input tensors must have at least one dimension");
250 }
251 if (tensors.size() >= 2 &&
252 tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
253 return errors::InvalidArgument(
254 "Batching input tensors supplied in a given op invocation must "
255 "have equal 0th-dimension size");
256 }
257 batch_components->inputs.push_back(tensor);
258 }
259 RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context),
260 context->op_kernel().name());
261 RecordInputBatchSizeV2(tensors[0].shape().dim_size(0), GetModelName(context),
262 context->op_kernel().name());
263 if (batcher_) {
264 RecordBatchParamBatchTimeoutMicros(
265 batcher_queue_options_.batch_timeout_micros, GetModelName(context),
266 context->op_kernel().name());
267 RecordBatchParamMaxBatchSize(
268 batcher_queue_options_.max_execution_batch_size, GetModelName(context),
269 context->op_kernel().name());
270 RecordBatchParamMaxEnqueuedBatches(
271 batcher_queue_options_.max_enqueued_batches, GetModelName(context),
272 context->op_kernel().name());
273 } else if (adaptive_batcher_) {
274 RecordBatchParamBatchTimeoutMicros(
275 adaptive_batcher_queue_options_.batch_timeout_micros,
276 GetModelName(context), context->op_kernel().name());
277 RecordBatchParamMaxBatchSize(adaptive_batcher_queue_options_.max_batch_size,
278 GetModelName(context),
279 context->op_kernel().name());
280 RecordBatchParamMaxEnqueuedBatches(
281 adaptive_batcher_queue_options_.max_enqueued_batches,
282 GetModelName(context), context->op_kernel().name());
283 } else {
284 return errors::Internal("No batcher defined.");
285 }
286 RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
287 GetModelName(context),
288 context->op_kernel().name());
289
290 // Degenerate case where the input is empty. Just return an empty tensor.
291 if (tensors[0].shape().dim_size(0) == 0) {
292 for (int i = 0; i < context->num_outputs(); i++) {
293 Tensor* empty_output;
294 AllocatorAttributes cpu_alloc;
295 cpu_alloc.set_on_host(true);
296 TF_RETURN_IF_ERROR(context->allocate_output(i, TensorShape({0}),
297 &empty_output, cpu_alloc));
298 }
299 done_callback();
300 return OkStatus();
301 }
302 OpInputList captured_tensors;
303 const auto captured_status =
304 context->input_list("captured_tensors", &captured_tensors);
305 if (captured_status.ok()) {
306 batch_components->captured_inputs.reserve(captured_tensors.size());
307 for (const Tensor& captured_tensor : captured_tensors) {
308 batch_components->captured_inputs.push_back(captured_tensor);
309 }
310 }
311 batch_components->context = context;
312 batch_components->done_callback = std::move(done_callback);
313 batch_components->split_index = 0;
314 batch_components->output = std::make_shared<TensorMatrix>();
315 batch_components->status = std::make_shared<ThreadSafeStatus>();
316
317 std::unique_ptr<RequestCostAccessor> request_cost_accessor =
318 CreateRequestCostAccessor();
319 if (request_cost_accessor) {
320 batch_components->request_cost = request_cost_accessor->GetRequestCost();
321 }
322
323 BatcherQueueT* batcher_queue;
324 TF_RETURN_IF_ERROR(
325 LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
326 return batcher_queue->Schedule(&batch_components);
327 }
328
329 /*static*/ BatchResourceBase::BatcherT::QueueOptions
GetBatcherQueueOptions(int32_t num_batch_threads,int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,bool enable_large_batch_splitting)330 BatchResourceBase::GetBatcherQueueOptions(
331 int32_t num_batch_threads, int32_t max_batch_size,
332 int32_t batch_timeout_micros, int32_t max_enqueued_batches,
333 const std::vector<int32>& allowed_batch_sizes,
334 bool enable_large_batch_splitting) {
335 BatcherT::QueueOptions batcher_queue_options;
336 batcher_queue_options.input_batch_size_limit = max_batch_size;
337 batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
338 batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
339 batcher_queue_options.enable_large_batch_splitting =
340 enable_large_batch_splitting;
341 if (enable_large_batch_splitting) {
342 batcher_queue_options.split_input_task_func =
343 [](std::unique_ptr<BatchTask>* input_task,
344 int open_batch_remaining_slot, int max_batch_size,
345 std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
346 return SplitInputTask(input_task, open_batch_remaining_slot,
347 max_batch_size, output_tasks);
348 };
349
350 if (allowed_batch_sizes.empty()) {
351 batcher_queue_options.max_execution_batch_size = max_batch_size;
352 } else {
353 batcher_queue_options.max_execution_batch_size =
354 *allowed_batch_sizes.rbegin();
355 }
356 }
357
358 return batcher_queue_options;
359 }
360
361 /*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
GetAdaptiveBatcherQueueOptions(int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,bool enable_large_batch_splitting,const std::vector<int32> & allowed_batch_sizes)362 BatchResourceBase::GetAdaptiveBatcherQueueOptions(
363 int32_t max_batch_size, int32_t batch_timeout_micros,
364 int32_t max_enqueued_batches, bool enable_large_batch_splitting,
365 const std::vector<int32>& allowed_batch_sizes) {
366 AdaptiveBatcherT::QueueOptions batcher_queue_options;
367 batcher_queue_options.max_input_task_size =
368 absl::make_optional(max_batch_size);
369 batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
370 batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
371 if (allowed_batch_sizes.empty()) {
372 batcher_queue_options.max_batch_size = max_batch_size;
373 } else {
374 batcher_queue_options.max_batch_size = *allowed_batch_sizes.rbegin();
375 }
376
377 if (enable_large_batch_splitting) {
378 batcher_queue_options.split_input_task_func =
379 [](std::unique_ptr<BatchTask>* input_task,
380 int open_batch_remaining_slot, int max_batch_size,
381 std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
382 return SplitInputTask(input_task, open_batch_remaining_slot,
383 max_batch_size, output_tasks);
384 };
385 }
386
387 return batcher_queue_options;
388 }
389
ValidateBatch(const BatchT & batch)390 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
391 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
392 const BatchResourceBase::BatchTask& task = batch.task(task_idx);
393
394 if (task.inputs.size() != batch.task(0).inputs.size()) {
395 return errors::InvalidArgument(
396 "Batching inputs must have equal number of edges");
397 }
398 }
399
400 return OkStatus();
401 }
402
403 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
404 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
405 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const406 int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const {
407 if (allowed_batch_sizes_.empty()) {
408 return batch_size;
409 }
410 for (int allowed_size : allowed_batch_sizes_) {
411 if (allowed_size >= batch_size) {
412 return allowed_size;
413 }
414 }
415 LOG(ERROR) << "Batch size " << batch_size
416 << " is greater than largest allowed size; "
417 "ignoring allowed sizes constraint.";
418 return batch_size;
419 }
420
ConcatInputTensors(const BatchT & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const421 Status BatchResourceBase::ConcatInputTensors(
422 const BatchT& batch, OpKernelContext* context,
423 std::vector<Tensor>* concatenated_tensors) const {
424 if (batch.num_tasks() == 0) {
425 return errors::InvalidArgument("Empty batch.");
426 }
427
428 const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
429 const int padding_amount = padded_batch_size - batch.size();
430 profiler::TraceMe trace_me([padded_batch_size, padding_amount]() {
431 return profiler::TraceMeEncode(
432 "ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size},
433 {"padding_amount", padding_amount}});
434 });
435 RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size,
436 context->op_kernel().name());
437 RecordPaddingSizeV2(padding_amount, GetModelName(context), padded_batch_size,
438 context->op_kernel().name());
439 RecordProcessedBatchSize(padded_batch_size, GetModelName(context),
440 context->op_kernel().name());
441 RecordProcessedBatchSizeV2(padded_batch_size, GetModelName(context),
442 context->op_kernel().name());
443 RecordBatchSize(batch.size(), GetModelName(context),
444 context->op_kernel().name());
445
446 // All tasks should have the same number of input edges.
447 const int num_inputs = batch.task(0).inputs.size();
448 concatenated_tensors->reserve(num_inputs);
449
450 // Process each input one at a time (the typical case has just one).
451 for (int i = 0; i < num_inputs; ++i) {
452 // Concatenate the tasks ith input tensors into a big output tensor.
453 std::vector<Tensor> to_concatenate;
454 to_concatenate.reserve(batch.num_tasks());
455 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
456 to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
457 }
458
459 // Add padding as needed. Use the first row of the first task's tensor as
460 // the data for padding.
461 if (padding_amount > 0) {
462 const Tensor& padding_source = batch.task(0).inputs.at(i);
463 Tensor padding;
464 if (padding_source.shape().dim_size(0) == 0) {
465 return errors::InvalidArgument(
466 "Cannot use an empty tensor with zero rows as padding when "
467 "batching. (Input ",
468 i, " got shape ", padding_source.shape().DebugString(), ".)");
469 }
470 if (padding_source.shape().dim_size(0) == 1) {
471 padding = padding_source;
472 } else {
473 padding = padding_source.Slice(0, 1);
474 }
475 for (int i = 0; i < padding_amount; ++i) {
476 to_concatenate.push_back(padding);
477 }
478 }
479
480 Tensor concatenated_tensor;
481 Status concat_status =
482 Concat(context, to_concatenate, &concatenated_tensor);
483 TF_RETURN_IF_ERROR(concat_status);
484 concatenated_tensors->push_back(concatenated_tensor);
485 }
486 return OkStatus();
487 }
488
SplitInputTask(std::unique_ptr<BatchTask> * input_task_ptr,int open_batch_remaining_slot,int max_batch_size,std::vector<std::unique_ptr<BatchTask>> * output_tasks)489 /*static*/ Status BatchResourceBase::SplitInputTask(
490 std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
491 int max_batch_size, std::vector<std::unique_ptr<BatchTask>>* output_tasks) {
492 BatchTask& input_task = *(*input_task_ptr);
493 const int64_t input_task_size = input_task.size();
494
495 DCHECK_GT(input_task_size, 0);
496
497 std::shared_ptr<ThreadSafeStatus> shared_status = input_task.status;
498
499 // `split_task_done_callback` runs only after all splitted tasks are
500 // complete.
501 std::function<void()> split_task_done_callback =
502 [done_callback = input_task.done_callback, output = input_task.output,
503 op_kernel_context = input_task.context, status = shared_status]() {
504 const int num_output = op_kernel_context->num_outputs();
505 for (int i = 0; i < num_output; ++i) {
506 Tensor output_tensor;
507
508 // Concat would memcpy each input tensor to one output tensor.
509 // In this context, Concat can be further optimized to get rid of
510 // some (probably all) memcpy when input tensors are slices of
511 // another copy.
512 std::vector<Tensor> to_concatenate;
513 to_concatenate.reserve(output->size());
514 for (int j = 0; j < output->size(); ++j) {
515 to_concatenate.push_back(std::move((*output)[j][i]));
516 }
517 const auto concat_status =
518 Concat(op_kernel_context, to_concatenate, &output_tensor);
519 if (!concat_status.ok()) {
520 status->Update(concat_status);
521 }
522
523 op_kernel_context->set_output(i, std::move(output_tensor));
524 }
525 op_kernel_context->SetStatus(status->status());
526 done_callback();
527 };
528 IncrementalBarrier barrier(split_task_done_callback);
529
530 const internal::InputSplitMetadata input_split_metadata(
531 input_task_size, open_batch_remaining_slot, max_batch_size);
532
533 const absl::FixedArray<int>& task_sizes = input_split_metadata.task_sizes();
534 const int num_batches = task_sizes.size();
535 std::vector<int64_t> output_task_sizes;
536 output_task_sizes.resize(num_batches);
537 for (int i = 0; i < num_batches; i++) {
538 output_task_sizes[i] = task_sizes[i];
539 }
540
541 input_task.output->resize(num_batches);
542 for (int i = 0; i < num_batches; ++i) {
543 (*input_task.output)[i].resize(input_task.context->num_outputs());
544 }
545
546 output_tasks->reserve(num_batches);
547 for (int i = 0; i < num_batches; i++) {
548 output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
549 }
550
551 const int num_input_tensors = input_task.inputs.size();
552
553 // Splits each input tensor according to `output_task_sizes`, and
554 // initializes input of `output_tasks` with split results.
555 for (int i = 0; i < num_input_tensors; ++i) {
556 std::vector<Tensor> split_tensors;
557 const Tensor& input_tensor = input_task.inputs[i];
558 // TODO(b/154140947):
559 // Figure out the optimal implementation of Split, by using
560 // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
561 const Status split_status = Split(input_task.context, input_tensor,
562 output_task_sizes, &split_tensors);
563 if (!split_status.ok()) {
564 return errors::Internal(
565 "When splitting input, Tensor split operation failed: ",
566 split_status.error_message());
567 }
568 if (split_tensors.size() != output_task_sizes.size()) {
569 return errors::Internal(
570 "When splitting input, tensor split operation did not work as "
571 "expected; got ",
572 split_tensors.size(), " splits; expected ", output_task_sizes.size());
573 }
574 for (int j = 0; j < output_tasks->size(); ++j) {
575 BatchTask& output_task = *((*output_tasks)[j]);
576 auto moved_tensor_iter = std::next(split_tensors.begin(), j);
577 std::move(moved_tensor_iter, moved_tensor_iter + 1,
578 std::back_inserter(output_task.inputs));
579 }
580 }
581 return OkStatus();
582 }
583
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,BatchT * batch) const584 Status BatchResourceBase::SplitOutputTensors(
585 const std::vector<Tensor>& combined_outputs, BatchT* batch) const {
586 DCHECK_GE(batch->num_tasks(), 1);
587 if (batch->num_tasks() < 1) {
588 return errors::Internal("Batch size expected to be positive; was ",
589 batch->num_tasks());
590 }
591
592 std::vector<int64_t> task_sizes_plus_optional_padding;
593 task_sizes_plus_optional_padding.reserve(batch->num_tasks());
594 for (int i = 0; i < batch->num_tasks(); ++i) {
595 task_sizes_plus_optional_padding.push_back(batch->task(i).size());
596 }
597 const int padding_size =
598 RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
599 if (padding_size > 0) {
600 task_sizes_plus_optional_padding.push_back(padding_size);
601 }
602
603 // For each output tensor name, a divided-up tensor with one entry per task.
604 std::map<string, std::vector<Tensor>> split_tensors;
605
606 DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
607 int combined_outputs_size = combined_outputs.size();
608 if (combined_outputs_size != batch->task(0).context->num_outputs()) {
609 return errors::Internal("Wrong number of batched output tensors");
610 }
611
612 // Generate 'split_tensors' and populate the context outputs.
613 for (int i = 0, iter_limit = combined_outputs.size(); i < iter_limit; ++i) {
614 const Tensor& output_tensor = combined_outputs[i];
615 if (output_tensor.shape().dims() == 0) {
616 return errors::FailedPrecondition(
617 "Batched output tensor has 0 dimensions");
618 }
619 if (output_tensor.shape().dim_size(0) !=
620 static_cast<int64_t>(batch->size() + padding_size)) {
621 return errors::FailedPrecondition(
622 "Batched output tensor's 0th dimension does not equal the sum of "
623 "the 0th dimension sizes of the input tensors");
624 }
625
626 std::vector<Tensor> split_tensor;
627 const Status split_status = tensor::Split(
628 output_tensor, task_sizes_plus_optional_padding, &split_tensor);
629 DCHECK(split_status.ok()) << split_status.ToString();
630 if (!split_status.ok()) {
631 return errors::Internal("Tensor split operation failed: ",
632 split_status.error_message());
633 }
634 DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
635 if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
636 return errors::Internal(
637 "Tensor split operation did not work as expected; got ",
638 split_tensor.size(), " splits; expected ",
639 task_sizes_plus_optional_padding.size());
640 }
641
642 // Ignore a possible final split_tensors entry containing the padding.
643 for (int j = 0; j < batch->num_tasks(); ++j) {
644 BatchTask& task = *(batch->mutable_task(j));
645 if (task.is_partial) {
646 std::vector<Tensor>& tensor_vector = (*task.output)[task.split_index];
647 tensor_vector[i] = std::move(split_tensor[j]);
648 } else {
649 task.context->set_output(i, split_tensor[j]);
650 }
651 }
652 }
653
654 return OkStatus();
655 }
656
ProcessFuncBatch(std::unique_ptr<BatchT> batch) const657 void BatchResourceBase::ProcessFuncBatch(std::unique_ptr<BatchT> batch) const {
658 if (batch->empty()) {
659 return;
660 }
661
662 // We use the 'propagated_context' from one of the threads which setup one
663 // of the tasks. This will propagate any common context over all the threads
664 // which are running this Session, of which this BatchOp is a part.
665 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
666
667 // TODO(b/185852990): Add a unit test to check the context is correctly set.
668 // Creates the CostMeasurements within the same context that runs the Session.
669 const CostMeasurement::Context batching_context{/*is_per_query=*/false};
670 std::vector<std::unique_ptr<CostMeasurement>> batch_cost_measurements =
671 CreateCostMeasurements(batching_context);
672
673 auto& last_task = batch->task(batch->num_tasks() - 1);
674 OpKernelContext* last_task_context = last_task.context;
675
676 // Regardless of the outcome, we need to propagate the status to the
677 // individual tasks and signal that they are done. We use MakeCleanup() to
678 // ensure that this happens no matter how we exit the method below.
679 Status status;
680 bool cleanup_done = false;
681 int64_t processed_size = batch->size();
682 auto cleanup_fn = [&cleanup_done, &batch, &processed_size,
683 &batch_cost_measurements](const Status& status) {
684 if (cleanup_done) {
685 return;
686 }
687 SplitBatchCosts(batch_cost_measurements, processed_size, *batch);
688 // Clear the measurements before unblocking the batch task, as measurements
689 // are associated with the task's thread context.
690 batch_cost_measurements.clear();
691 for (int i = 0; i < batch->num_tasks(); ++i) {
692 WithContext wc(batch->task(i).propagated_context);
693 if (batch->task(i).is_partial) {
694 batch->mutable_task(i)->status->Update(status);
695 } else {
696 batch->mutable_task(i)->context->SetStatus(status);
697 }
698 batch->mutable_task(i)->done_callback();
699 }
700 cleanup_done = true;
701 };
702
703 auto finally =
704 gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
705
706 status = ValidateBatch(*batch);
707 if (!status.ok()) {
708 return;
709 }
710
711 std::vector<Tensor> concatenated_tensors;
712 status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
713 processed_size = RoundToLowestAllowedBatchSize(batch->size());
714 if (!status.ok()) {
715 return;
716 }
717
718 std::vector<Tensor> combined_outputs;
719 std::vector<Tensor> args(concatenated_tensors.begin(),
720 concatenated_tensors.end());
721 const auto& captured_inputs =
722 batch->task(batch->num_tasks() - 1).captured_inputs;
723 args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
724
725 uint64 current_time = EnvTime::NowNanos();
726 const string& model_name = GetModelName(last_task_context);
727 for (int i = 0; i < batch->num_tasks(); ++i) {
728 RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3,
729 model_name, last_task_context->op_kernel().name(),
730 processed_size);
731 RecordBatchDelayUsV2((current_time - batch->task(i).start_time) * 1e-3,
732 model_name, last_task_context->op_kernel().name(),
733 processed_size);
734 }
735 // Releases the cleanup method here, because the callback of the function
736 // library runtime will handle it now.
737 finally.release();
738 ProcessFuncBatchImpl(
739 last_task, args, &combined_outputs, [&](const Status& run_status) {
740 Status final_status;
741 auto run_finally = gtl::MakeCleanup([&]() {
742 // We do the cleanup here as an optimization, so that
743 // it runs in the underlying TF inter-op threadpool.
744 // Running it in the threadpool, let's the ensuing
745 // ops be scheduled faster, because the executor will
746 // add them to the front of the threadpool's task
747 // queue rather than the end.
748 cleanup_fn(final_status);
749 });
750 final_status = run_status;
751 if (!final_status.ok()) {
752 return;
753 }
754 final_status = SplitOutputTensors(combined_outputs, batch.get());
755 });
756 }
757
758 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<BatchT> batch) const759 void BatchResourceBase::ProcessBatch(std::unique_ptr<BatchT> batch) const {
760 if (batch->empty()) {
761 return;
762 }
763
764 // We use the 'propagated_context' from one of the threads which setup one
765 // of the tasks. This will propagate any common context over all the threads
766 // which are running this Session, of which this BatchOp is a part.
767 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
768
769 // TODO(b/185852990): Add a unit test to check the context is correctly set.
770 // Creates the CostMeasurement within the same context that runs the Session.
771 const CostMeasurement::Context batching_context{/*is_per_query=*/false};
772 std::vector<std::unique_ptr<CostMeasurement>> batch_cost_measurements =
773 CreateCostMeasurements(batching_context);
774
775 int64_t processed_size = batch->size();
776 auto batch_cost_split_cleanup = gtl::MakeCleanup([&] {
777 SplitBatchCosts(batch_cost_measurements, processed_size, *batch);
778 });
779
780 OpKernelContext* last_task_context =
781 batch->task(batch->num_tasks() - 1).context;
782 AsyncOpKernel::DoneCallback last_task_callback =
783 batch->task(batch->num_tasks() - 1).done_callback;
784
785 OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
786 last_task_callback);
787
788 // All tasks should have the same number of input edges.
789 const int num_input_edges = batch->task(0).inputs.size();
790 std::vector<Tensor> concatenated_tensors;
791 const Status concat_status =
792 ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
793 processed_size = RoundToLowestAllowedBatchSize(batch->size());
794 OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
795
796 // Process each input edge one at a time (the typical case has just one).
797 for (int i = 0; i < num_input_edges; ++i) {
798 last_task_context->set_output(i, concatenated_tensors[i]);
799
800 // Emit batch->num_tasks() - 1 empty output tensors.
801 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
802 const BatchTask& task = batch->task(task_idx);
803 TensorShape output_shape(task.inputs[i].shape());
804 output_shape.set_dim(0, 0);
805 Tensor* output = nullptr;
806 OP_REQUIRES_OK_ASYNC(
807 task.context, task.context->allocate_output(i, output_shape, &output),
808 task.done_callback);
809 }
810 }
811 // Emit batch->num_tasks() - 1 empty index tensors.
812 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
813 const BatchTask& task = batch->task(task_idx);
814 TensorShape index_shape({0, 3});
815 Tensor* output = nullptr;
816 OP_REQUIRES_OK_ASYNC(
817 task.context,
818 task.context->allocate_output(num_input_edges, index_shape, &output),
819 task.done_callback);
820 }
821 // Emit all ID tensors.
822 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
823 const BatchTask& task = batch->task(task_idx);
824 Tensor* id;
825 OP_REQUIRES_OK_ASYNC(task.context,
826 task.context->allocate_output(num_input_edges + 1,
827 TensorShape({}), &id),
828 task.done_callback);
829 id->scalar<int64_t>()() = task.guid;
830 }
831 OP_REQUIRES_OK_ASYNC(
832 last_task_context,
833 EmitIndexTensor(last_task_context, *batch, num_input_edges),
834 last_task_callback);
835
836 // Signal done for each element of the batch. (At this point, the contexts
837 // are no longer guaranteed to remain live.)
838 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
839 batch->mutable_task(task_idx)->done_callback();
840 }
841 }
842
EmitIndexTensor(OpKernelContext * context,const BatchT & batch,int output_index)843 /*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context,
844 const BatchT& batch,
845 int output_index) {
846 const TensorShape index_shape({batch.num_tasks(), 3});
847 Tensor* index = nullptr;
848 TF_RETURN_IF_ERROR(
849 context->allocate_output(output_index, index_shape, &index));
850 auto index_flat = index->shaped<int64_t, 2>({batch.num_tasks(), 3});
851 size_t offset = 0;
852 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
853 const BatchTask& task = batch.task(task_idx);
854 index_flat(task_idx, 0) = task.guid;
855 index_flat(task_idx, 1) = offset;
856 index_flat(task_idx, 2) = offset + task.size();
857 offset += task.size();
858 }
859 return OkStatus();
860 }
861
862 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
863 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueueT ** queue)864 Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
865 BatcherQueueT** queue) {
866 mutex_lock l(batcher_queues_mu_);
867
868 auto it = batcher_queues_.find(queue_name);
869 if (it != batcher_queues_.end()) {
870 *queue = it->second.get();
871 return OkStatus();
872 }
873
874 std::unique_ptr<BatcherQueueT> new_queue;
875 auto process_batch_callback = [this](std::unique_ptr<BatchT> batch) {
876 if (!has_process_batch_function_) {
877 ProcessBatch(std::move(batch));
878 } else {
879 ProcessFuncBatch(std::move(batch));
880 }
881 };
882 if (batcher_) {
883 TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
884 process_batch_callback, &new_queue));
885 } else if (adaptive_batcher_) {
886 TF_RETURN_IF_ERROR(adaptive_batcher_->AddQueue(
887 adaptive_batcher_queue_options_, process_batch_callback, &new_queue));
888 } else {
889 return errors::Internal("No batcher defined.");
890 }
891 *queue = new_queue.get();
892 batcher_queues_[queue_name] = std::move(new_queue);
893 return OkStatus();
894 }
895
CreateBatchTask(OpKernelContext * context,std::unique_ptr<BatchResourceBase::BatchTask> * output) const896 Status BatchResourceBase::CreateBatchTask(
897 OpKernelContext* context,
898 std::unique_ptr<BatchResourceBase::BatchTask>* output) const {
899 *output = absl::make_unique<BatchResourceBase::BatchTask>();
900 return OkStatus();
901 }
902
SplitBatchCosts(std::vector<std::unique_ptr<CostMeasurement>> & batch_cost_measurements,const int64_t processed_size,BatchT & batch)903 void BatchResourceBase::SplitBatchCosts(
904 std::vector<std::unique_ptr<CostMeasurement>>& batch_cost_measurements,
905 const int64_t processed_size, BatchT& batch) {
906 for (auto& batch_cost_measurement : batch_cost_measurements) {
907 if (batch_cost_measurement->GetTotalCost() <= absl::ZeroDuration()) {
908 return;
909 }
910 if (batch.size() == 0) { // NOLINT: empty() checks the batch contains 0
911 // tasks. size() gets the sum of task sizes.
912 LOG_EVERY_N_SEC(ERROR, 60)
913 << "Non-zero cost collected but the batch size is 0.";
914 return;
915 }
916 if (processed_size == 0) {
917 LOG_EVERY_N_SEC(ERROR, 60)
918 << "Non-zero cost collected but the processed size is 0.";
919 return;
920 }
921 const absl::string_view cost_type = batch_cost_measurement->GetCostType();
922 const absl::Duration total_cost = batch_cost_measurement->GetTotalCost();
923
924 for (int i = 0; i < batch.num_tasks(); i++) {
925 RequestCost* request_cost = batch.task(i).request_cost;
926 // Skip recording the cost if the request_cost is null.
927 if (!request_cost) continue;
928
929 // Smeared cost: cost of paddings are assigned to each task.
930 const auto cost_with_smear =
931 total_cost / batch.size() * batch.task(i).size();
932
933 // Non-smeared cost: cost of paddings are not assigned to any tasks.
934 const auto cost_no_smear =
935 total_cost / processed_size * batch.task(i).size();
936
937 request_cost->RecordCost(
938 {{absl::StrCat(cost_type, kWithSmearSuffix), cost_with_smear},
939 {absl::StrCat(cost_type, kNoSmearSuffix), cost_no_smear}});
940 }
941 }
942 }
943
944 } // namespace serving
945 } // namespace tensorflow
946