xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batching_util/batch_input_task.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <utility>
24 
25 #include "absl/base/call_once.h"
26 #include "absl/container/fixed_array.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
30 #include "tensorflow/core/kernels/batching_util/input_split_metadata.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/util/incremental_barrier.h"
34 
35 namespace tensorflow {
36 namespace serving {
37 
38 namespace internal {
39 template <typename TaskType>
40 class BatchInputTaskHandleTestAccess;
41 
42 template <typename TaskType>
43 class BatchInputTaskTestAccess;
44 
45 template <typename TaskType>
46 class BatchInputTask;
47 
48 // A RAII-style object that holds a ref-counted batch-input-task, and
49 // represents a slice of batch-input-task.
50 
51 // To be handed out to callers of `BatchInputTask::ToTaskHandles` quickly
52 // (i.e. not necessarily waiting for input split)
53 //
54 // `BatchInputTaskHandle::GetSplitTask` evaluates to the slice of task.
55 template <typename TaskType>
56 class BatchInputTaskHandle : public BatchTask {
57  public:
58   BatchInputTaskHandle(
59       std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
60       size_t task_size);
61 
62   // Should be called once. Returns nullptr on subsequent calls.
63   std::unique_ptr<TaskType> GetSplitTask();
64 
65   // Returns the size of this task.
size()66   size_t size() const override { return task_size_; }
67 
68  private:
69   template <typename T>
70   friend class internal::BatchInputTaskHandleTestAccess;
71 
split_id()72   int split_id() const { return split_id_; }
73 
74   std::shared_ptr<BatchInputTask<TaskType>> batch_input_task_;
75 
76   // The handle evaluates to the N-th slice of original task, and
77   // N is `split_id_`.
78   const int split_id_;
79 
80   const size_t task_size_;
81 
82   std::atomic<bool> once_{false};
83 };
84 
85 // BatchInputTask encapsulates a input (`input_task`) to be batched and the
86 // information to get task splits after it's enqueued, so as to support lazy
87 // split of a task.
88 //
89 // Input split could reduce excessive padding for efficiency; lazy split
90 // moves task-split out of the critical path of enqueue and dequeue and reduces
91 // contention.
92 //
93 // BatchInputTask is thread safe.
94 //
95 // Usage
96 //
97 // ... a deque with frequent enqueue and dequeue operations ...
98 // ... Note, a deque of Batch of BatchInputTaskHandle is used to form batches
99 //     at enqueue time (split is lazy at deque time);
100 // ... For use cases to form batches at dequeue time, we can use a deque of
101 //     BatchInputTaskHandle directly, and "peek" metadata to form a batch by
102 //     then.
103 // std::deque<std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>>> deque_
104 //     TF_GUARDED_BY(mu_);
105 //
106 // std::unique_ptr<TaskType> input_task;
107 //
108 // ... Enqueue path ...
109 //
110 // {
111 //   mutex_lock l(mu_);
112 //   std::shared_ptr<BatchInputTask<TaskType>> batch_input_task =
113 //       ConstructLazyBatchWithoutSplit(input_task);
114 //
115 //   std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> task_handles;
116 //   input_batch->ToTaskHandles(&task_handles);
117 //   for (int i = 0; i < task_handles.size(); ++i) {
118 //     EnqueueTaskHandleIntoDeque(deque_);
119 //   }
120 //
121 // ... Dequeue path ...
122 // std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>> handles_to_schedule;
123 // {
124 //    mutex_lock l(mu_);
125 //    ... HasBatchToSchedule could be customized or specialized
126 //    ... (e.g., readiness depending on enqueue time)
127 //    if (HasBatchToSchedule(deque_)) {
128 //      handles_to_schedule = std::move(deque_.front());
129 //      deque_.pop_front();
130 //    }
131 // }
132 // ...... `mu_` is released ......
133 //
134 // std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> tasks_in_batch =
135 //     RemoveAllTasksFromBatch(handles_to_schedule);
136 //
137 // std::unique_ptr<Batch<TaskType>> batch_to_schedule;
138 // for (int i = 0; i < tasks_in_batch.size(); i++) {
139 //   batch_to_schedule->AddTask(std::move(tasks_in_batch[i]->GetSplitTask()));
140 // }
141 // batch_to_schedule->Close();
142 //
143 // `batch_to_schedule` is ready for schedule.
144 template <typename TaskType>
145 class BatchInputTask
146     : public std::enable_shared_from_this<BatchInputTask<TaskType>> {
147  public:
148   using SplitInputFunc = std::function<Status(
149       std::unique_ptr<TaskType>* input_task, int first_output_task_size,
150       int input_batch_size_limit,
151       std::vector<std::unique_ptr<TaskType>>* output_tasks)>;
152 
153   BatchInputTask(std::unique_ptr<TaskType> input_task,
154                  int open_batch_remaining_slot, int batch_size_limit,
155                  SplitInputFunc split_input_func);
156 
157   // Outputs the task handles for the input task.
158   // Each task handle represents a slice of task after input task is split, and
159   // could evaluate to that slice.
160   //
161   // NOTE:
162   // Each task handle in `output_task_handles` takes ownership of a reference of
163   // this BatchInputTask.
164   void ToTaskHandles(
165       std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
166           output_task_handles);
167 
168  private:
169   friend class BatchInputTaskHandle<TaskType>;
170   template <typename T>
171   friend class internal::BatchInputTaskTestAccess;
172 
173   std::unique_ptr<TaskType> GetSplitTask(int split_id);
174 
175   Status SplitBatches(std::vector<std::unique_ptr<TaskType>>* output_tasks);
176 
177   std::unique_ptr<TaskType> input_task_;
178 
179   const int input_task_size_ = 0;
180   const int open_batch_remaining_slot_;
181 
182   const int batch_size_limit_;
183   const SplitInputFunc split_func_;
184 
185   const InputSplitMetadata input_split_metadata_;
186 
187   mutable absl::once_flag once_;
188 
189   std::vector<std::unique_ptr<TaskType>> task_splits_;
190   Status split_status_;
191 };
192 
193 //
194 // Implementation details. API readers may skip.
195 //
196 
197 template <typename TaskType>
BatchInputTaskHandle(std::shared_ptr<BatchInputTask<TaskType>> batch_input_task,int split_id,size_t task_size)198 BatchInputTaskHandle<TaskType>::BatchInputTaskHandle(
199     std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
200     size_t task_size)
201     : batch_input_task_(batch_input_task),
202       split_id_(split_id),
203       task_size_(task_size) {}
204 
205 template <typename TaskType>
GetSplitTask()206 std::unique_ptr<TaskType> BatchInputTaskHandle<TaskType>::GetSplitTask() {
207   if (once_.load(std::memory_order_acquire)) {
208     return nullptr;
209   }
210   once_.store(true, std::memory_order_release);
211   return batch_input_task_->GetSplitTask(split_id_);
212 }
213 
214 template <typename TaskType>
BatchInputTask(std::unique_ptr<TaskType> input_task,int open_batch_remaining_slot,int batch_size_limit,SplitInputFunc split_input_func)215 BatchInputTask<TaskType>::BatchInputTask(std::unique_ptr<TaskType> input_task,
216                                          int open_batch_remaining_slot,
217                                          int batch_size_limit,
218                                          SplitInputFunc split_input_func)
219     : input_task_(std::move(input_task)),
220       input_task_size_(input_task_->size()),
221       open_batch_remaining_slot_(open_batch_remaining_slot),
222       batch_size_limit_(batch_size_limit),
223       split_func_(split_input_func),
224       input_split_metadata_(input_task_size_, open_batch_remaining_slot,
225                             batch_size_limit) {}
226 
227 template <typename TaskType>
ToTaskHandles(std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> * task_handles)228 void BatchInputTask<TaskType>::ToTaskHandles(
229     std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
230         task_handles) {
231   const absl::FixedArray<int>& task_sizes = input_split_metadata_.task_sizes();
232   task_handles->resize(task_sizes.size());
233   for (int i = 0; i < task_handles->size(); i++) {
234     (*task_handles)[i] = std::make_unique<BatchInputTaskHandle<TaskType>>(
235         this->shared_from_this(), i, task_sizes[i]);
236   }
237 }
238 
239 template <typename TaskType>
GetSplitTask(int split_id)240 std::unique_ptr<TaskType> BatchInputTask<TaskType>::GetSplitTask(int split_id) {
241   absl::call_once(once_,
242                   [this]() { split_status_ = SplitBatches(&task_splits_); });
243   if (!split_status_.ok()) {
244     LOG_EVERY_N_SEC(WARNING, 60 /* seconds */)
245         << "Split task with error: " << split_status_ << " split metadata is "
246         << input_split_metadata_.DebugString();
247     return nullptr;
248   }
249   if (split_id >= 0 && split_id < task_splits_.size()) {
250     return std::move(task_splits_[split_id]);
251   }
252   return nullptr;
253 }
254 
255 template <typename TaskType>
SplitBatches(std::vector<std::unique_ptr<TaskType>> * output_tasks)256 Status BatchInputTask<TaskType>::SplitBatches(
257     std::vector<std::unique_ptr<TaskType>>* output_tasks) {
258   return split_func_(&input_task_, open_batch_remaining_slot_,
259                      batch_size_limit_, output_tasks);
260 }
261 
262 }  // namespace internal
263 }  // namespace serving
264 }  // namespace tensorflow
265 
266 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
267