1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ 16 #define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ 17 18 #include <functional> 19 #include <memory> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "tensorflow/core/protobuf/config.pb.h" 26 #include "tensorflow/core/tfrt/fallback/fallback_state.h" 27 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" 28 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" 29 #include "tensorflow/core/tfrt/tpu/tpu_resources.h" // NOLINT(unused-includes): For tfrt::tpu::TpuModelResource 30 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h" 31 #include "tfrt/bef/bef_buffer.h" // from @tf_runtime 32 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime 33 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime 34 #include "tfrt/host_context/execution_context.h" // from @tf_runtime 35 #include "tfrt/host_context/function.h" // from @tf_runtime 36 #include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime 37 #include "tfrt/support/ref_count.h" // from @tf_runtime 38 39 namespace tensorflow { 40 namespace tfrt_stub { 41 42 // Contains request related info. 43 struct RequestInfo { 44 tfrt::RCReference<tfrt::RequestContext> tfrt_request_context; 45 std::unique_ptr<WorkQueueInterface> request_queue; 46 std::function<void(std::function<void()>)> runner; 47 }; 48 49 // Creates a `RequestInfo` given relative data. 50 StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext( 51 const GraphExecutionRunOptions& run_options, 52 const SessionMetadata& model_metadata, tfrt::HostContext* host, 53 tensorflow::tfrt_stub::WorkQueueInterface* work_queue, 54 tfrt::ResourceContext* resource_context, 55 const FallbackState& fallback_state); 56 57 // Runs on a function given input/output and other info. 58 tensorflow::Status GraphExecutionRunOnFunction( 59 const GraphExecutionOptions& options, 60 const GraphExecutionRunOptions& run_options, 61 absl::string_view signature_name, const tfrt::Function& func, 62 absl::Span<const tensorflow::Tensor> inputs, 63 absl::Span<const tensorflow::Tensor> captures, 64 std::vector<tensorflow::Tensor>* outputs, 65 tfrt::ResourceContext* resource_context, const Runtime& runtime, 66 const FallbackState& fallback_state, 67 tfrt::RequestDeadlineTracker& req_deadline_tracker); 68 69 // Creates a ResourceContext and populate it with per model resource from 70 // Runtime. If `tpu_target` is set to kTpurt, also call a special 71 // `AddTpuResources` function to populate TPU related resources for tpurt. 72 // 73 // TODO(b/178227859): Remove the need for the special handling for TPU here. 74 std::unique_ptr<tfrt::ResourceContext> CreateResourceContext( 75 const Runtime& runtime, tfrt::tpu::TpuModelResource* tpu_model_resource, 76 tensorflow::TfrtTpuInfraTarget tpu_target); 77 78 // Loads (if not yet) and runs a subgraph in a graph as per each request. 79 class GraphExecutor { 80 public: 81 using Options = GraphExecutionOptions; 82 using RunOptions = GraphExecutionRunOptions; 83 84 // The loading result of a `ClientGraph`. 85 struct LoadedClientGraph { 86 std::string name; 87 tfrt::BefBuffer bef; 88 tfrt::RCReference<tfrt::BEFFile> bef_file; 89 std::unique_ptr<tfrt::ResourceContext> resource_context; 90 }; 91 92 // A subgraph constructed by specifying input/output tensors. 93 struct ClientGraph { 94 // A unique name by joining all the input/output/target names. 95 std::string name; 96 // The feed nodes for the corresponding inputs, but they might not be in the 97 // original order and if there are more than one original inputs mapped to 98 // the same feed node, only one is picked here. 99 tensorflow::GraphImportConfig::InputArrays input_nodes; 100 // The fetch nodes for the outputs, which should be in the original order. 101 std::vector<std::string> output_nodes; 102 // The target nodes that should be run but not returned as outputs. 103 std::vector<std::string> target_nodes; 104 }; 105 106 // Creates a `GraphExecutor` given the args. 107 static StatusOr<std::unique_ptr<GraphExecutor>> Create( 108 Options options, const FallbackState& fallback_state, 109 tfrt::tpu::TpuModelResource* tpu_model_resource, 110 tensorflow::GraphDef graph_def); 111 112 // Ctor. Public for `Create()`. Do not use directly. GraphExecutor(Options options,const FallbackState & fallback_state,tfrt::tpu::TpuModelResource * tpu_model_resource,std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState> graph_execution_state)113 GraphExecutor(Options options, const FallbackState& fallback_state, 114 tfrt::tpu::TpuModelResource* tpu_model_resource, 115 std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState> 116 graph_execution_state) 117 : options_(std::move(options)), 118 fallback_state_(fallback_state), 119 tpu_model_resource_(tpu_model_resource), 120 graph_execution_state_(std::move(graph_execution_state)), 121 req_deadline_tracker_( 122 options_.runtime->core_runtime()->GetHostContext()) {} 123 124 // Runs on the graph according to given input/output. 125 tensorflow::Status Run( 126 const RunOptions& run_options, 127 absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs, 128 absl::Span<const std::string> output_tensor_names, 129 absl::Span<const std::string> target_tensor_names, 130 std::vector<tensorflow::Tensor>* outputs); 131 132 // Extends the current graph by `graph`. 133 tensorflow::Status Extend(const GraphDef& graph); 134 graph_execution_state()135 tensorflow::tfrt_stub::TfrtGraphExecutionState& graph_execution_state() 136 const { 137 return *graph_execution_state_; 138 } 139 140 // Compiles and returns a graph that is specified by `client_graph`. 141 StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>> 142 ImportAndCompileClientGraph(const GraphExecutor::ClientGraph& client_graph); 143 144 // Returns the underlying runtime. runtime()145 const tensorflow::tfrt_stub::Runtime& runtime() const { 146 DCHECK(options_.runtime); 147 return *options_.runtime; 148 } 149 150 private: 151 // A set of methods to load a client graph. 152 StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>> LoadClientGraph( 153 const GraphExecutor::ClientGraph& client_graph, 154 tensorflow::tfrt_stub::WorkQueueInterface* work_queue); 155 tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> 156 ImportClientGraphToMlirModule(const GraphExecutor::ClientGraph& client_graph, 157 mlir::MLIRContext* context) const; 158 StatusOr<tfrt::BefBuffer> CompileMlirModuleToBef(mlir::ModuleOp module) const; 159 tensorflow::Status InitBef( 160 tfrt::BEFFile* bef_file, tfrt::ResourceContext* resource_context, 161 tensorflow::tfrt_stub::WorkQueueInterface* work_queue); 162 163 // Returns a `LoadedClientGraph` given input/output tensor info. If there is 164 // no existing one yet, creates one first. 165 StatusOr<std::reference_wrapper<const GraphExecutor::LoadedClientGraph>> 166 GetOrCreateLoadedClientGraph( 167 absl::Span<const std::string> input_tensor_names, 168 absl::Span<const tensorflow::DataType> input_tensor_dtypes, 169 absl::Span<const std::string> output_tensor_names, 170 absl::Span<const std::string> target_tensor_names, 171 tensorflow::tfrt_stub::WorkQueueInterface* work_queue) 172 TF_LOCKS_EXCLUDED(loaded_client_graphs_mu_); 173 174 Options options_; 175 std::reference_wrapper<const FallbackState> fallback_state_; 176 tfrt::tpu::TpuModelResource* tpu_model_resource_; // NOT owned. 177 178 std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState> 179 graph_execution_state_; 180 181 tfrt::RequestDeadlineTracker req_deadline_tracker_; 182 183 tensorflow::mutex loaded_client_graphs_mu_; 184 // Caches `LoadedClientGraph` by the joined name. 185 // For pointer stability of values in `absl::flat_hash_map<>`, additional 186 // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.) 187 absl::flat_hash_map<std::string /*joined_name*/, 188 std::unique_ptr<LoadedClientGraph>> 189 loaded_client_graphs_ TF_GUARDED_BY(loaded_client_graphs_mu_); 190 }; 191 192 } // namespace tfrt_stub 193 } // namespace tensorflow 194 195 #endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ 196