xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/fallback/op_kernel_runner.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_
16 #define TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_
17 
18 #include <assert.h>
19 #include <stddef.h>
20 
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/container/inlined_vector.h"
26 #include "tensorflow/core/common_runtime/device_mgr.h"
27 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
28 #include "tensorflow/core/framework/device.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/status.h"
36 
37 namespace tensorflow {
38 namespace tfrt_stub {
39 
40 class OpKernelRunner {
41  public:
42   static StatusOr<OpKernelRunner> Create(
43       absl::string_view op_name, absl::string_view device_name, int num_args,
44       const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder,
45       const tensorflow::DeviceMgr& device_manager,
46       const tensorflow::ProcessFunctionLibraryRuntime&
47           process_function_library_runtime);
48 
49   static StatusOr<OpKernelRunner> Create(
50       absl::string_view op_name, int num_args,
51       const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder,
52       const tensorflow::ProcessFunctionLibraryRuntime&
53           process_function_library_runtime,
54       tensorflow::Device* device);
55 
56   OpKernelRunner() = default;
57 
58   explicit operator bool() const { return op_kernel_ != nullptr; }
59 
60   void Run(OpKernelContext* context) const;
61 
62   void RunAsync(OpKernelContext* context,
63                 AsyncOpKernel::DoneCallback done_callback) const;
64 
IsAsync()65   bool IsAsync() const { return is_async_; }
66 
op_kernel()67   tensorflow::OpKernel* op_kernel() const { return op_kernel_.get(); }
device()68   tensorflow::Device* device() const { return device_; }
function_library_runtime()69   tensorflow::FunctionLibraryRuntime* function_library_runtime() const {
70     return function_library_runtime_;
71   }
resource_manager()72   tensorflow::ResourceMgr* resource_manager() const {
73     return resource_manager_;
74   }
75 
input_alloc_attrs()76   const gtl::InlinedVector<AllocatorAttributes, 4>& input_alloc_attrs() const {
77     return input_alloc_attrs_;
78   }
output_alloc_attrs()79   const gtl::InlinedVector<AllocatorAttributes, 1>& output_alloc_attrs() const {
80     return output_alloc_attrs_;
81   }
82 
83  private:
84   explicit OpKernelRunner(
85       tensorflow::Device* device,
86       tensorflow::FunctionLibraryRuntime* function_library_runtime,
87       std::unique_ptr<OpKernel> op_kernel);
88 
89   tensorflow::Device* device_ = nullptr;
90   tensorflow::FunctionLibraryRuntime* function_library_runtime_ = nullptr;
91   tensorflow::ResourceMgr* resource_manager_ = nullptr;
92   std::unique_ptr<OpKernel> op_kernel_;
93   bool is_async_ = false;
94   gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_;
95   gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_;
96 };
97 
98 // OpKernelRunState keeps the states needed for per-kernel execution.
99 struct OpKernelRunState {
100   gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
101   gtl::InlinedVector<tensorflow::TensorValue, 4> input_tf_tensor_values;
102   OpKernelContext::Params params;
103 
104   OpKernelRunState() = default;
OpKernelRunStateOpKernelRunState105   OpKernelRunState(
106       const gtl::InlinedVector<tensorflow::TensorValue, 4>& tensor_values,
107       const OpKernelContext::Params& p) {
108     // `input_tf_tensor_values` contains the reference to all tensor used,
109     // while `input_tf_tensors` only contains those needs ownership so their
110     // sizes may not match. For this copy assignment, we conservatively copy all
111     // tensors.
112     input_tf_tensors.reserve(tensor_values.size());
113     for (const auto& tensor_value : tensor_values) {
114       input_tf_tensors.push_back(*tensor_value.tensor);
115     }
116     for (auto& tensor : input_tf_tensors) {
117       input_tf_tensor_values.emplace_back(&tensor);
118     }
119 
120     // Since `input_tf_tensor_values` and `params` contains pointers to
121     // `input_tf_tensors`, we need to change those pointers to the correct ones
122     // after copying.
123     params = p;
124     params.inputs = input_tf_tensor_values;
125     // Clear eigen_gpu_device to ensure OpKernelContext constructor will make a
126     // new eigen GPU device.
127     params.eigen_gpu_device = nullptr;
128   }
129 
130   OpKernelRunState(const OpKernelRunState& other) = delete;
131   OpKernelRunState& operator=(const OpKernelRunState& other) = delete;
132 
133   ~OpKernelRunState() = default;
134 };
135 
136 // OpKernelRunnerTable for keeping OpKernelRunner instances to avoid expensive
137 // reinstantiation of OpKernel and other repeated setup per kernel execution.
138 // OpKernelRunnerTable is thread-compatible.
139 class OpKernelRunnerTable {
140  public:
141   OpKernelRunnerTable() = default;
142 
143   // Return true if it successfully inserts `runner`. `index` is supposed to be
144   // dense.
Insert(int64_t index,OpKernelRunner runner)145   bool Insert(int64_t index, OpKernelRunner runner) {
146     if (runners_.size() <= index) runners_.resize(index + 1);
147     if (runners_[index].has_value()) return false;
148     runners_[index] = std::move(runner);
149     return true;
150   }
151 
152   // Return the OpKernelRunner at the corresponding `index` in the table. The
153   // result can never be nullptr. It is a fatal error to use an index that is
154   // not in the table. Note that the returned pointer will be invalidated if
155   // Insert() is called.
Get(int64_t index)156   const OpKernelRunner* Get(int64_t index) const {
157     // Out of bounds vector access will throw an exception and anyway will crash
158     // the binary, prefer a more readable error message.
159     CHECK_GT(runners_.size(), index)  // Crash OK
160         << "runner index is out of bounds: index=" << index
161         << " size=" << runners_.size();
162     auto& result = runners_.at(index);
163     CHECK(result.has_value())  // Crash OK
164         << "runner is not available: index=" << index;
165     return &(*result);
166   }
167 
168  private:
169   std::vector<absl::optional<OpKernelRunner>> runners_;
170 };
171 
172 }  // namespace tfrt_stub
173 }  // namespace tensorflow
174 
175 #endif  // TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_
176