xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/batch_kernels.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/platform/status.h"
24 
25 namespace tensorflow {
26 
27 // Per-model inflight batches parameters.
28 ABSL_CONST_INIT extern const int64_t kMinInflightBatches;
29 ABSL_CONST_INIT extern const int64_t kInitialInflightBatches;
30 ABSL_CONST_INIT extern const int64_t kBatchesToAverageOver;
31 ABSL_CONST_INIT extern const int64_t kMaxInflightBatches;
32 
33 namespace internal {
34 class BatchFunctionKernelTestAccess;
35 }
36 
37 // `BatchFunctionKernel` is the implementation of op `BatchFunction`.
38 //
39 // `BatchFunctionKernel` will batch (tensor) inputs by concatenating them
40 // along the 0-th dimension, schedule a user-defined computation, and then
41 // splits the returned tensors as batch output.
42 //
43 // In particular, an instance of `BatchFunctionKernel` creates or re-uses a
44 // a batch scheduler instance based on op attributes, pre-processes and enqueues
45 // concatenated inputs to the scheduler which invokes user-defined function,
46 // and then splits function output as op output.
47 //
48 // User defined function is named by attribute `f` and defined in the graph.
49 class BatchFunctionKernel : public AsyncOpKernel {
50  public:
51   explicit BatchFunctionKernel(OpKernelConstruction* c);
52 
53   bool IsExpensive() override;
54 
55   void ComputeAsync(OpKernelContext* c, DoneCallback done) final;
56 
57  private:
58   friend class internal::BatchFunctionKernelTestAccess;
59 
60   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
61   // If large batch split is not enabled, the last one must equal
62   // `max_batch_size_`. otherwise the last element must be smaller than or equal
63   // to `max_batch_size_`.
64   Status ValidateAllowedBatchSizes() const;
65 
66   // Creates the function handle if it isn't initialized yet; and re-use it
67   // afterwards.
68   Status GetOrCreateFunctionHandle(OpKernelContext* c,
69                                    FunctionLibraryRuntime::Handle* handle);
70 
71   // Instantiate the user-defined function and emits `handle`.
72   Status InstantiateFunction(OpKernelContext* c,
73                              FunctionLibraryRuntime::Handle* handle) const;
74 
75   // Initialize vars by reading from op-kernel-construction.
76   // Vars
77   // - enable_adaptive_batch_threads_
78   //   true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or
79   //   if `num_batch_threads` is not positive.
80   // - adaptive_batch_scheduler_options_
81   //   Read from corresponding attributes as long as they are set.
82   void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c,
83                                         int32_t num_batch_threads);
84   string container_;
85   string shared_name_;
86   string batcher_queue_;
87   int32 num_batch_threads_;
88   int32 max_batch_size_;
89   int32 batch_timeout_micros_;
90   int32 max_enqueued_batches_;
91   std::vector<int32> allowed_batch_sizes_;
92   NameAttrList func_;
93   absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
94   FunctionLibraryRuntime* flib_;
95   bool enable_large_batch_splitting_;
96   bool has_attribute_enable_large_batch_splitting_;
97   bool enable_adaptive_batch_threads_ = false;
98 
99   mutex mu_;
100 
101   // Parameters for adaptive batch scheduler only.
102   // Note 'num_batch_threads_' above is shared by two implementations of batch
103   // scheduler.
104   struct AdaptiveBatchSchedulerOptions {
105     int32 min_in_flight_batches_limit = kMinInflightBatches;
106     int32 initial_in_flight_batches_limit = kInitialInflightBatches;
107     int32 max_in_flight_batches_limit = kMaxInflightBatches;
108     int32 batches_to_average_over = kBatchesToAverageOver;
109   };
110   absl::optional<AdaptiveBatchSchedulerOptions>
111       adaptive_batch_scheduler_options_ = absl::nullopt;
112 };
113 
114 }  // namespace tensorflow
115 
116 #endif  // TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_
117