1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ 17 #define TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/core/framework/function.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/platform/mutex.h" 23 #include "tensorflow/core/platform/status.h" 24 25 namespace tensorflow { 26 27 // Per-model inflight batches parameters. 28 ABSL_CONST_INIT extern const int64_t kMinInflightBatches; 29 ABSL_CONST_INIT extern const int64_t kInitialInflightBatches; 30 ABSL_CONST_INIT extern const int64_t kBatchesToAverageOver; 31 ABSL_CONST_INIT extern const int64_t kMaxInflightBatches; 32 33 namespace internal { 34 class BatchFunctionKernelTestAccess; 35 } 36 37 // `BatchFunctionKernel` is the implementation of op `BatchFunction`. 38 // 39 // `BatchFunctionKernel` will batch (tensor) inputs by concatenating them 40 // along the 0-th dimension, schedule a user-defined computation, and then 41 // splits the returned tensors as batch output. 42 // 43 // In particular, an instance of `BatchFunctionKernel` creates or re-uses a 44 // a batch scheduler instance based on op attributes, pre-processes and enqueues 45 // concatenated inputs to the scheduler which invokes user-defined function, 46 // and then splits function output as op output. 47 // 48 // User defined function is named by attribute `f` and defined in the graph. 49 class BatchFunctionKernel : public AsyncOpKernel { 50 public: 51 explicit BatchFunctionKernel(OpKernelConstruction* c); 52 53 bool IsExpensive() override; 54 55 void ComputeAsync(OpKernelContext* c, DoneCallback done) final; 56 57 private: 58 friend class internal::BatchFunctionKernelTestAccess; 59 60 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically. 61 // If large batch split is not enabled, the last one must equal 62 // `max_batch_size_`. otherwise the last element must be smaller than or equal 63 // to `max_batch_size_`. 64 Status ValidateAllowedBatchSizes() const; 65 66 // Creates the function handle if it isn't initialized yet; and re-use it 67 // afterwards. 68 Status GetOrCreateFunctionHandle(OpKernelContext* c, 69 FunctionLibraryRuntime::Handle* handle); 70 71 // Instantiate the user-defined function and emits `handle`. 72 Status InstantiateFunction(OpKernelContext* c, 73 FunctionLibraryRuntime::Handle* handle) const; 74 75 // Initialize vars by reading from op-kernel-construction. 76 // Vars 77 // - enable_adaptive_batch_threads_ 78 // true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or 79 // if `num_batch_threads` is not positive. 80 // - adaptive_batch_scheduler_options_ 81 // Read from corresponding attributes as long as they are set. 82 void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c, 83 int32_t num_batch_threads); 84 string container_; 85 string shared_name_; 86 string batcher_queue_; 87 int32 num_batch_threads_; 88 int32 max_batch_size_; 89 int32 batch_timeout_micros_; 90 int32 max_enqueued_batches_; 91 std::vector<int32> allowed_batch_sizes_; 92 NameAttrList func_; 93 absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_); 94 FunctionLibraryRuntime* flib_; 95 bool enable_large_batch_splitting_; 96 bool has_attribute_enable_large_batch_splitting_; 97 bool enable_adaptive_batch_threads_ = false; 98 99 mutex mu_; 100 101 // Parameters for adaptive batch scheduler only. 102 // Note 'num_batch_threads_' above is shared by two implementations of batch 103 // scheduler. 104 struct AdaptiveBatchSchedulerOptions { 105 int32 min_in_flight_batches_limit = kMinInflightBatches; 106 int32 initial_in_flight_batches_limit = kInitialInflightBatches; 107 int32 max_in_flight_batches_limit = kMaxInflightBatches; 108 int32 batches_to_average_over = kBatchesToAverageOver; 109 }; 110 absl::optional<AdaptiveBatchSchedulerOptions> 111 adaptive_batch_scheduler_options_ = absl::nullopt; 112 }; 113 114 } // namespace tensorflow 115 116 #endif // TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ 117