xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/graph_executor/graph_executor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
16 #define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "tensorflow/core/protobuf/config.pb.h"
26 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
27 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
28 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
29 #include "tensorflow/core/tfrt/tpu/tpu_resources.h"  // NOLINT(unused-includes): For tfrt::tpu::TpuModelResource
30 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
31 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
32 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
33 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
34 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
35 #include "tfrt/host_context/function.h"  // from @tf_runtime
36 #include "tfrt/host_context/request_deadline_tracker.h"  // from @tf_runtime
37 #include "tfrt/support/ref_count.h"  // from @tf_runtime
38 
39 namespace tensorflow {
40 namespace tfrt_stub {
41 
42 // Contains request related info.
43 struct RequestInfo {
44   tfrt::RCReference<tfrt::RequestContext> tfrt_request_context;
45   std::unique_ptr<WorkQueueInterface> request_queue;
46   std::function<void(std::function<void()>)> runner;
47 };
48 
49 // Creates a `RequestInfo` given relative data.
50 StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext(
51     const GraphExecutionRunOptions& run_options,
52     const SessionMetadata& model_metadata, tfrt::HostContext* host,
53     tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
54     tfrt::ResourceContext* resource_context,
55     const FallbackState& fallback_state);
56 
57 // Runs on a function given input/output and other info.
58 tensorflow::Status GraphExecutionRunOnFunction(
59     const GraphExecutionOptions& options,
60     const GraphExecutionRunOptions& run_options,
61     absl::string_view signature_name, const tfrt::Function& func,
62     absl::Span<const tensorflow::Tensor> inputs,
63     absl::Span<const tensorflow::Tensor> captures,
64     std::vector<tensorflow::Tensor>* outputs,
65     tfrt::ResourceContext* resource_context, const Runtime& runtime,
66     const FallbackState& fallback_state,
67     tfrt::RequestDeadlineTracker& req_deadline_tracker);
68 
69 // Creates a ResourceContext and populate it with per model resource from
70 // Runtime. If `tpu_target` is set to kTpurt, also call a special
71 // `AddTpuResources` function to populate TPU related resources for tpurt.
72 //
73 // TODO(b/178227859): Remove the need for the special handling for TPU here.
74 std::unique_ptr<tfrt::ResourceContext> CreateResourceContext(
75     const Runtime& runtime, tfrt::tpu::TpuModelResource* tpu_model_resource,
76     tensorflow::TfrtTpuInfraTarget tpu_target);
77 
78 // Loads (if not yet) and runs a subgraph in a graph as per each request.
79 class GraphExecutor {
80  public:
81   using Options = GraphExecutionOptions;
82   using RunOptions = GraphExecutionRunOptions;
83 
84   // The loading result of a `ClientGraph`.
85   struct LoadedClientGraph {
86     std::string name;
87     tfrt::BefBuffer bef;
88     tfrt::RCReference<tfrt::BEFFile> bef_file;
89     std::unique_ptr<tfrt::ResourceContext> resource_context;
90   };
91 
92   // A subgraph constructed by specifying input/output tensors.
93   struct ClientGraph {
94     // A unique name by joining all the input/output/target names.
95     std::string name;
96     // The feed nodes for the corresponding inputs, but they might not be in the
97     // original order and if there are more than one original inputs mapped to
98     // the same feed node, only one is picked here.
99     tensorflow::GraphImportConfig::InputArrays input_nodes;
100     // The fetch nodes for the outputs, which should be in the original order.
101     std::vector<std::string> output_nodes;
102     // The target nodes that should be run but not returned as outputs.
103     std::vector<std::string> target_nodes;
104   };
105 
106   // Creates a `GraphExecutor` given the args.
107   static StatusOr<std::unique_ptr<GraphExecutor>> Create(
108       Options options, const FallbackState& fallback_state,
109       tfrt::tpu::TpuModelResource* tpu_model_resource,
110       tensorflow::GraphDef graph_def);
111 
112   // Ctor. Public for `Create()`. Do not use directly.
GraphExecutor(Options options,const FallbackState & fallback_state,tfrt::tpu::TpuModelResource * tpu_model_resource,std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState> graph_execution_state)113   GraphExecutor(Options options, const FallbackState& fallback_state,
114                 tfrt::tpu::TpuModelResource* tpu_model_resource,
115                 std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState>
116                     graph_execution_state)
117       : options_(std::move(options)),
118         fallback_state_(fallback_state),
119         tpu_model_resource_(tpu_model_resource),
120         graph_execution_state_(std::move(graph_execution_state)),
121         req_deadline_tracker_(
122             options_.runtime->core_runtime()->GetHostContext()) {}
123 
124   // Runs on the graph according to given input/output.
125   tensorflow::Status Run(
126       const RunOptions& run_options,
127       absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
128       absl::Span<const std::string> output_tensor_names,
129       absl::Span<const std::string> target_tensor_names,
130       std::vector<tensorflow::Tensor>* outputs);
131 
132   // Extends the current graph by `graph`.
133   tensorflow::Status Extend(const GraphDef& graph);
134 
graph_execution_state()135   tensorflow::tfrt_stub::TfrtGraphExecutionState& graph_execution_state()
136       const {
137     return *graph_execution_state_;
138   }
139 
140   // Compiles and returns a graph that is specified by `client_graph`.
141   StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>>
142   ImportAndCompileClientGraph(const GraphExecutor::ClientGraph& client_graph);
143 
144   // Returns the underlying runtime.
runtime()145   const tensorflow::tfrt_stub::Runtime& runtime() const {
146     DCHECK(options_.runtime);
147     return *options_.runtime;
148   }
149 
150  private:
151   // A set of methods to load a client graph.
152   StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>> LoadClientGraph(
153       const GraphExecutor::ClientGraph& client_graph,
154       tensorflow::tfrt_stub::WorkQueueInterface* work_queue);
155   tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
156   ImportClientGraphToMlirModule(const GraphExecutor::ClientGraph& client_graph,
157                                 mlir::MLIRContext* context) const;
158   StatusOr<tfrt::BefBuffer> CompileMlirModuleToBef(mlir::ModuleOp module) const;
159   tensorflow::Status InitBef(
160       tfrt::BEFFile* bef_file, tfrt::ResourceContext* resource_context,
161       tensorflow::tfrt_stub::WorkQueueInterface* work_queue);
162 
163   // Returns a `LoadedClientGraph` given input/output tensor info. If there is
164   // no existing one yet, creates one first.
165   StatusOr<std::reference_wrapper<const GraphExecutor::LoadedClientGraph>>
166   GetOrCreateLoadedClientGraph(
167       absl::Span<const std::string> input_tensor_names,
168       absl::Span<const tensorflow::DataType> input_tensor_dtypes,
169       absl::Span<const std::string> output_tensor_names,
170       absl::Span<const std::string> target_tensor_names,
171       tensorflow::tfrt_stub::WorkQueueInterface* work_queue)
172       TF_LOCKS_EXCLUDED(loaded_client_graphs_mu_);
173 
174   Options options_;
175   std::reference_wrapper<const FallbackState> fallback_state_;
176   tfrt::tpu::TpuModelResource* tpu_model_resource_;  // NOT owned.
177 
178   std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState>
179       graph_execution_state_;
180 
181   tfrt::RequestDeadlineTracker req_deadline_tracker_;
182 
183   tensorflow::mutex loaded_client_graphs_mu_;
184   // Caches `LoadedClientGraph` by the joined name.
185   // For pointer stability of values in `absl::flat_hash_map<>`, additional
186   // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.)
187   absl::flat_hash_map<std::string /*joined_name*/,
188                       std::unique_ptr<LoadedClientGraph>>
189       loaded_client_graphs_ TF_GUARDED_BY(loaded_client_graphs_mu_);
190 };
191 
192 }  // namespace tfrt_stub
193 }  // namespace tensorflow
194 
195 #endif  // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
196