xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_
16 #define TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/synchronization/mutex.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
27 #include "tensorflow/core/common_runtime/graph_execution_state.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/protobuf/config.pb.h"
32 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
33 
34 namespace tensorflow {
35 namespace tfrt_stub {
36 
37 // This is a TFRT variant of `tensorflow::GraphExecutionState`. It wraps
38 // `tensorflow::GraphExecutionState` and adds TFRT-specific adjustments.
39 //
40 // Responsible for generating an executable `Graph` from the original `GraphDef`
41 // that specifies the complete graph and from `GraphImportConfig` that specifies
42 // input/output nodes.
43 //
44 // Thread-safe.
45 class TfrtGraphExecutionState {
46  public:
47   struct OptimizationResult {
48     std::unique_ptr<tensorflow::Graph> graph;
49     absl::Duration functionalization_duration;
50     absl::Duration grappler_duration;
51   };
52 
53   struct Options {
54     bool run_placer_grappler_on_functions = false;
55     bool enable_tfrt_gpu = false;
56   };
57 
58   // Creates a `GraphExecutionState` given `graph_def` and `fallback_state`.
59   static StatusOr<std::unique_ptr<TfrtGraphExecutionState>> Create(
60       const Options& options, tensorflow::GraphDef graph_def,
61       const FallbackState& fallback_state);
62 
63   // Ctor. Do not use directly. Public only for `std::make_unique<>()`.
TfrtGraphExecutionState(const Options & options,std::unique_ptr<tensorflow::GraphExecutionState> graph_execution_state,const FallbackState & fallback_state,absl::flat_hash_set<std::string> functions_to_optimize)64   TfrtGraphExecutionState(
65       const Options& options,
66       std::unique_ptr<tensorflow::GraphExecutionState> graph_execution_state,
67       const FallbackState& fallback_state,
68       absl::flat_hash_set<std::string> functions_to_optimize)
69       : options_(options),
70         graph_execution_state_(std::move(graph_execution_state)),
71         fallback_state_(fallback_state),
72         functions_to_optimize_(std::move(functions_to_optimize)) {}
73 
74   // Creates an optimized graph by pruning with `graph_import_config` and
75   // best-effort Grappler run.
76   StatusOr<OptimizationResult> CreateOptimizedGraph(
77       tensorflow::GraphImportConfig& graph_import_config);
78 
79   // Extends the current graph by `graph`.
80   Status Extend(const GraphDef& graph);
81 
82   // Return the preprocessed full graph. Note that it does not contain the
83   // function library in the original graph.
graph()84   const tensorflow::Graph& graph() const {
85     absl::MutexLock lock(&graph_execution_state_mu_);
86     DCHECK(graph_execution_state_->full_graph());
87     return *graph_execution_state_->full_graph();
88   }
89 
90   // The original graph.
original_graph_def()91   const GraphDef* original_graph_def() const {
92     absl::MutexLock lock(&graph_execution_state_mu_);
93     return graph_execution_state_->original_graph_def();
94   }
95 
96  private:
97   // Return the function library in the original graph.
flib_def()98   const FunctionLibraryDefinition& flib_def() const {
99     absl::MutexLock lock(&graph_execution_state_mu_);
100     return graph_execution_state_->flib_def();
101   }
102 
103   StatusOr<std::unique_ptr<tensorflow::Graph>> OptimizeGraph(
104       const tensorflow::Graph& graph,
105       const tensorflow::BuildGraphOptions& build_graph_options);
106 
107   Options options_;
108 
109   std::unique_ptr<tensorflow::GraphExecutionState> graph_execution_state_
110       ABSL_GUARDED_BY(graph_execution_state_mu_);
111   // We need this mutex even thought `GraphExecutionState` is thread-safe,
112   // because `swap()` is not thread-safe.
113   mutable absl::Mutex graph_execution_state_mu_;
114 
115   const FallbackState& fallback_state_;
116   // Only valid if `options_.run_placer_grappler_on_functions` is true.
117   absl::flat_hash_set<std::string> functions_to_optimize_;
118 };
119 
120 // Prunes the `graph_def` using the feed/fetch nodes specified in
121 // `callable_options`. It is a TFRT-specific version that it performs more
122 // pruning (e.g., prunes the input edges to the feed nodes) than
123 // `ComputeTransitiveFanin()` so that the graph can be functionalized properly
124 // later.
125 Status PruneGraphDef(GraphDef& graph_def,
126                      const CallableOptions& callable_options);
127 
128 // Eliminates ref variables in V1 control flow, which is required for
129 // functionalization. Current strategy is to insert an identity node between
130 // each ref node and its ref input and in-place update the ref node to its
131 // non-ref counterpart.
132 Status EliminateRefVariablesFromV1ControlFlow(GraphDef& graph_def);
133 
134 // Removes the "_input_shapes" attribute of functions in the graph.
135 void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def);
136 
137 // Replaces partitioned calls in the graph that have _XlaMustCompile attribute
138 // set to true with XlaLaunch op.
139 // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA
140 // bridge.
141 Status BuildXlaLaunchOps(Graph* graph);
142 
143 }  // namespace tfrt_stub
144 }  // namespace tensorflow
145 
146 #endif  // TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_
147