xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batching_util/batch_resource_base.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
18 
19 #include <map>
20 
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
23 #include "tensorflow/core/common_runtime/request_cost.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
30 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
31 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
32 #include "tensorflow/core/platform/context.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/thread_annotations.h"
35 
36 namespace tensorflow {
37 namespace serving {
38 
39 // Base class for resource that encapsulating the state and logic for batching
40 // tensors.
41 class BatchResourceBase : public ResourceBase {
42  public:
43   // Given a BatchTask (from one op invocation) with 'num_outputs'== M and
44   // splitted into N sub tasks, TensorMatrix is a N X M matrix.
45   // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output;
46   // concatenating tensors along the 2nd dimension gives a output tensor.
47   typedef std::vector<std::vector<Tensor>> TensorMatrix;
48 
49   // Ingests data from one invocation of the batch op. The data is enqueued to
50   // be combined with others into a batch, asynchronously.
51   Status RegisterInput(int64_t guid, OpKernelContext* context,
52                        const string& batcher_queue_name,
53                        AsyncOpKernel::DoneCallback done_callback);
54 
55  public:
56   // One task to be batched, corresponds to a `slice` of input from one batch-op
57   // invocation.
58   //
59   // Given input from one batch-op invocation, a `slice` of this input is:
60   // 1) Split each Tensor in `BatchTask::inputs` along the 0th dimension.
61   // 2) 'split_index' is calculated along the 0-th dimension.
62   //
63   // Note input from one batch-op invocation is valid and considered a
64   // specialized `slice`.
65   struct BatchTask : public tensorflow::serving::BatchTask {
66     // A unique ID to identify this invocation of Batch.
67     int64_t guid;
68 
69     Context propagated_context;
70 
71     std::vector<Tensor> inputs;
72     std::vector<Tensor> captured_inputs;
73     OpKernelContext* context;
74     AsyncOpKernel::DoneCallback done_callback;
75 
76     // The index of this split, along the 0-th dimension of input from op
77     // invocation.
78     int split_index = 0;
79 
80     // Two-dimensional tensor matrix, ownership shared by:
81     // 1) each split of task (to fill one row in this matrix)
82     // and
83     // 2) callback that runs to merge output of individual splits for an op
84     // invocation, after all splits complete.
85     std::shared_ptr<TensorMatrix> output;
86 
87     // 'status' records error (could be from any split) if at least one split
88     // returns error, OK otherwise.
89     // Ownership is shared by individual splits and callback.
90     std::shared_ptr<ThreadSafeStatus> status;
91 
92     bool is_partial = false;
93 
94     uint64 start_time;
95 
sizeBatchTask96     size_t size() const override { return inputs[0].shape().dim_size(0); }
97 
98     // Create a split task from this one. The caller needs to setup the inputs
99     // of the new task
100     std::unique_ptr<BatchTask> CreateSplitTask(
101         int split_index, AsyncOpKernel::DoneCallback done_callback);
102 
103     // RequestCost is for collecting the cost and must outlive the batching
104     // processing.
105     //
106     // For example, to collect cost in rpc processing, `request_cost` is owned
107     // by rpc handler and points to the RequestCost of an rpc which provides
108     // the inputs to this BatchTask.
109     //
110     // After the batch processing, the request cost will be incremented with
111     // this task's processing costs.
112     RequestCost* request_cost = nullptr;
113 
114    protected:
CreateDerivedTaskBatchTask115     virtual std::unique_ptr<BatchTask> CreateDerivedTask() {
116       return std::make_unique<BatchTask>();
117     }
118   };
119 
120   // Appending a T suffix to make the type alias different to those in
121   // tensorflow::serving namespace, because some versions of compiler complain
122   // about changing meaning of the symbols.
123   using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>;
124   using AdaptiveBatcherT =
125       AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>;
126   using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>;
127   using BatchT = Batch<BatchResourceBase::BatchTask>;
128 
BatchResourceBase(bool has_process_batch_function,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)129   BatchResourceBase(bool has_process_batch_function,
130                     std::shared_ptr<BatcherT> batcher,
131                     const BatcherT::QueueOptions& batcher_queue_options,
132                     std::vector<int32> allowed_batch_sizes)
133       : has_process_batch_function_(has_process_batch_function),
134         batcher_(std::move(batcher)),
135         batcher_queue_options_(batcher_queue_options),
136         allowed_batch_sizes_(std::move(allowed_batch_sizes)) {
137     allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ",");
138   }
139 
BatchResourceBase(bool has_process_batch_function,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)140   BatchResourceBase(bool has_process_batch_function,
141                     std::shared_ptr<AdaptiveBatcherT> batcher,
142                     const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
143                     std::vector<int32> allowed_batch_sizes)
144       : has_process_batch_function_(has_process_batch_function),
145         adaptive_batcher_(std::move(batcher)),
146         adaptive_batcher_queue_options_(batcher_queue_options),
147         allowed_batch_sizes_(std::move(allowed_batch_sizes)) {}
148 
149   static BatcherT::QueueOptions GetBatcherQueueOptions(
150       int32_t num_batch_threads, int32_t max_batch_size,
151       int32_t batch_timeout_micros, int32_t max_enqueued_batches,
152       const std::vector<int32>& allowed_batch_sizes,
153       bool enable_large_batch_splitting);
154 
155   static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions(
156       int32_t max_batch_size, int32_t batch_timeout_micros,
157       int32_t max_enqueued_batches, bool enable_large_batch_splitting,
158       const std::vector<int32>& allowed_batch_sizes);
159 
160   // Split 'input' of 'input_task_ptr' along 0th dimension, into a list of
161   // 'output_tasks'.
162   // Task sizes are determined by
163   // 1) open_batch_remaining_slot
164   // 2) max_batch_size
165   // 3) size-of-input-task
166   // in a way that
167   // 1) Task sizes add up to `size-of-input-task`.
168   // 2) Task sizes from left to right are like
169   //    [open_batch_remaining_slot, max_batch_size, max_batch_size, ...,
170   //    `size-of-input-task` - `sum-of-previous-elements`].
171   //
172   // REQUIRES:
173   // Caller should make sure size-of-input-task is greater than
174   // open_batch_remaining_slot.
175   static Status SplitInputTask(
176       std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
177       int max_batch_size,
178       std::vector<std::unique_ptr<BatchTask>>* output_tasks);
179 
180   // Splits the batch costs to each task.
181   //
182   // Inputs:
183   // 1) batch_cost_measurements, which provides the total cost of each type;
184   // 2) processed_size, it's the batch size plus the padding amount;
185   // 3) batch, provides the batch size.
186   //
187   // Outputs:
188   // The request_cost in each batch task will be updated. This function will use
189   // two approaches to split the batch cost (if it's non-zero), thus two costs
190   // will be output.
191   // 1) smeared cost: batch cost is split proportionally to each task's size,
192   //    and paddings do not share any cost;
193   // 2) non-smeared cost: batch cost is split proportionally to each task or
194   //    padding's size. Here padding's cost is not assigned to any tasks.
195   static void SplitBatchCosts(
196       std::vector<std::unique_ptr<CostMeasurement>>& batch_cost_measurements,
197       const int64_t processed_size, BatchT& batch);
198 
199  private:
200   // Implementation of calling the process batch function.
201   virtual void ProcessFuncBatchImpl(
202       const BatchResourceBase::BatchTask& last_task,
203       absl::Span<const Tensor> inputs, std::vector<Tensor>* combined_outputs,
204       std::function<void(const Status&)> done) const = 0;
205 
206   // Factory method for creating a BatchTask, overridable by subclasses.
207   virtual Status CreateBatchTask(
208       OpKernelContext* context,
209       std::unique_ptr<BatchResourceBase::BatchTask>* output) const;
210 
211   // Validates that it's legal to combine the tasks in 'batch' into a batch.
212   // Assumes the batch is non-empty.
213   static Status ValidateBatch(const BatchT& batch);
214 
215   // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
216   // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
217   // returns 'batch_size'.
218   int RoundToLowestAllowedBatchSize(int batch_size) const;
219 
220   Status ConcatInputTensors(const BatchT& batch, OpKernelContext* context,
221                             std::vector<Tensor>* concatenated_tensors) const;
222 
223   Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
224                             BatchT* batch) const;
225 
226   void ProcessFuncBatch(std::unique_ptr<BatchT> batch) const;
227 
228   // Processes a batch of one or more BatchTask entries.
229   void ProcessBatch(std::unique_ptr<BatchT> batch) const;
230 
231   // Emits an index tensor, which the Unbatch op will use to un-concatenate
232   // the tensor and attribute the pieces to the right batch keys. The index
233   // tensor contains, for each input: [batch_key, start_offset, end_offset]
234   // where start_offset and end_offset represent the range of entries in the
235   // concatenated tensors that belong to that input.
236   //
237   // Emits the result to the output at 'output_index' using 'context'.
238   static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch,
239                                 int output_index);
240 
241   // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
242   // creates it.
243   Status LookupOrCreateBatcherQueue(const string& queue_name,
244                                     BatcherQueueT** queue);
245 
246   // True if user specified a batch processing function for this resource.
247   const bool has_process_batch_function_;
248   // A batch scheduler, and options for creating queues.
249   std::shared_ptr<BatcherT> batcher_;
250   BatcherT::QueueOptions batcher_queue_options_;
251 
252   // A batch scheduler, and options for creating queues.
253   std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_;
254   AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_;
255 
256   // A collection of batcher queues, keyed on queue name.
257   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
258   // ones (with a time delay?); it's okay if they get recreated later).
259   mutable mutex batcher_queues_mu_;
260   std::map<string, std::unique_ptr<BatcherQueueT>> batcher_queues_
261       TF_GUARDED_BY(batcher_queues_mu_);
262 
263   std::vector<int32> allowed_batch_sizes_;
264   // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is
265   // used to record batching parameter.
266   string allowed_batch_sizes_str_;
267 };
268 
269 }  // namespace serving
270 }  // namespace tensorflow
271 
272 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
273