1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
18
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <utility>
24
25 #include "absl/base/call_once.h"
26 #include "absl/container/fixed_array.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
30 #include "tensorflow/core/kernels/batching_util/input_split_metadata.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/util/incremental_barrier.h"
34
35 namespace tensorflow {
36 namespace serving {
37
38 namespace internal {
39 template <typename TaskType>
40 class BatchInputTaskHandleTestAccess;
41
42 template <typename TaskType>
43 class BatchInputTaskTestAccess;
44
45 template <typename TaskType>
46 class BatchInputTask;
47
48 // A RAII-style object that holds a ref-counted batch-input-task, and
49 // represents a slice of batch-input-task.
50
51 // To be handed out to callers of `BatchInputTask::ToTaskHandles` quickly
52 // (i.e. not necessarily waiting for input split)
53 //
54 // `BatchInputTaskHandle::GetSplitTask` evaluates to the slice of task.
55 template <typename TaskType>
56 class BatchInputTaskHandle : public BatchTask {
57 public:
58 BatchInputTaskHandle(
59 std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
60 size_t task_size);
61
62 // Should be called once. Returns nullptr on subsequent calls.
63 std::unique_ptr<TaskType> GetSplitTask();
64
65 // Returns the size of this task.
size()66 size_t size() const override { return task_size_; }
67
68 private:
69 template <typename T>
70 friend class internal::BatchInputTaskHandleTestAccess;
71
split_id()72 int split_id() const { return split_id_; }
73
74 std::shared_ptr<BatchInputTask<TaskType>> batch_input_task_;
75
76 // The handle evaluates to the N-th slice of original task, and
77 // N is `split_id_`.
78 const int split_id_;
79
80 const size_t task_size_;
81
82 std::atomic<bool> once_{false};
83 };
84
85 // BatchInputTask encapsulates a input (`input_task`) to be batched and the
86 // information to get task splits after it's enqueued, so as to support lazy
87 // split of a task.
88 //
89 // Input split could reduce excessive padding for efficiency; lazy split
90 // moves task-split out of the critical path of enqueue and dequeue and reduces
91 // contention.
92 //
93 // BatchInputTask is thread safe.
94 //
95 // Usage
96 //
97 // ... a deque with frequent enqueue and dequeue operations ...
98 // ... Note, a deque of Batch of BatchInputTaskHandle is used to form batches
99 // at enqueue time (split is lazy at deque time);
100 // ... For use cases to form batches at dequeue time, we can use a deque of
101 // BatchInputTaskHandle directly, and "peek" metadata to form a batch by
102 // then.
103 // std::deque<std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>>> deque_
104 // TF_GUARDED_BY(mu_);
105 //
106 // std::unique_ptr<TaskType> input_task;
107 //
108 // ... Enqueue path ...
109 //
110 // {
111 // mutex_lock l(mu_);
112 // std::shared_ptr<BatchInputTask<TaskType>> batch_input_task =
113 // ConstructLazyBatchWithoutSplit(input_task);
114 //
115 // std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> task_handles;
116 // input_batch->ToTaskHandles(&task_handles);
117 // for (int i = 0; i < task_handles.size(); ++i) {
118 // EnqueueTaskHandleIntoDeque(deque_);
119 // }
120 //
121 // ... Dequeue path ...
122 // std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>> handles_to_schedule;
123 // {
124 // mutex_lock l(mu_);
125 // ... HasBatchToSchedule could be customized or specialized
126 // ... (e.g., readiness depending on enqueue time)
127 // if (HasBatchToSchedule(deque_)) {
128 // handles_to_schedule = std::move(deque_.front());
129 // deque_.pop_front();
130 // }
131 // }
132 // ...... `mu_` is released ......
133 //
134 // std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> tasks_in_batch =
135 // RemoveAllTasksFromBatch(handles_to_schedule);
136 //
137 // std::unique_ptr<Batch<TaskType>> batch_to_schedule;
138 // for (int i = 0; i < tasks_in_batch.size(); i++) {
139 // batch_to_schedule->AddTask(std::move(tasks_in_batch[i]->GetSplitTask()));
140 // }
141 // batch_to_schedule->Close();
142 //
143 // `batch_to_schedule` is ready for schedule.
144 template <typename TaskType>
145 class BatchInputTask
146 : public std::enable_shared_from_this<BatchInputTask<TaskType>> {
147 public:
148 using SplitInputFunc = std::function<Status(
149 std::unique_ptr<TaskType>* input_task, int first_output_task_size,
150 int input_batch_size_limit,
151 std::vector<std::unique_ptr<TaskType>>* output_tasks)>;
152
153 BatchInputTask(std::unique_ptr<TaskType> input_task,
154 int open_batch_remaining_slot, int batch_size_limit,
155 SplitInputFunc split_input_func);
156
157 // Outputs the task handles for the input task.
158 // Each task handle represents a slice of task after input task is split, and
159 // could evaluate to that slice.
160 //
161 // NOTE:
162 // Each task handle in `output_task_handles` takes ownership of a reference of
163 // this BatchInputTask.
164 void ToTaskHandles(
165 std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
166 output_task_handles);
167
168 private:
169 friend class BatchInputTaskHandle<TaskType>;
170 template <typename T>
171 friend class internal::BatchInputTaskTestAccess;
172
173 std::unique_ptr<TaskType> GetSplitTask(int split_id);
174
175 Status SplitBatches(std::vector<std::unique_ptr<TaskType>>* output_tasks);
176
177 std::unique_ptr<TaskType> input_task_;
178
179 const int input_task_size_ = 0;
180 const int open_batch_remaining_slot_;
181
182 const int batch_size_limit_;
183 const SplitInputFunc split_func_;
184
185 const InputSplitMetadata input_split_metadata_;
186
187 mutable absl::once_flag once_;
188
189 std::vector<std::unique_ptr<TaskType>> task_splits_;
190 Status split_status_;
191 };
192
193 //
194 // Implementation details. API readers may skip.
195 //
196
197 template <typename TaskType>
BatchInputTaskHandle(std::shared_ptr<BatchInputTask<TaskType>> batch_input_task,int split_id,size_t task_size)198 BatchInputTaskHandle<TaskType>::BatchInputTaskHandle(
199 std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
200 size_t task_size)
201 : batch_input_task_(batch_input_task),
202 split_id_(split_id),
203 task_size_(task_size) {}
204
205 template <typename TaskType>
GetSplitTask()206 std::unique_ptr<TaskType> BatchInputTaskHandle<TaskType>::GetSplitTask() {
207 if (once_.load(std::memory_order_acquire)) {
208 return nullptr;
209 }
210 once_.store(true, std::memory_order_release);
211 return batch_input_task_->GetSplitTask(split_id_);
212 }
213
214 template <typename TaskType>
BatchInputTask(std::unique_ptr<TaskType> input_task,int open_batch_remaining_slot,int batch_size_limit,SplitInputFunc split_input_func)215 BatchInputTask<TaskType>::BatchInputTask(std::unique_ptr<TaskType> input_task,
216 int open_batch_remaining_slot,
217 int batch_size_limit,
218 SplitInputFunc split_input_func)
219 : input_task_(std::move(input_task)),
220 input_task_size_(input_task_->size()),
221 open_batch_remaining_slot_(open_batch_remaining_slot),
222 batch_size_limit_(batch_size_limit),
223 split_func_(split_input_func),
224 input_split_metadata_(input_task_size_, open_batch_remaining_slot,
225 batch_size_limit) {}
226
227 template <typename TaskType>
ToTaskHandles(std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> * task_handles)228 void BatchInputTask<TaskType>::ToTaskHandles(
229 std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
230 task_handles) {
231 const absl::FixedArray<int>& task_sizes = input_split_metadata_.task_sizes();
232 task_handles->resize(task_sizes.size());
233 for (int i = 0; i < task_handles->size(); i++) {
234 (*task_handles)[i] = std::make_unique<BatchInputTaskHandle<TaskType>>(
235 this->shared_from_this(), i, task_sizes[i]);
236 }
237 }
238
239 template <typename TaskType>
GetSplitTask(int split_id)240 std::unique_ptr<TaskType> BatchInputTask<TaskType>::GetSplitTask(int split_id) {
241 absl::call_once(once_,
242 [this]() { split_status_ = SplitBatches(&task_splits_); });
243 if (!split_status_.ok()) {
244 LOG_EVERY_N_SEC(WARNING, 60 /* seconds */)
245 << "Split task with error: " << split_status_ << " split metadata is "
246 << input_split_metadata_.DebugString();
247 return nullptr;
248 }
249 if (split_id >= 0 && split_id < task_splits_.size()) {
250 return std::move(task_splits_[split_id]);
251 }
252 return nullptr;
253 }
254
255 template <typename TaskType>
SplitBatches(std::vector<std::unique_ptr<TaskType>> * output_tasks)256 Status BatchInputTask<TaskType>::SplitBatches(
257 std::vector<std::unique_ptr<TaskType>>* output_tasks) {
258 return split_func_(&input_task_, open_batch_remaining_slot_,
259 batch_size_limit_, output_tasks);
260 }
261
262 } // namespace internal
263 } // namespace serving
264 } // namespace tensorflow
265
266 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
267