1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/batching_util/input_split_metadata.h"
17
18 #include <algorithm>
19
20 #include "absl/container/fixed_array.h"
21 #include "absl/strings/str_join.h"
22
23 namespace tensorflow {
24 namespace serving {
25 namespace internal {
26 namespace {
compute_task_size_from_open_batch(int input_task_size,int open_batch_remaining_slot,int batch_size_limit)27 int compute_task_size_from_open_batch(int input_task_size,
28 int open_batch_remaining_slot,
29 int batch_size_limit) {
30 return (open_batch_remaining_slot > 0)
31 ? (input_task_size + batch_size_limit - open_batch_remaining_slot)
32 : input_task_size;
33 }
34
compute_head_task_size(int input_task_size,int open_batch_remaining_slot,int batch_size_limit)35 int compute_head_task_size(int input_task_size, int open_batch_remaining_slot,
36 int batch_size_limit) {
37 if (open_batch_remaining_slot == 0) {
38 return std::min(input_task_size, batch_size_limit);
39 }
40 return std::min(open_batch_remaining_slot, input_task_size);
41 }
42
compute_tail_task_size(int task_size_from_open_batch,int input_task_size,int open_batch_remaining_slot,int batch_size_limit)43 int compute_tail_task_size(int task_size_from_open_batch, int input_task_size,
44 int open_batch_remaining_slot,
45 int batch_size_limit) {
46 int tail_task_size;
47 if (input_task_size <= open_batch_remaining_slot) {
48 tail_task_size = input_task_size;
49 } else {
50 tail_task_size = task_size_from_open_batch % batch_size_limit;
51 if (tail_task_size == 0) {
52 tail_task_size = batch_size_limit;
53 }
54 }
55 return tail_task_size;
56 }
57
compute_num_batches(int task_size_from_open_batch,int batch_size_limit)58 int compute_num_batches(int task_size_from_open_batch, int batch_size_limit) {
59 return (task_size_from_open_batch + batch_size_limit - 1) / batch_size_limit;
60 }
61 } // namespace
62
InputSplitMetadata(int input_task_size,int open_batch_remaining_slot,int batch_size_limit)63 InputSplitMetadata::InputSplitMetadata(int input_task_size,
64 int open_batch_remaining_slot,
65 int batch_size_limit)
66 : task_sizes_(generate_task_sizes(
67 input_task_size, open_batch_remaining_slot, batch_size_limit)) {}
68
task_sizes() const69 const absl::FixedArray<int>& InputSplitMetadata::task_sizes() const {
70 return task_sizes_;
71 }
72
DebugString() const73 std::string InputSplitMetadata::DebugString() const {
74 return absl::StrJoin(task_sizes_, ", ");
75 }
76
generate_task_sizes(int input_task_size,int open_batch_remaining_slot,int batch_size_limit) const77 absl::FixedArray<int> InputSplitMetadata::generate_task_sizes(
78 int input_task_size, int open_batch_remaining_slot,
79 int batch_size_limit) const {
80 const int task_size_from_open_batch = compute_task_size_from_open_batch(
81 input_task_size, open_batch_remaining_slot, batch_size_limit);
82
83 const int num_batches =
84 compute_num_batches(task_size_from_open_batch, batch_size_limit);
85
86 absl::FixedArray<int> task_sizes(num_batches, batch_size_limit);
87
88 task_sizes.front() = compute_head_task_size(
89 input_task_size, open_batch_remaining_slot, batch_size_limit);
90
91 task_sizes.back() =
92 compute_tail_task_size(task_size_from_open_batch, input_task_size,
93 open_batch_remaining_slot, batch_size_limit);
94
95 return task_sizes;
96 }
97 } // namespace internal
98 } // namespace serving
99 } // namespace tensorflow
100