1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_ 16 #define TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_ 17 18 #include <functional> 19 #include <memory> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/synchronization/mutex.h" 25 #include "absl/time/time.h" 26 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" 27 #include "tensorflow/core/common_runtime/graph_execution_state.h" 28 #include "tensorflow/core/framework/graph.pb.h" 29 #include "tensorflow/core/graph/graph.h" 30 #include "tensorflow/core/platform/statusor.h" 31 #include "tensorflow/core/protobuf/config.pb.h" 32 #include "tensorflow/core/tfrt/fallback/fallback_state.h" 33 34 namespace tensorflow { 35 namespace tfrt_stub { 36 37 // This is a TFRT variant of `tensorflow::GraphExecutionState`. It wraps 38 // `tensorflow::GraphExecutionState` and adds TFRT-specific adjustments. 39 // 40 // Responsible for generating an executable `Graph` from the original `GraphDef` 41 // that specifies the complete graph and from `GraphImportConfig` that specifies 42 // input/output nodes. 43 // 44 // Thread-safe. 45 class TfrtGraphExecutionState { 46 public: 47 struct OptimizationResult { 48 std::unique_ptr<tensorflow::Graph> graph; 49 absl::Duration functionalization_duration; 50 absl::Duration grappler_duration; 51 }; 52 53 struct Options { 54 bool run_placer_grappler_on_functions = false; 55 bool enable_tfrt_gpu = false; 56 }; 57 58 // Creates a `GraphExecutionState` given `graph_def` and `fallback_state`. 59 static StatusOr<std::unique_ptr<TfrtGraphExecutionState>> Create( 60 const Options& options, tensorflow::GraphDef graph_def, 61 const FallbackState& fallback_state); 62 63 // Ctor. Do not use directly. Public only for `std::make_unique<>()`. TfrtGraphExecutionState(const Options & options,std::unique_ptr<tensorflow::GraphExecutionState> graph_execution_state,const FallbackState & fallback_state,absl::flat_hash_set<std::string> functions_to_optimize)64 TfrtGraphExecutionState( 65 const Options& options, 66 std::unique_ptr<tensorflow::GraphExecutionState> graph_execution_state, 67 const FallbackState& fallback_state, 68 absl::flat_hash_set<std::string> functions_to_optimize) 69 : options_(options), 70 graph_execution_state_(std::move(graph_execution_state)), 71 fallback_state_(fallback_state), 72 functions_to_optimize_(std::move(functions_to_optimize)) {} 73 74 // Creates an optimized graph by pruning with `graph_import_config` and 75 // best-effort Grappler run. 76 StatusOr<OptimizationResult> CreateOptimizedGraph( 77 tensorflow::GraphImportConfig& graph_import_config); 78 79 // Extends the current graph by `graph`. 80 Status Extend(const GraphDef& graph); 81 82 // Return the preprocessed full graph. Note that it does not contain the 83 // function library in the original graph. graph()84 const tensorflow::Graph& graph() const { 85 absl::MutexLock lock(&graph_execution_state_mu_); 86 DCHECK(graph_execution_state_->full_graph()); 87 return *graph_execution_state_->full_graph(); 88 } 89 90 // The original graph. original_graph_def()91 const GraphDef* original_graph_def() const { 92 absl::MutexLock lock(&graph_execution_state_mu_); 93 return graph_execution_state_->original_graph_def(); 94 } 95 96 private: 97 // Return the function library in the original graph. flib_def()98 const FunctionLibraryDefinition& flib_def() const { 99 absl::MutexLock lock(&graph_execution_state_mu_); 100 return graph_execution_state_->flib_def(); 101 } 102 103 StatusOr<std::unique_ptr<tensorflow::Graph>> OptimizeGraph( 104 const tensorflow::Graph& graph, 105 const tensorflow::BuildGraphOptions& build_graph_options); 106 107 Options options_; 108 109 std::unique_ptr<tensorflow::GraphExecutionState> graph_execution_state_ 110 ABSL_GUARDED_BY(graph_execution_state_mu_); 111 // We need this mutex even thought `GraphExecutionState` is thread-safe, 112 // because `swap()` is not thread-safe. 113 mutable absl::Mutex graph_execution_state_mu_; 114 115 const FallbackState& fallback_state_; 116 // Only valid if `options_.run_placer_grappler_on_functions` is true. 117 absl::flat_hash_set<std::string> functions_to_optimize_; 118 }; 119 120 // Prunes the `graph_def` using the feed/fetch nodes specified in 121 // `callable_options`. It is a TFRT-specific version that it performs more 122 // pruning (e.g., prunes the input edges to the feed nodes) than 123 // `ComputeTransitiveFanin()` so that the graph can be functionalized properly 124 // later. 125 Status PruneGraphDef(GraphDef& graph_def, 126 const CallableOptions& callable_options); 127 128 // Eliminates ref variables in V1 control flow, which is required for 129 // functionalization. Current strategy is to insert an identity node between 130 // each ref node and its ref input and in-place update the ref node to its 131 // non-ref counterpart. 132 Status EliminateRefVariablesFromV1ControlFlow(GraphDef& graph_def); 133 134 // Removes the "_input_shapes" attribute of functions in the graph. 135 void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def); 136 137 // Replaces partitioned calls in the graph that have _XlaMustCompile attribute 138 // set to true with XlaLaunch op. 139 // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA 140 // bridge. 141 Status BuildXlaLaunchOps(Graph* graph); 142 143 } // namespace tfrt_stub 144 } // namespace tensorflow 145 146 #endif // TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_ 147