1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_ 16 #define TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_ 17 18 #include <assert.h> 19 #include <stddef.h> 20 21 #include <memory> 22 #include <string> 23 #include <utility> 24 25 #include "absl/container/inlined_vector.h" 26 #include "tensorflow/core/common_runtime/device_mgr.h" 27 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 28 #include "tensorflow/core/framework/device.h" 29 #include "tensorflow/core/framework/node_def.pb.h" 30 #include "tensorflow/core/framework/op_def.pb.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/platform/errors.h" 35 #include "tensorflow/core/platform/status.h" 36 37 namespace tensorflow { 38 namespace tfrt_stub { 39 40 class OpKernelRunner { 41 public: 42 static StatusOr<OpKernelRunner> Create( 43 absl::string_view op_name, absl::string_view device_name, int num_args, 44 const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder, 45 const tensorflow::DeviceMgr& device_manager, 46 const tensorflow::ProcessFunctionLibraryRuntime& 47 process_function_library_runtime); 48 49 static StatusOr<OpKernelRunner> Create( 50 absl::string_view op_name, int num_args, 51 const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder, 52 const tensorflow::ProcessFunctionLibraryRuntime& 53 process_function_library_runtime, 54 tensorflow::Device* device); 55 56 OpKernelRunner() = default; 57 58 explicit operator bool() const { return op_kernel_ != nullptr; } 59 60 void Run(OpKernelContext* context) const; 61 62 void RunAsync(OpKernelContext* context, 63 AsyncOpKernel::DoneCallback done_callback) const; 64 IsAsync()65 bool IsAsync() const { return is_async_; } 66 op_kernel()67 tensorflow::OpKernel* op_kernel() const { return op_kernel_.get(); } device()68 tensorflow::Device* device() const { return device_; } function_library_runtime()69 tensorflow::FunctionLibraryRuntime* function_library_runtime() const { 70 return function_library_runtime_; 71 } resource_manager()72 tensorflow::ResourceMgr* resource_manager() const { 73 return resource_manager_; 74 } 75 input_alloc_attrs()76 const gtl::InlinedVector<AllocatorAttributes, 4>& input_alloc_attrs() const { 77 return input_alloc_attrs_; 78 } output_alloc_attrs()79 const gtl::InlinedVector<AllocatorAttributes, 1>& output_alloc_attrs() const { 80 return output_alloc_attrs_; 81 } 82 83 private: 84 explicit OpKernelRunner( 85 tensorflow::Device* device, 86 tensorflow::FunctionLibraryRuntime* function_library_runtime, 87 std::unique_ptr<OpKernel> op_kernel); 88 89 tensorflow::Device* device_ = nullptr; 90 tensorflow::FunctionLibraryRuntime* function_library_runtime_ = nullptr; 91 tensorflow::ResourceMgr* resource_manager_ = nullptr; 92 std::unique_ptr<OpKernel> op_kernel_; 93 bool is_async_ = false; 94 gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_; 95 gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_; 96 }; 97 98 // OpKernelRunState keeps the states needed for per-kernel execution. 99 struct OpKernelRunState { 100 gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors; 101 gtl::InlinedVector<tensorflow::TensorValue, 4> input_tf_tensor_values; 102 OpKernelContext::Params params; 103 104 OpKernelRunState() = default; OpKernelRunStateOpKernelRunState105 OpKernelRunState( 106 const gtl::InlinedVector<tensorflow::TensorValue, 4>& tensor_values, 107 const OpKernelContext::Params& p) { 108 // `input_tf_tensor_values` contains the reference to all tensor used, 109 // while `input_tf_tensors` only contains those needs ownership so their 110 // sizes may not match. For this copy assignment, we conservatively copy all 111 // tensors. 112 input_tf_tensors.reserve(tensor_values.size()); 113 for (const auto& tensor_value : tensor_values) { 114 input_tf_tensors.push_back(*tensor_value.tensor); 115 } 116 for (auto& tensor : input_tf_tensors) { 117 input_tf_tensor_values.emplace_back(&tensor); 118 } 119 120 // Since `input_tf_tensor_values` and `params` contains pointers to 121 // `input_tf_tensors`, we need to change those pointers to the correct ones 122 // after copying. 123 params = p; 124 params.inputs = input_tf_tensor_values; 125 // Clear eigen_gpu_device to ensure OpKernelContext constructor will make a 126 // new eigen GPU device. 127 params.eigen_gpu_device = nullptr; 128 } 129 130 OpKernelRunState(const OpKernelRunState& other) = delete; 131 OpKernelRunState& operator=(const OpKernelRunState& other) = delete; 132 133 ~OpKernelRunState() = default; 134 }; 135 136 // OpKernelRunnerTable for keeping OpKernelRunner instances to avoid expensive 137 // reinstantiation of OpKernel and other repeated setup per kernel execution. 138 // OpKernelRunnerTable is thread-compatible. 139 class OpKernelRunnerTable { 140 public: 141 OpKernelRunnerTable() = default; 142 143 // Return true if it successfully inserts `runner`. `index` is supposed to be 144 // dense. Insert(int64_t index,OpKernelRunner runner)145 bool Insert(int64_t index, OpKernelRunner runner) { 146 if (runners_.size() <= index) runners_.resize(index + 1); 147 if (runners_[index].has_value()) return false; 148 runners_[index] = std::move(runner); 149 return true; 150 } 151 152 // Return the OpKernelRunner at the corresponding `index` in the table. The 153 // result can never be nullptr. It is a fatal error to use an index that is 154 // not in the table. Note that the returned pointer will be invalidated if 155 // Insert() is called. Get(int64_t index)156 const OpKernelRunner* Get(int64_t index) const { 157 // Out of bounds vector access will throw an exception and anyway will crash 158 // the binary, prefer a more readable error message. 159 CHECK_GT(runners_.size(), index) // Crash OK 160 << "runner index is out of bounds: index=" << index 161 << " size=" << runners_.size(); 162 auto& result = runners_.at(index); 163 CHECK(result.has_value()) // Crash OK 164 << "runner is not available: index=" << index; 165 return &(*result); 166 } 167 168 private: 169 std::vector<absl::optional<OpKernelRunner>> runners_; 170 }; 171 172 } // namespace tfrt_stub 173 } // namespace tensorflow 174 175 #endif // TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_ 176