xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/graph_execution_state.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/build_graph_options.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_set.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/graph/costmodel.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 struct SessionOptions;
37 
38 namespace subgraph {
39 struct RewriteGraphMetadata;
40 }
41 
42 struct GraphExecutionStateOptions {
43   const DeviceSet* device_set = nullptr;
44   const SessionOptions* session_options = nullptr;
45   // Unique session identifier. Can be empty.
46   string session_handle;
47   // A map from node name to device name, representing the unchangeable
48   // placement of stateful nodes.
49   std::unordered_map<string, string> stateful_placements;
50 };
51 
52 // A ClientGraph is simply a sub-graph of the full graph as induced by
53 // BuildGraphOptions.
54 struct ClientGraph {
ClientGraphClientGraph55   explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
56                        DataTypeVector feed_types, DataTypeVector fetch_types,
57                        int64_t collective_graph_key)
58       : flib_def(std::move(flib)),
59         graph(flib_def.get()),
60         feed_types(std::move(feed_types)),
61         fetch_types(std::move(fetch_types)),
62         collective_graph_key(collective_graph_key) {}
63   // Each client-graph gets its own function library since optimization passes
64   // post rewrite for execution might want to introduce new functions.
65   std::unique_ptr<FunctionLibraryDefinition> flib_def;
66   Graph graph;
67   DataTypeVector feed_types;
68   DataTypeVector fetch_types;
69   int64_t collective_graph_key;
70 };
71 
72 // GraphExecutionState is responsible for generating an
73 // executable ClientGraph from the original GraphDef that specifies
74 // the complete graph and from BuildGraphOptions which specifies
75 // input/output nodes.
76 //
77 // An executable Graph differs from a GraphDef by being Placed,
78 // meaning that each Node is assigned to a single Device in the
79 // available set.
80 //
81 // When GraphExecutionState is first constructed it instantiates
82 // a full Graph from the provided GraphDef, and places it, using only
83 // the static device assignments from the GraphDef.  Nodes without are
84 // currently placed in a very naive way.  Since stateful Nodes cannot
85 // be moved after initial placement, it is important that stateful
86 // Nodes get sensible initial device assignments in the graph
87 // definition.
88 //
89 // Subsequently, GraphExecutionState generates a SimpleClientGraph on
90 // demand, which is a sub-graph of the latest placement of the full
91 // Graph.  MasterSession uses such a ClientGraph to execute one or
92 // more similar client requests.
93 //
94 // GraphExecutionState is thread-safe.
95 
96 class GraphExecutionState {
97  public:
98   virtual ~GraphExecutionState();
99 
100   // Creates a new `GraphExecutionState` for the given
101   // `graph_def`, which represents the entire graph for a session.
102   static Status MakeForBaseGraph(
103       GraphDef&& graph_def, const GraphExecutionStateOptions& options,
104       std::unique_ptr<GraphExecutionState>* out_state);
105 
106   // Creates a new `GraphExecutionState` and `SimpleClientGraph`
107   // for the subgraph of `original_graph_def` defined by
108   // `subgraph_options`.
109   static Status MakeForPrunedGraph(
110       const GraphExecutionState& base_execution_state,
111       const GraphExecutionStateOptions& options,
112       const BuildGraphOptions& subgraph_options,
113       std::unique_ptr<GraphExecutionState>* out_state,
114       std::unique_ptr<ClientGraph>* out_client_graph);
115 
116   // Creates a new GraphExecutionState representing the
117   // concatenation of this graph, and the graph defined by
118   // "extension_def". The same name may not be used to define a node
119   // in both this graph and "extension_def".
120   //
121   // If successful, returns OK and the caller takes ownership of "*out".
122   // Otherwise returns an error and does not modify "*out".
123   //
124   // After calling `old_state->Extend()`, `old_state` may no longer be
125   // used.
126   //
127   // NOTE(mrry): This method respects the placement of stateful nodes in
128   // in *this, but currently does not transfer any other placement
129   // or cost model information to the new graph.
130   Status Extend(const GraphDef& extension_def,
131                 std::unique_ptr<GraphExecutionState>* out) const;
132 
133   // Builds a ClientGraph (a sub-graph of the full graph as induced by
134   // the Node set specified in "options").  If successful, returns OK
135   // and the caller takes the ownership of "*out". Otherwise, returns
136   // an error.
137   Status BuildGraph(const BuildGraphOptions& options,
138                     std::unique_ptr<ClientGraph>* out);
139 
140   // Optimize the graph with the node set specified in `options`.
141   Status OptimizeGraph(
142       const BuildGraphOptions& options, const Graph& graph,
143       const FunctionLibraryDefinition* flib_def,
144       std::unique_ptr<Graph>* optimized_graph,
145       std::unique_ptr<FunctionLibraryDefinition>* optimized_flib);
146 
147   // The graph returned by BuildGraph may contain only the pruned
148   // graph, whereas some clients may want access to the full graph.
full_graph()149   const Graph* full_graph() { return graph_; }
150 
151   // The original graph.
original_graph_def()152   GraphDef* original_graph_def() { return original_graph_def_.get(); }
153 
154   // The original function library of this graph.
flib_def()155   const FunctionLibraryDefinition& flib_def() const { return *flib_def_; }
156 
157   // Returns the node with the given name, or null if it does not exist.
get_node_by_name(const string & name)158   const Node* get_node_by_name(const string& name) const {
159     NodeNameToCostIdMap::const_iterator iter =
160         node_name_to_cost_id_map_.find(name);
161     if (iter != node_name_to_cost_id_map_.end()) {
162       return graph_->FindNodeId(iter->second);
163     } else {
164       return nullptr;
165     }
166   }
167 
168   // Returns the map of stateful placements as a map of
169   // node name to placement string.
GetStatefulPlacements()170   std::unordered_map<string, string> GetStatefulPlacements() const {
171     return stateful_placements_;
172   }
173 
174  private:
175   GraphExecutionState(std::unique_ptr<GraphDef>&& graph_def,
176                       std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
177                       const GraphExecutionStateOptions& options);
178 
179   Status InitBaseGraph(std::unique_ptr<Graph>&& graph);
180 
181   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
182   // is true, such as "params" and "queue" nodes.  Once placed these
183   // nodes can not be moved to a different device.  Maps node names to
184   // device names.
185   std::unordered_map<string, string> stateful_placements_;  // Immutable after
186                                                             // ctor.
187   void SaveStatefulNodes(Graph* graph);
188   void RestoreStatefulNodes(Graph* graph);
189 
190   // Extract the subset of the graph that needs to be run, adding feed/fetch
191   // ops as needed.
192   Status PruneGraph(const BuildGraphOptions& options, Graph* graph,
193                     subgraph::RewriteGraphMetadata* out_rewrite_metadata);
194 
195   // The GraphExecutionState must store a copy of the original GraphDef if
196   // either of the following conditions holds:
197   //
198   // * `session_options_.config.graph_options().place_pruned_graph()` is true.
199   // * `session_options_.config.experimental().optimize_for_static_graph()` is
200   //   false.
201   const std::unique_ptr<GraphDef> original_graph_def_;
202 
203   const DeviceSet* device_set_;            // Not owned
204   const SessionOptions* session_options_;  // Not owned
205   // Unique session identifier. Can be empty.
206   string session_handle_;
207 
208   // Map from name to Node for the full graph in placed_.
209   NodeNameToCostIdMap node_name_to_cost_id_map_;
210 
211   // 'flib_def_' is initialized from the initial graph def's library,
212   // and may be updated by a graph optimization pass.
213   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
214 
215   // `rewrite_metadata_` is only set for GraphExecutionState
216   // objects created by `MakeForPrunedGraph()`.
217   std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
218 
219   // The dataflow graph owned by this object.
220   Graph* graph_;
221 
222   TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState);
223 };
224 
225 }  // namespace tensorflow
226 
227 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
228