1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ 16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ 17 18 #include <functional> 19 #include <memory> 20 #include <string> 21 22 #include "tensorflow/core/common_runtime/eager/context.h" 23 #include "tensorflow/core/framework/device.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/platform/refcount.h" 27 #include "tensorflow/core/platform/threadpool_interface.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" 30 #include "tensorflow/core/tfrt/utils/fallback_tensor.h" 31 #include "tfrt/host_context/async_value.h" // from @tf_runtime 32 #include "tfrt/support/pointer_util.h" // from @tf_runtime 33 34 namespace tensorflow { 35 namespace tfd { 36 37 // FallbackResourceArray holds the tensors that are computed only once during 38 // initialization and read-only afterwards. 39 class FallbackResourceArray { 40 public: 41 // Sets `tensor` in the array at `index`. `index` should be dense and 42 // duplicate indices are not allowed. 43 void SetResource(int index, tensorflow::tfrt_stub::ImmutableTensor tensor); 44 45 // Returns the resource tensor wrapped in AsyncValue value at `index`. 46 tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>* GetResource(int index)47 GetResource(int index) const { 48 return resource_async_values_.at(index).get(); 49 } 50 51 // Returns the resource tensor at `index`. GetResourceAsFallbackTensor(int index)52 const tensorflow::tfrt_stub::FallbackTensor& GetResourceAsFallbackTensor( 53 int index) const { 54 return resource_async_values_.at(index)->get(); 55 } 56 57 private: 58 // `resources_` holds the ownership of all the resource tensors. Note that it 59 // may not be a one-to-one mapping between `resources_` and 60 // `resource_async_values_`. 61 std::vector<std::unique_ptr<tensorflow::tfrt_stub::ImmutableTensor>> 62 resources_; 63 // `resource_async_values_` holds the UnRefCountedAsyncValue of the fallback 64 // tensors that can be directly used by fallback kernels in the graph. 65 std::vector<std::unique_ptr< 66 tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>>> 67 resource_async_values_; 68 }; 69 70 // Per-request state in kernel falllback compat mode. 71 class KernelFallbackCompatRequestState { 72 public: 73 // NOTE: This is the constructor for training. 74 KernelFallbackCompatRequestState( 75 std::function<void(std::function<void()>)>* runner, 76 const tensorflow::DeviceMgr* device_manager, int64_t step_id, 77 tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container, 78 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 79 core::RefCountPtr<Rendezvous> rendezvous, 80 tfrt_stub::OpKernelRunnerTable* runner_table, 81 FallbackResourceArray* resource_array, 82 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, 83 const absl::optional<SessionMetadata>& model_metadata, 84 const tensorflow::ProcessFunctionLibraryRuntime* pflr); 85 86 // NOTE: This is the constructor for inference. 87 KernelFallbackCompatRequestState( 88 std::function<void(std::function<void()>)>* runner, 89 const tensorflow::DeviceMgr* device_manager, int64_t step_id, 90 tfrt_stub::OpKernelRunnerTable* runner_table, 91 FallbackResourceArray* resource_array, 92 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, 93 const absl::optional<SessionMetadata>& model_metadata, 94 const tensorflow::ProcessFunctionLibraryRuntime* pflr); 95 96 // Returns the user-specified custom device corresponding to the given device. 97 // It is currently only used for configure per-request intra op threadpool. custom_device(const tensorflow::Device * device)98 tensorflow::Device* custom_device(const tensorflow::Device* device) const { 99 auto it = custom_device_.find(device); 100 if (it == custom_device_.end()) return nullptr; 101 return it->second.get(); 102 } 103 step_container()104 ScopedStepContainer* step_container() const { return step_container_.get(); } 105 device_manager()106 const tensorflow::DeviceMgr& device_manager() const { 107 return *device_manager_; 108 } 109 110 const tensorflow::ProcessFunctionLibraryRuntime& process_function_library_runtime()111 process_function_library_runtime() const { 112 return *pflr_; 113 } 114 collective_executor()115 CollectiveExecutor* collective_executor() const { 116 return collective_executor_; 117 } 118 runner_table()119 tfrt_stub::OpKernelRunnerTable* runner_table() const { return runner_table_; } 120 resource_array()121 FallbackResourceArray* resource_array() const { return resource_array_; } 122 runner()123 std::function<void(std::function<void()>)>* runner() const { return runner_; } 124 cancellation_manager()125 CancellationManager* cancellation_manager() const { 126 return default_cancellation_manager_; 127 } 128 rendezvous()129 RendezvousInterface* rendezvous() const { return rendezvous_.get(); } 130 set_log_device_placement(bool log)131 void set_log_device_placement(bool log) { log_device_placement_ = log; } log_device_placement()132 bool log_device_placement() const { return log_device_placement_; } 133 intra_op_threadpool()134 tensorflow::thread::ThreadPoolInterface* intra_op_threadpool() const { 135 return intra_op_threadpool_; 136 } 137 session_metadata()138 const SessionMetadata& session_metadata() const { return session_metadata_; } 139 140 private: 141 // Below are resources needed by current tensorflow. 142 std::function<void(std::function<void()>)>* runner_ = nullptr; 143 ::tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container_; 144 absl::flat_hash_map<const tensorflow::Device*, 145 std::unique_ptr<tensorflow::Device>> 146 custom_device_; 147 std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle_; 148 CollectiveExecutor* collective_executor_ = nullptr; 149 core::RefCountPtr<Rendezvous> rendezvous_; 150 CancellationManager* default_cancellation_manager_ = nullptr; 151 152 const tensorflow::DeviceMgr* device_manager_ = nullptr; 153 154 // `runner_table` holds the prepopulated tensorflow::OpKernel instances for 155 // kernel fallback compat mode. 156 tfrt_stub::OpKernelRunnerTable* runner_table_ = nullptr; 157 158 // Resource array is used for keeping static values in the runtime. It is 159 // accessed through tfrt_fallback_async.set_resource and 160 // tfrt_fallback_async.get_resource kernels. 161 FallbackResourceArray* resource_array_ = nullptr; 162 163 tensorflow::thread::ThreadPoolInterface* intra_op_threadpool_ = nullptr; 164 165 // Model metadata used for monitoring and tracing purpose. 166 SessionMetadata session_metadata_; 167 168 const tensorflow::ProcessFunctionLibraryRuntime* pflr_ = nullptr; 169 170 bool log_device_placement_ = false; 171 }; 172 173 } // namespace tfd 174 } // namespace tensorflow 175 176 #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ 177