1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 18 19 #include <map> 20 21 #include "absl/strings/str_join.h" 22 #include "tensorflow/core/common_runtime/cost_measurement_registry.h" 23 #include "tensorflow/core/common_runtime/request_cost.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" 29 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" 30 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" 31 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h" 32 #include "tensorflow/core/platform/context.h" 33 #include "tensorflow/core/platform/status.h" 34 #include "tensorflow/core/platform/thread_annotations.h" 35 36 namespace tensorflow { 37 namespace serving { 38 39 // Base class for resource that encapsulating the state and logic for batching 40 // tensors. 41 class BatchResourceBase : public ResourceBase { 42 public: 43 // Given a BatchTask (from one op invocation) with 'num_outputs'== M and 44 // splitted into N sub tasks, TensorMatrix is a N X M matrix. 45 // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output; 46 // concatenating tensors along the 2nd dimension gives a output tensor. 47 typedef std::vector<std::vector<Tensor>> TensorMatrix; 48 49 // Ingests data from one invocation of the batch op. The data is enqueued to 50 // be combined with others into a batch, asynchronously. 51 Status RegisterInput(int64_t guid, OpKernelContext* context, 52 const string& batcher_queue_name, 53 AsyncOpKernel::DoneCallback done_callback); 54 55 public: 56 // One task to be batched, corresponds to a `slice` of input from one batch-op 57 // invocation. 58 // 59 // Given input from one batch-op invocation, a `slice` of this input is: 60 // 1) Split each Tensor in `BatchTask::inputs` along the 0th dimension. 61 // 2) 'split_index' is calculated along the 0-th dimension. 62 // 63 // Note input from one batch-op invocation is valid and considered a 64 // specialized `slice`. 65 struct BatchTask : public tensorflow::serving::BatchTask { 66 // A unique ID to identify this invocation of Batch. 67 int64_t guid; 68 69 Context propagated_context; 70 71 std::vector<Tensor> inputs; 72 std::vector<Tensor> captured_inputs; 73 OpKernelContext* context; 74 AsyncOpKernel::DoneCallback done_callback; 75 76 // The index of this split, along the 0-th dimension of input from op 77 // invocation. 78 int split_index = 0; 79 80 // Two-dimensional tensor matrix, ownership shared by: 81 // 1) each split of task (to fill one row in this matrix) 82 // and 83 // 2) callback that runs to merge output of individual splits for an op 84 // invocation, after all splits complete. 85 std::shared_ptr<TensorMatrix> output; 86 87 // 'status' records error (could be from any split) if at least one split 88 // returns error, OK otherwise. 89 // Ownership is shared by individual splits and callback. 90 std::shared_ptr<ThreadSafeStatus> status; 91 92 bool is_partial = false; 93 94 uint64 start_time; 95 sizeBatchTask96 size_t size() const override { return inputs[0].shape().dim_size(0); } 97 98 // Create a split task from this one. The caller needs to setup the inputs 99 // of the new task 100 std::unique_ptr<BatchTask> CreateSplitTask( 101 int split_index, AsyncOpKernel::DoneCallback done_callback); 102 103 // RequestCost is for collecting the cost and must outlive the batching 104 // processing. 105 // 106 // For example, to collect cost in rpc processing, `request_cost` is owned 107 // by rpc handler and points to the RequestCost of an rpc which provides 108 // the inputs to this BatchTask. 109 // 110 // After the batch processing, the request cost will be incremented with 111 // this task's processing costs. 112 RequestCost* request_cost = nullptr; 113 114 protected: CreateDerivedTaskBatchTask115 virtual std::unique_ptr<BatchTask> CreateDerivedTask() { 116 return std::make_unique<BatchTask>(); 117 } 118 }; 119 120 // Appending a T suffix to make the type alias different to those in 121 // tensorflow::serving namespace, because some versions of compiler complain 122 // about changing meaning of the symbols. 123 using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>; 124 using AdaptiveBatcherT = 125 AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>; 126 using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>; 127 using BatchT = Batch<BatchResourceBase::BatchTask>; 128 BatchResourceBase(bool has_process_batch_function,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)129 BatchResourceBase(bool has_process_batch_function, 130 std::shared_ptr<BatcherT> batcher, 131 const BatcherT::QueueOptions& batcher_queue_options, 132 std::vector<int32> allowed_batch_sizes) 133 : has_process_batch_function_(has_process_batch_function), 134 batcher_(std::move(batcher)), 135 batcher_queue_options_(batcher_queue_options), 136 allowed_batch_sizes_(std::move(allowed_batch_sizes)) { 137 allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ","); 138 } 139 BatchResourceBase(bool has_process_batch_function,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)140 BatchResourceBase(bool has_process_batch_function, 141 std::shared_ptr<AdaptiveBatcherT> batcher, 142 const AdaptiveBatcherT::QueueOptions& batcher_queue_options, 143 std::vector<int32> allowed_batch_sizes) 144 : has_process_batch_function_(has_process_batch_function), 145 adaptive_batcher_(std::move(batcher)), 146 adaptive_batcher_queue_options_(batcher_queue_options), 147 allowed_batch_sizes_(std::move(allowed_batch_sizes)) {} 148 149 static BatcherT::QueueOptions GetBatcherQueueOptions( 150 int32_t num_batch_threads, int32_t max_batch_size, 151 int32_t batch_timeout_micros, int32_t max_enqueued_batches, 152 const std::vector<int32>& allowed_batch_sizes, 153 bool enable_large_batch_splitting); 154 155 static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions( 156 int32_t max_batch_size, int32_t batch_timeout_micros, 157 int32_t max_enqueued_batches, bool enable_large_batch_splitting, 158 const std::vector<int32>& allowed_batch_sizes); 159 160 // Split 'input' of 'input_task_ptr' along 0th dimension, into a list of 161 // 'output_tasks'. 162 // Task sizes are determined by 163 // 1) open_batch_remaining_slot 164 // 2) max_batch_size 165 // 3) size-of-input-task 166 // in a way that 167 // 1) Task sizes add up to `size-of-input-task`. 168 // 2) Task sizes from left to right are like 169 // [open_batch_remaining_slot, max_batch_size, max_batch_size, ..., 170 // `size-of-input-task` - `sum-of-previous-elements`]. 171 // 172 // REQUIRES: 173 // Caller should make sure size-of-input-task is greater than 174 // open_batch_remaining_slot. 175 static Status SplitInputTask( 176 std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot, 177 int max_batch_size, 178 std::vector<std::unique_ptr<BatchTask>>* output_tasks); 179 180 // Splits the batch costs to each task. 181 // 182 // Inputs: 183 // 1) batch_cost_measurements, which provides the total cost of each type; 184 // 2) processed_size, it's the batch size plus the padding amount; 185 // 3) batch, provides the batch size. 186 // 187 // Outputs: 188 // The request_cost in each batch task will be updated. This function will use 189 // two approaches to split the batch cost (if it's non-zero), thus two costs 190 // will be output. 191 // 1) smeared cost: batch cost is split proportionally to each task's size, 192 // and paddings do not share any cost; 193 // 2) non-smeared cost: batch cost is split proportionally to each task or 194 // padding's size. Here padding's cost is not assigned to any tasks. 195 static void SplitBatchCosts( 196 std::vector<std::unique_ptr<CostMeasurement>>& batch_cost_measurements, 197 const int64_t processed_size, BatchT& batch); 198 199 private: 200 // Implementation of calling the process batch function. 201 virtual void ProcessFuncBatchImpl( 202 const BatchResourceBase::BatchTask& last_task, 203 absl::Span<const Tensor> inputs, std::vector<Tensor>* combined_outputs, 204 std::function<void(const Status&)> done) const = 0; 205 206 // Factory method for creating a BatchTask, overridable by subclasses. 207 virtual Status CreateBatchTask( 208 OpKernelContext* context, 209 std::unique_ptr<BatchResourceBase::BatchTask>* output) const; 210 211 // Validates that it's legal to combine the tasks in 'batch' into a batch. 212 // Assumes the batch is non-empty. 213 static Status ValidateBatch(const BatchT& batch); 214 215 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than 216 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply 217 // returns 'batch_size'. 218 int RoundToLowestAllowedBatchSize(int batch_size) const; 219 220 Status ConcatInputTensors(const BatchT& batch, OpKernelContext* context, 221 std::vector<Tensor>* concatenated_tensors) const; 222 223 Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs, 224 BatchT* batch) const; 225 226 void ProcessFuncBatch(std::unique_ptr<BatchT> batch) const; 227 228 // Processes a batch of one or more BatchTask entries. 229 void ProcessBatch(std::unique_ptr<BatchT> batch) const; 230 231 // Emits an index tensor, which the Unbatch op will use to un-concatenate 232 // the tensor and attribute the pieces to the right batch keys. The index 233 // tensor contains, for each input: [batch_key, start_offset, end_offset] 234 // where start_offset and end_offset represent the range of entries in the 235 // concatenated tensors that belong to that input. 236 // 237 // Emits the result to the output at 'output_index' using 'context'. 238 static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch, 239 int output_index); 240 241 // Looks up the batcher queue for 'queue_name'. If it did't previously exist, 242 // creates it. 243 Status LookupOrCreateBatcherQueue(const string& queue_name, 244 BatcherQueueT** queue); 245 246 // True if user specified a batch processing function for this resource. 247 const bool has_process_batch_function_; 248 // A batch scheduler, and options for creating queues. 249 std::shared_ptr<BatcherT> batcher_; 250 BatcherT::QueueOptions batcher_queue_options_; 251 252 // A batch scheduler, and options for creating queues. 253 std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_; 254 AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_; 255 256 // A collection of batcher queues, keyed on queue name. 257 // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty 258 // ones (with a time delay?); it's okay if they get recreated later). 259 mutable mutex batcher_queues_mu_; 260 std::map<string, std::unique_ptr<BatcherQueueT>> batcher_queues_ 261 TF_GUARDED_BY(batcher_queues_mu_); 262 263 std::vector<int32> allowed_batch_sizes_; 264 // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is 265 // used to record batching parameter. 266 string allowed_batch_sizes_str_; 267 }; 268 269 } // namespace serving 270 } // namespace tensorflow 271 272 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 273