1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__
16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 
22 #include "tensorflow/core/common_runtime/eager/context.h"
23 #include "tensorflow/core/framework/device.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/platform/refcount.h"
27 #include "tensorflow/core/platform/threadpool_interface.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
30 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
31 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
32 #include "tfrt/support/pointer_util.h"  // from @tf_runtime
33 
34 namespace tensorflow {
35 namespace tfd {
36 
37 // FallbackResourceArray holds the tensors that are computed only once during
38 // initialization and read-only afterwards.
39 class FallbackResourceArray {
40  public:
41   // Sets `tensor` in the array at `index`. `index` should be dense and
42   // duplicate indices are not allowed.
43   void SetResource(int index, tensorflow::tfrt_stub::ImmutableTensor tensor);
44 
45   // Returns the resource tensor wrapped in AsyncValue value at `index`.
46   tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>*
GetResource(int index)47   GetResource(int index) const {
48     return resource_async_values_.at(index).get();
49   }
50 
51   // Returns the resource tensor at `index`.
GetResourceAsFallbackTensor(int index)52   const tensorflow::tfrt_stub::FallbackTensor& GetResourceAsFallbackTensor(
53       int index) const {
54     return resource_async_values_.at(index)->get();
55   }
56 
57  private:
58   // `resources_` holds the ownership of all the resource tensors. Note that it
59   // may not be a one-to-one mapping between `resources_` and
60   // `resource_async_values_`.
61   std::vector<std::unique_ptr<tensorflow::tfrt_stub::ImmutableTensor>>
62       resources_;
63   // `resource_async_values_` holds the UnRefCountedAsyncValue of the fallback
64   // tensors that can be directly used by fallback kernels in the graph.
65   std::vector<std::unique_ptr<
66       tfrt::UnRefCountedAsyncValue<tensorflow::tfrt_stub::FallbackTensor>>>
67       resource_async_values_;
68 };
69 
70 // Per-request state in kernel falllback compat mode.
71 class KernelFallbackCompatRequestState {
72  public:
73   // NOTE: This is the constructor for training.
74   KernelFallbackCompatRequestState(
75       std::function<void(std::function<void()>)>* runner,
76       const tensorflow::DeviceMgr* device_manager, int64_t step_id,
77       tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container,
78       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
79       core::RefCountPtr<Rendezvous> rendezvous,
80       tfrt_stub::OpKernelRunnerTable* runner_table,
81       FallbackResourceArray* resource_array,
82       tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
83       const absl::optional<SessionMetadata>& model_metadata,
84       const tensorflow::ProcessFunctionLibraryRuntime* pflr);
85 
86   // NOTE: This is the constructor for inference.
87   KernelFallbackCompatRequestState(
88       std::function<void(std::function<void()>)>* runner,
89       const tensorflow::DeviceMgr* device_manager, int64_t step_id,
90       tfrt_stub::OpKernelRunnerTable* runner_table,
91       FallbackResourceArray* resource_array,
92       tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
93       const absl::optional<SessionMetadata>& model_metadata,
94       const tensorflow::ProcessFunctionLibraryRuntime* pflr);
95 
96   // Returns the user-specified custom device corresponding to the given device.
97   // It is currently only used for configure per-request intra op threadpool.
custom_device(const tensorflow::Device * device)98   tensorflow::Device* custom_device(const tensorflow::Device* device) const {
99     auto it = custom_device_.find(device);
100     if (it == custom_device_.end()) return nullptr;
101     return it->second.get();
102   }
103 
step_container()104   ScopedStepContainer* step_container() const { return step_container_.get(); }
105 
device_manager()106   const tensorflow::DeviceMgr& device_manager() const {
107     return *device_manager_;
108   }
109 
110   const tensorflow::ProcessFunctionLibraryRuntime&
process_function_library_runtime()111   process_function_library_runtime() const {
112     return *pflr_;
113   }
114 
collective_executor()115   CollectiveExecutor* collective_executor() const {
116     return collective_executor_;
117   }
118 
runner_table()119   tfrt_stub::OpKernelRunnerTable* runner_table() const { return runner_table_; }
120 
resource_array()121   FallbackResourceArray* resource_array() const { return resource_array_; }
122 
runner()123   std::function<void(std::function<void()>)>* runner() const { return runner_; }
124 
cancellation_manager()125   CancellationManager* cancellation_manager() const {
126     return default_cancellation_manager_;
127   }
128 
rendezvous()129   RendezvousInterface* rendezvous() const { return rendezvous_.get(); }
130 
set_log_device_placement(bool log)131   void set_log_device_placement(bool log) { log_device_placement_ = log; }
log_device_placement()132   bool log_device_placement() const { return log_device_placement_; }
133 
intra_op_threadpool()134   tensorflow::thread::ThreadPoolInterface* intra_op_threadpool() const {
135     return intra_op_threadpool_;
136   }
137 
session_metadata()138   const SessionMetadata& session_metadata() const { return session_metadata_; }
139 
140  private:
141   // Below are resources needed by current tensorflow.
142   std::function<void(std::function<void()>)>* runner_ = nullptr;
143   ::tfrt::OwnedOrUnownedPtr<ScopedStepContainer> step_container_;
144   absl::flat_hash_map<const tensorflow::Device*,
145                       std::unique_ptr<tensorflow::Device>>
146       custom_device_;
147   std::unique_ptr<CollectiveExecutor::Handle> collective_executor_handle_;
148   CollectiveExecutor* collective_executor_ = nullptr;
149   core::RefCountPtr<Rendezvous> rendezvous_;
150   CancellationManager* default_cancellation_manager_ = nullptr;
151 
152   const tensorflow::DeviceMgr* device_manager_ = nullptr;
153 
154   // `runner_table` holds the prepopulated tensorflow::OpKernel instances for
155   // kernel fallback compat mode.
156   tfrt_stub::OpKernelRunnerTable* runner_table_ = nullptr;
157 
158   // Resource array is used for keeping static values in the runtime. It is
159   // accessed through tfrt_fallback_async.set_resource and
160   // tfrt_fallback_async.get_resource kernels.
161   FallbackResourceArray* resource_array_ = nullptr;
162 
163   tensorflow::thread::ThreadPoolInterface* intra_op_threadpool_ = nullptr;
164 
165   // Model metadata used for monitoring and tracing purpose.
166   SessionMetadata session_metadata_;
167 
168   const tensorflow::ProcessFunctionLibraryRuntime* pflr_ = nullptr;
169 
170   bool log_device_placement_ = false;
171 };
172 
173 }  // namespace tfd
174 }  // namespace tensorflow
175 
176 #endif  // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__
177