xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/synchronization/mutex.h"
27 #include "absl/time/clock.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/jit/defs.h"
30 #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
32 #include "tensorflow/core/common_runtime/function_body.h"
33 #include "tensorflow/core/common_runtime/function_def_utils.h"
34 #include "tensorflow/core/common_runtime/graph_constructor.h"
35 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
36 #include "tensorflow/core/common_runtime/optimization_registry.h"
37 #include "tensorflow/core/common_runtime/placer.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/function.pb.h"
41 #include "tensorflow/core/framework/graph.pb.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_util.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/op_def.pb.h"
46 #include "tensorflow/core/framework/versions.pb.h"
47 #include "tensorflow/core/graph/graph.h"
48 #include "tensorflow/core/graph/node_builder.h"
49 #include "tensorflow/core/grappler/utils.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/statusor.h"
53 #include "tensorflow/core/protobuf/config.pb.h"
54 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
55 #include "tensorflow/core/tfrt/utils/graph_partition.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 
58 namespace tensorflow {
59 namespace tfrt_stub {
60 
61 namespace {
62 
63 // Finds the names of functions that are safe to optimize.
FindFunctionsToOptimize(const GraphDef & graph_def)64 absl::flat_hash_set<std::string> FindFunctionsToOptimize(
65     const GraphDef& graph_def) {
66   // TODO(b/203689805): Add more functional ops.
67   static const auto* const kOpWhitelist = new absl::flat_hash_set<std::string>{
68       "PartitionedCall", "StatefulPartitionedCall"};
69   absl::flat_hash_map<
70       std::string /*function_name*/,
71       absl::flat_hash_set<std::string> /*ops_using_the_function*/>
72       function_to_ops;
73 
74   auto build_map = [&](const auto& node_defs) {
75     for (const auto& node_def : node_defs) {
76       for (const auto& p : node_def.attr()) {
77         const AttrValue& attr_value = p.second;
78         if (!attr_value.has_func()) continue;
79         function_to_ops[attr_value.func().name()].insert(node_def.op());
80       }
81     }
82   };
83 
84   build_map(graph_def.node());
85   for (const auto& function_def : graph_def.library().function()) {
86     build_map(function_def.node_def());
87   }
88 
89   absl::flat_hash_set<std::string> functions_to_optimize;
90   for (const auto& p : function_to_ops) {
91     const std::string& function_name = p.first;
92     const absl::flat_hash_set<std::string>& ops = p.second;
93     // Optimize a function iff all the ops that use it are whitelisted.
94     if (std::all_of(ops.begin(), ops.end(), [](const auto& op) {
95           return kOpWhitelist->contains(op);
96         })) {
97       functions_to_optimize.insert(function_name);
98     }
99   }
100 
101   return functions_to_optimize;
102 }
103 
104 // Preprocesses `graph_def`, returns the functions to optimize if
105 // `run_placer_grappler_on_functions` is true.
PreprocessGraph(tensorflow::GraphDef & graph_def,bool run_placer_grappler_on_functions)106 StatusOr<absl::flat_hash_set<std::string>> PreprocessGraph(
107     tensorflow::GraphDef& graph_def, bool run_placer_grappler_on_functions) {
108   if (VLOG_IS_ON(1)) {
109     DumpGraphDefToFile("before_generate_resource_shared_name_graph_def",
110                        graph_def);
111   }
112 
113   TF_RETURN_IF_ERROR(tensorflow::GenerateResourceSharedNameIfEmpty(
114       graph_def, tensorflow::OpRegistry::Global()));
115 
116   if (VLOG_IS_ON(2)) {
117     DumpGraphDefToFile("after_generate_resource_shared_name_graph_def",
118                        graph_def);
119   }
120 
121   if (run_placer_grappler_on_functions) {
122     return FindFunctionsToOptimize(graph_def);
123   }
124   return absl::flat_hash_set<std::string>();
125 }
126 
127 }  // namespace
128 
129 StatusOr<std::unique_ptr<TfrtGraphExecutionState>>
Create(const TfrtGraphExecutionState::Options & options,tensorflow::GraphDef graph_def,const FallbackState & fallback_state)130 TfrtGraphExecutionState::Create(const TfrtGraphExecutionState::Options& options,
131                                 tensorflow::GraphDef graph_def,
132                                 const FallbackState& fallback_state) {
133   TF_ASSIGN_OR_RETURN(
134       auto functions_to_optimize,
135       PreprocessGraph(graph_def, options.run_placer_grappler_on_functions));
136 
137   // `CreateGraphExecutionState()` will preprocess the graph (e.g., apply
138   // Placer to the top level graph).
139   TF_ASSIGN_OR_RETURN(
140       auto graph_execution_state,
141       fallback_state.CreateGraphExecutionState(std::move(graph_def)));
142 
143   return std::make_unique<TfrtGraphExecutionState>(
144       options, std::move(graph_execution_state), fallback_state,
145       std::move(functions_to_optimize));
146 }
147 
148 namespace {
149 
PopulateCallableOptions(CallableOptions & callable_options,absl::Span<const std::string> feed_tensor_names,absl::Span<const std::string> fetch_tensor_names,absl::Span<const std::string> target_tensor_names)150 CallableOptions PopulateCallableOptions(
151     CallableOptions& callable_options,
152     absl::Span<const std::string> feed_tensor_names,
153     absl::Span<const std::string> fetch_tensor_names,
154     absl::Span<const std::string> target_tensor_names) {
155   // Configure pruning with the feed/fetch/target tensor names.
156   callable_options.mutable_feed()->Reserve(feed_tensor_names.size());
157   for (const auto& feed : feed_tensor_names) {
158     callable_options.add_feed(feed);
159   }
160   callable_options.mutable_fetch()->Reserve(fetch_tensor_names.size());
161   for (const auto& fetch : fetch_tensor_names) {
162     callable_options.add_fetch(fetch);
163   }
164   callable_options.mutable_target()->Reserve(target_tensor_names.size());
165   for (const auto& target : target_tensor_names) {
166     callable_options.add_target(target);
167   }
168 
169   return callable_options;
170 }
171 
CreateGraphDefFromGraphAndFlibDef(const tensorflow::Graph & graph,const tensorflow::FunctionLibraryDefinition & flib_def)172 tensorflow::GraphDef CreateGraphDefFromGraphAndFlibDef(
173     const tensorflow::Graph& graph,
174     const tensorflow::FunctionLibraryDefinition& flib_def) {
175   tensorflow::GraphDef graph_def;
176   graph.ToGraphDef(&graph_def);
177   *graph_def.mutable_library() = flib_def.ToProto();
178   return graph_def;
179 }
180 
181 // Creates a pruned graph from `graph_def` according to `callable_options`.
CreatePrunedGraph(tensorflow::GraphDef graph_def,const CallableOptions & callable_options)182 StatusOr<std::unique_ptr<tensorflow::Graph>> CreatePrunedGraph(
183     tensorflow::GraphDef graph_def, const CallableOptions& callable_options) {
184   VLOG(1) << "Creating pruned graph: " << callable_options.DebugString();
185 
186   // Prune the graph with `callable_options`. Although
187   // grappler has model_pruner stage, it may leave v1 control flows in an
188   // invalid state that cannot be functionalized. So we perform additional
189   // pruning before functionalization.
190   TF_RETURN_IF_ERROR(PruneGraphDef(graph_def, callable_options));
191 
192   if (VLOG_IS_ON(2)) {
193     DumpGraphDefToFile("before_eliminate_ref_variables_graph_def", graph_def);
194   }
195 
196   // Ref variables in V1 Control flow prevent it from being functionalized. So
197   // we eliminate them first.
198   TF_RETURN_IF_ERROR(EliminateRefVariablesFromV1ControlFlow(graph_def));
199 
200   // The "_input_shapes" attributes will be not be correct after function
201   // optimizer in grappler, we need to remove them. Note that "_input_shapes" is
202   // not used except as a debug hint (somehow this debug hint is used by MLIR
203   // graphdef importer, which is not expected).
204   RemoveInputShapesInFunctions(graph_def);
205 
206   auto pruned_graph =
207       std::make_unique<tensorflow::Graph>(tensorflow::OpRegistry::Global());
208   tensorflow::GraphConstructorOptions options;
209   options.allow_internal_ops = true;
210   options.add_default_attributes = true;
211   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, std::move(graph_def),
212                                             pruned_graph.get()));
213   return pruned_graph;
214 }
215 
216 // Creates a new identity node to replace an operand of a given `node`.
CreateNewIdentityNode(const NodeDef & node,const std::string & input_name,const std::string & identity_name)217 NodeDef CreateNewIdentityNode(const NodeDef& node,
218                               const std::string& input_name,
219                               const std::string& identity_name) {
220   NodeDef identity;
221   identity.set_name(identity_name);
222   identity.set_op("Identity");
223   identity.add_input(input_name);
224   identity.set_device(node.device());
225   for (const auto& name_and_attr : node.attr()) {
226     if (name_and_attr.first == "T") {
227       identity.mutable_attr()->insert(name_and_attr);
228       break;
229     }
230   }
231   return identity;
232 }
233 
234 // Inlines functions into the top level graph.
InlineFunctions(std::unique_ptr<Graph> * graph,const DeviceSet * device_set)235 Status InlineFunctions(std::unique_ptr<Graph>* graph,
236                        const DeviceSet* device_set) {
237   GraphOptimizationPassOptions optimization_options;
238   SessionOptions session_options;
239   // We don't lower v2 control flow to v1 for now.
240   session_options.config.mutable_experimental()->set_use_tfrt(true);
241   session_options.config.mutable_graph_options()
242       ->mutable_optimizer_options()
243       ->set_do_function_inlining(true);
244   optimization_options.session_options = &session_options;
245   optimization_options.graph = graph;
246   optimization_options.flib_def = (*graph)->mutable_flib_def();
247   optimization_options.device_set = device_set;
248   optimization_options.is_function_graph = false;
249 
250   LowerFunctionalOpsPass pass;
251   return pass.Run(optimization_options);
252 }
253 
254 // Assigns input/output nodes to the host.
PlaceInputOutputNodesOnHost(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const Device * cpu_device,Graph * graph)255 Status PlaceInputOutputNodesOnHost(const std::vector<std::string>& inputs,
256                                    const std::vector<std::string>& outputs,
257                                    const Device* cpu_device, Graph* graph) {
258   std::unordered_map<std::string, Node*> name_to_node_map =
259       graph->BuildNodeNameIndex();
260   for (const auto& input : inputs) {
261     name_to_node_map.at(grappler::NodeName(input))
262         ->set_assigned_device_name(cpu_device->name());
263   }
264 
265   // Collect all output nodes.
266   absl::flat_hash_set<Node*> output_nodes;
267   for (const auto& output : outputs) {
268     output_nodes.insert(name_to_node_map.at(grappler::NodeName(output)));
269   }
270   for (const auto& output_node : output_nodes) {
271     // Append an IdentityN node to the original output node if it is not
272     // assigned to the host.
273     if (!output_node->IsIdentity() &&
274         output_node->type_string() != "IdentityN" &&
275         output_node->assigned_device_name() != cpu_device->name()) {
276       // Rename the original output node.
277       std::string output_node_name = output_node->name();
278       output_node->set_name(output_node_name + "/tfrt_renamed");
279 
280       // Append an IdentityN node with the original output node name.
281       std::vector<NodeBuilder::NodeOut> output_tensors;
282       output_tensors.reserve(output_node->num_outputs());
283       for (int i = 0; i < output_node->num_outputs(); i++) {
284         output_tensors.push_back(NodeBuilder::NodeOut(output_node, i));
285       }
286       TF_RETURN_IF_ERROR(NodeBuilder(output_node_name, "IdentityN")
287                              .AssignedDevice(cpu_device->name())
288                              .Input(output_tensors)
289                              .Finalize(graph, /*created_node=*/nullptr));
290     } else {
291       output_node->set_assigned_device_name(cpu_device->name());
292     }
293   }
294   return OkStatus();
295 }
296 
AdjustDeviceAssignment(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const std::vector<std::string> & control_outputs,const Device * cpu_device,Graph * graph)297 Status AdjustDeviceAssignment(const std::vector<std::string>& inputs,
298                               const std::vector<std::string>& outputs,
299                               const std::vector<std::string>& control_outputs,
300                               const Device* cpu_device, Graph* graph) {
301   // TODO(b/232299232): We don't inline and partition v2 control flow currently.
302   // All ops within control flow are placed on CPU for now. Figure out a better
303   // way to handle v2 control flow.
304   for (Node* node : graph->op_nodes()) {
305     if (node->IsWhileNode() || node->IsIfNode()) {
306       LOG(WARNING) << "The control flow node " << node->name()
307                    << " is placed on CPU.";
308       node->set_assigned_device_name(cpu_device->name());
309     }
310   }
311 
312   TF_RETURN_IF_ERROR(
313       PlaceInputOutputNodesOnHost(inputs, outputs, cpu_device, graph));
314   return Status::OK();
315 }
316 
IsTpuGraph(const Graph * graph)317 bool IsTpuGraph(const Graph* graph) {
318   static const auto* const kTpuOps = new absl::flat_hash_set<std::string>{
319       "TPUPartitionedCall", "TPUCompile", "TPUReplicateMetadata"};
320   for (const Node* node : graph->nodes()) {
321     if (kTpuOps->contains(node->type_string())) {
322       return true;
323     }
324   }
325   for (const std::string& func_name : graph->flib_def().ListFunctionNames()) {
326     const FunctionDef* func_def = graph->flib_def().Find(func_name);
327     for (const NodeDef& node_def : func_def->node_def()) {
328       if (kTpuOps->contains(node_def.op())) return true;
329     }
330   }
331   return false;
332 }
333 
334 // Adds Send/Recv ops to `graph` for data transfer, if ops are run on different
335 // devices. Returns a new graph with the added Send/Recv ops.
336 // This is done by partitioning `graph` and add Send/Recv ops on the edges
337 // across devices.
BuildXlaOpsAndMaybeInsertTransferOps(const std::string & graph_func_name,const FallbackState & fallback_state,const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const std::vector<std::string> & control_outputs,std::unique_ptr<Graph> graph)338 StatusOr<std::unique_ptr<Graph>> BuildXlaOpsAndMaybeInsertTransferOps(
339     const std::string& graph_func_name, const FallbackState& fallback_state,
340     const std::vector<std::string>& inputs,
341     const std::vector<std::string>& outputs,
342     const std::vector<std::string>& control_outputs,
343     std::unique_ptr<Graph> graph) {
344   // Skip inserting transfer ops if this is a TPU graph.
345   // Our stack currently cannot run the old bridge on TPU graphs, as it will
346   // generate ops that are not supported by the subsequent MLIR passes.
347   // In the case where TPU related ops are not wrapped in TPUPartitionedCall,
348   // running placer and partitioning on such graphs will fail. So we skip TPU
349   // graphs for now.
350   // TODO(b/228510957): In the long term, we will want a unified way for data
351   // transfer, i.e., using Send/Recv ops for data transfer for TPU as well.
352   if (IsTpuGraph(graph.get())) {
353     return graph;
354   }
355 
356   // Inline functions to facilitate partitioning nodes in the functions.
357   TF_RETURN_IF_ERROR(InlineFunctions(&graph, &fallback_state.device_set()));
358   if (VLOG_IS_ON(1)) {
359     DumpGraphToFile("after_inlining", *graph);
360   }
361 
362   // Replace the StatefulPartitionedCall op that should be compiled to an
363   // XlaLaunch op.
364   // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA
365   // bridge.
366   TF_RETURN_IF_ERROR(BuildXlaLaunchOps(graph.get()));
367   if (VLOG_IS_ON(1)) {
368     DumpGraphToFile("after_build_xla_launch", *graph);
369   }
370 
371   // Run placer.
372   const Device* cpu_device = fallback_state.device_manager().HostCPU();
373   if (cpu_device == nullptr) {
374     return errors::Internal("No CPU device found.");
375   }
376   Placer placer(graph.get(), /*function_name=*/"", &graph->flib_def(),
377                 &fallback_state.device_set(), cpu_device,
378                 /*allow_soft_placement=*/true,
379                 /*log_device_placement=*/false);
380   TF_RETURN_IF_ERROR(placer.Run());
381   if (VLOG_IS_ON(1)) {
382     DumpGraphToFile("after_placer", *graph);
383   }
384 
385   TF_RETURN_IF_ERROR(AdjustDeviceAssignment(inputs, outputs, control_outputs,
386                                             cpu_device, graph.get()));
387 
388   // Insert send/recv ops to the graph.
389   TF_ASSIGN_OR_RETURN(
390       std::unique_ptr<Graph> new_graph,
391       InsertTransferOps(graph_func_name, fallback_state.device_set(),
392                         cpu_device, inputs, outputs, control_outputs,
393                         std::move(graph)));
394   if (VLOG_IS_ON(1)) {
395     DumpGraphToFile("after_transfer_ops_insertion", *new_graph);
396   }
397 
398   return new_graph;
399 }
400 
401 }  // namespace
402 
403 StatusOr<TfrtGraphExecutionState::OptimizationResult>
CreateOptimizedGraph(tensorflow::GraphImportConfig & graph_import_config)404 TfrtGraphExecutionState::CreateOptimizedGraph(
405     tensorflow::GraphImportConfig& graph_import_config) {
406   OptimizationResult result;
407 
408   tensorflow::BuildGraphOptions build_graph_options;
409 
410   std::vector<std::string> inputs;
411   inputs.reserve(graph_import_config.inputs.size());
412   for (const auto& input : graph_import_config.inputs) {
413     inputs.push_back(input.first);
414   }
415   PopulateCallableOptions(build_graph_options.callable_options, inputs,
416                           graph_import_config.outputs,
417                           graph_import_config.control_outputs);
418 
419   auto graph_def = CreateGraphDefFromGraphAndFlibDef(graph(), flib_def());
420 
421   if (VLOG_IS_ON(1)) {
422     DumpGraphDefToFile("before_pruning", graph_def);
423   }
424 
425   TF_ASSIGN_OR_RETURN(
426       result.graph,
427       CreatePrunedGraph(graph_def, build_graph_options.callable_options));
428   DCHECK(result.graph);
429 
430   if (VLOG_IS_ON(1)) {
431     DumpGraphToFile("after_pruning", *result.graph);
432   }
433 
434   const auto functionalization_start_time = absl::Now();
435 
436   // Perform functionalization to convert v1 control flow to v2 control flow. It
437   // should be applied to the unoptimized graph, because Grappler may cause
438   // unfunctionalizablity.
439   TF_RETURN_IF_ERROR(tensorflow::UpgradeLegacyGraph(
440       result.graph.get(),
441       const_cast<tensorflow::FunctionLibraryDefinition*>(
442           &result.graph->flib_def()),
443       /*restrict_functionalization_to_compiled_nodes=*/false));
444 
445   if (VLOG_IS_ON(1)) {
446     DumpGraphToFile("after_functionalization", *result.graph);
447   }
448 
449   auto grappler_start_time = absl::Now();
450   result.functionalization_duration =
451       grappler_start_time - functionalization_start_time;
452 
453   auto status_or_optimized_graph =
454       OptimizeGraph(*result.graph, build_graph_options);
455   if (status_or_optimized_graph.ok()) {
456     result.graph = std::move(status_or_optimized_graph.ValueOrDie());
457   } else {
458     LOG(WARNING) << "TFRT failed to optimize graph: "
459                  << status_or_optimized_graph.status();
460   }
461 
462   if (VLOG_IS_ON(1)) {
463     DumpGraphToFile("after_grappler", *result.graph);
464   }
465 
466   result.grappler_duration = absl::Now() - grappler_start_time;
467 
468   if (options_.enable_tfrt_gpu) {
469     TF_ASSIGN_OR_RETURN(
470         result.graph,
471         BuildXlaOpsAndMaybeInsertTransferOps(
472             graph_import_config.graph_func_name, fallback_state_, inputs,
473             graph_import_config.outputs, graph_import_config.control_outputs,
474             std::move(result.graph)));
475 
476     // Update `control_outputs` as there might be newly added Send ops.
477     for (const Node* node : result.graph->nodes()) {
478       if (node->IsSend()) {
479         graph_import_config.control_outputs.push_back(node->name());
480       }
481     }
482   }
483 
484   return result;
485 }
486 
Extend(const GraphDef & graph)487 Status TfrtGraphExecutionState::Extend(const GraphDef& graph) {
488   std::unique_ptr<GraphExecutionState> new_state;
489   absl::MutexLock lock(&graph_execution_state_mu_);
490   TF_RETURN_IF_ERROR(graph_execution_state_->Extend(graph, &new_state));
491   graph_execution_state_.swap(new_state);
492 
493   auto* graph_def = graph_execution_state_->original_graph_def();
494   DCHECK_NE(graph_def, nullptr);
495   TF_ASSIGN_OR_RETURN(
496       functions_to_optimize_,
497       PreprocessGraph(*graph_def, options_.run_placer_grappler_on_functions));
498 
499   return OkStatus();
500 }
501 
502 namespace {
503 
504 // Given an "Exit" node, finds its corresponding "LoopCond" node.
FindLoopCondFromExitNode(const NodeDef & exit_node,const absl::flat_hash_map<std::string,NodeDef * > & name_to_node)505 StatusOr<const NodeDef*> FindLoopCondFromExitNode(
506     const NodeDef& exit_node,
507     const absl::flat_hash_map<std::string, NodeDef*>& name_to_node) {
508   const NodeDef* switch_node = nullptr;
509   for (const std::string& tensor_name : exit_node.input()) {
510     const std::string node_name = grappler::NodeName(tensor_name);
511     if (!name_to_node.contains(node_name)) {
512       return errors::InvalidArgument("Graph does not contain input ", node_name,
513                                      " of exit node ", exit_node.name());
514     }
515     const NodeDef* node = name_to_node.at(node_name);
516     if (node->op() == "Switch") {
517       switch_node = node;
518       break;
519     }
520   }
521   if (switch_node == nullptr) {
522     return errors::InvalidArgument("Exit node ", exit_node.name(),
523                                    " does not have a Switch node as its ",
524                                    "predecessor.");
525   }
526   for (const std::string& tensor_name : switch_node->input()) {
527     const std::string node_name = grappler::NodeName(tensor_name);
528     if (!name_to_node.contains(node_name)) {
529       return errors::InvalidArgument("Graph does not contain input ", node_name,
530                                      " of switch node ", switch_node->name());
531     }
532 
533     const NodeDef* node = name_to_node.at(node_name);
534     if (node->op() == "LoopCond") {
535       return node;
536     }
537   }
538 
539   return errors::InvalidArgument("Switch node ", switch_node->name(),
540                                  " does not have a LoopCond node as its ",
541                                  "predecessor.");
542 }
543 
544 }  // namespace
545 
PruneGraphDef(GraphDef & graph_def,const CallableOptions & callable_options)546 Status PruneGraphDef(GraphDef& graph_def,
547                      const CallableOptions& callable_options) {
548   // Gather node names and create a map from names to NodeDefs.
549   absl::flat_hash_map<std::string, NodeDef*> name_to_node;
550   // All exit nodes in order to track all while loops.
551   absl::flat_hash_set<const NodeDef*> exit_nodes;
552   for (auto& node : *graph_def.mutable_node()) {
553     name_to_node[node.name()] = &node;
554     if (node.op() == "Exit") {
555       exit_nodes.insert(&node);
556     }
557 
558     // TODO(tfrt-devs): Add support for _Send and _Recv ops.
559     if (node.op() == "_Send" || node.op() == "_Recv") {
560       return errors::InvalidArgument(
561           "TFRT prune graphdef cannot handle graphs contains _Send and _Recv "
562           "ops.");
563     }
564   }
565 
566   // Find all LoopCond -> Exit nodes mapping. So when we traverse to a LoopCond
567   // node, we can add corresponding Exit nodes to the traversal queue in order
568   // to maintain complete structure of a while loop.
569   absl::flat_hash_map<const NodeDef*, absl::flat_hash_set<const NodeDef*>>
570       loop_cond_to_exit_nodes;
571   for (const NodeDef* exit_node : exit_nodes) {
572     TF_ASSIGN_OR_RETURN(const NodeDef* loop_cond_node,
573                         FindLoopCondFromExitNode(*exit_node, name_to_node));
574     loop_cond_to_exit_nodes[loop_cond_node].insert(exit_node);
575   }
576 
577   // `queue` is for candidate nodes we want to visit in the graph.
578   std::vector<const NodeDef*> queue;
579 
580   // Add fetch nodes to the queue.
581   absl::flat_hash_set<std::string> fetch_node_names;
582   for (const std::string& tensor_name : callable_options.fetch()) {
583     const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
584     if (!node) {
585       return errors::InvalidArgument("Graph does not contain fetch node ",
586                                      tensor_name, ".");
587     }
588     queue.push_back(node);
589     fetch_node_names.insert(node->name());
590   }
591 
592   // Add control target nodes to the queue.
593   for (const std::string& tensor_name : callable_options.target()) {
594     const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
595     if (!node) {
596       return errors::InvalidArgument("Graph does not contain target node ",
597                                      tensor_name, ".");
598     }
599     queue.push_back(node);
600     fetch_node_names.insert(node->name());
601   }
602 
603   absl::flat_hash_set<NodeDef*> feed_node_defs;
604 
605   // Add feed nodes to the queue. In addition, perform necessary rewrites to
606   // remove unnecessary input edges.
607   for (const std::string& tensor_name : callable_options.feed()) {
608     NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
609     if (!node) {
610       return errors::InvalidArgument("Graph does not contain feed node ",
611                                      tensor_name, ".");
612     }
613 
614     // If a feed node is a Const, we don't need its inputs at all.
615     //
616     // TODO(tfrt-devs): Consider a general solution that we could just rewrite
617     // all feed nodes to Placeholder nodes.
618     if (node->op() == "Const") {
619       node->clear_input();
620     }
621 
622     queue.push_back(node);
623     feed_node_defs.insert(node);
624   }
625 
626   absl::flat_hash_set<const NodeDef*> visited;
627   std::vector<NodeDef> keep;
628 
629   // Perform graph traversal to find out connected nodes from fetches.
630   while (!queue.empty()) {
631     const NodeDef* node = queue.back();
632     queue.pop_back();
633 
634     if (!visited.insert(node).second) {
635       continue;
636     }
637 
638     keep.push_back(*node);
639     if (node->op() == "LoopCond") {
640       for (const NodeDef* exit_node : loop_cond_to_exit_nodes[node]) {
641         queue.push_back(exit_node);
642       }
643     }
644 
645     for (const std::string& tensor_name : node->input()) {
646       const NodeDef* in = name_to_node[grappler::NodeName(tensor_name)];
647       if (!in) {
648         return errors::InvalidArgument("Graph does not contain input ",
649                                        grappler::NodeName(tensor_name),
650                                        " of node ", node->name(), ".");
651       }
652       queue.push_back(in);
653     }
654   }
655 
656   graph_def.clear_node();
657   for (auto& node : keep) {
658     if (fetch_node_names.contains(node.name())) {
659       // If the fetch node is an Exit op, we insert an Identity op right after
660       // it and rename it to be the new fetch node. This is to prevent
661       // functionalization from removing the fetch nodes.
662       if (node.op() == "Exit") {
663         auto renamed_exit_node = node;
664         renamed_exit_node.set_name(
665             absl::StrCat(renamed_exit_node.name(), "/tfrt_renamed"));
666         node.set_op("Identity");
667         *node.mutable_input(0) = renamed_exit_node.name();
668         *graph_def.add_node() = std::move(renamed_exit_node);
669       }
670     }
671 
672     *graph_def.add_node() = std::move(node);
673   }
674 
675   return OkStatus();
676 }
677 
EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef & graph_def)678 Status EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef& graph_def) {
679   auto* op_factory = OpRegistry::Global();
680 
681   absl::flat_hash_set<std::string> ref_nodes;
682   for (const auto& node : graph_def.node()) {
683     if (node.op() == "RefEnter" || node.op() == "RefSwitch") {
684       ref_nodes.insert(node.name());
685     }
686   }
687 
688   tensorflow::GraphDef updated_graph_def;
689   absl::flat_hash_set<std::string> new_identities;
690   // Insert an identity node between each "RefEnter" or "RefSwitch" node and its
691   // ref input. Then modify each "RefEnter"/"RefSwitch" node in-place to an
692   // "Enter"/"Switch" node.
693   for (auto& node : *graph_def.mutable_node()) {
694     // First find the ref input name to this RefEnter or RefSwitch.
695     std::string* ref_input_name = nullptr;
696     if (node.op() == "RefEnter") {
697       node.set_op("Enter");
698       if (node.input_size() != 1) {
699         return errors::InvalidArgument("RefEnter node ", node.name(),
700                                        " does not have exactly 1 input.");
701       }
702       ref_input_name = node.mutable_input(0);
703     } else if (node.op() == "RefSwitch") {
704       node.set_op("Switch");
705       if (node.input_size() != 2) {
706         return errors::InvalidArgument("RefSwitch node", node.name(),
707                                        " does not have exactly 2 inputs.");
708       }
709       ref_input_name = node.mutable_input(0);
710     } else {
711       // For other ops, check if their inputs are the ref ops we want to
712       // eliminate, and if so, these ops must not require their inputs to be
713       // refs.
714       std::string ref_input;
715       for (const auto& tensor_name : node.input()) {
716         std::string input = grappler::NodeName(tensor_name);
717         if (ref_nodes.contains(input)) {
718           ref_input = std::move(input);
719           break;
720         }
721       }
722       if (!ref_input.empty()) {
723         const OpDef* op_def;
724         TF_RETURN_IF_ERROR(op_factory->LookUpOpDef(node.op(), &op_def));
725         // TODO(tfrt-devs): How to match input_args to input names in NodeDef?
726         for (const auto& input_arg : op_def->input_arg()) {
727           if (input_arg.is_ref()) {
728             return errors::Unimplemented(
729                 "Cannot in-place update ref node ", ref_input,
730                 " to the non-ref counterpart since its user node ", node.name(),
731                 " requires its input to be refs.");
732           }
733         }
734       }
735     }
736 
737     if (ref_input_name != nullptr) {
738       std::string identity_name =
739           absl::StrCat(grappler::NodeName(*ref_input_name), "/identity");
740       if (!new_identities.contains(identity_name)) {
741         *updated_graph_def.add_node() =
742             CreateNewIdentityNode(node, *ref_input_name, identity_name);
743         new_identities.insert(identity_name);
744       }
745       *ref_input_name = std::move(identity_name);
746     }
747 
748     *updated_graph_def.add_node() = std::move(node);
749   }
750 
751   graph_def.mutable_node()->Swap(updated_graph_def.mutable_node());
752   return OkStatus();
753 }
754 
RemoveInputShapesInFunctions(tensorflow::GraphDef & graph_def)755 void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def) {
756   for (tensorflow::FunctionDef& function_def :
757        *graph_def.mutable_library()->mutable_function()) {
758     function_def.mutable_attr()->erase("_input_shapes");
759   }
760 }
761 
762 namespace {
763 
764 // Optimizes the functions in `flib_proto` (filtering with
765 // `functions_to_optimize`) using `flib` and `fallback_state`. Each
766 // function is converted to a graph and optimized with Placer and Grappler, then
767 // converted back to a function to replace the old one.
OptimizeFunctions(FunctionDefLibrary & flib_proto,const FunctionLibraryDefinition & flib,const FallbackState & fallback_state,const absl::flat_hash_set<std::string> & functions_to_optimize)768 Status OptimizeFunctions(
769     FunctionDefLibrary& flib_proto, const FunctionLibraryDefinition& flib,
770     const FallbackState& fallback_state,
771     const absl::flat_hash_set<std::string>& functions_to_optimize) {
772   for (FunctionDef& fdef : *flib_proto.mutable_function()) {
773     if (!functions_to_optimize.contains(fdef.signature().name())) {
774       continue;
775     }
776 
777     // Convert function to graph.
778     std::unique_ptr<FunctionBody> fbody;
779     TF_RETURN_IF_ERROR(
780         FunctionDefToBodyHelper(fdef, AttrSlice(), &flib, &fbody));
781 
782     tensorflow::Graph* graph = fbody->graph;
783     tensorflow::GraphDef graph_def;
784     graph->ToGraphDef(&graph_def);
785     // We need to manually add the flib because it's not added in
786     // `FunctionDefToBodyHelper()`.
787     *graph_def.mutable_library() = flib.ToProto();
788 
789     // `CreateGraphExecutionState()` will preprocess the graph (e.g., apply
790     // Placer).
791     TF_ASSIGN_OR_RETURN(
792         auto graph_execution_state,
793         fallback_state.CreateGraphExecutionState(std::move(graph_def)));
794 
795     // Invoke Grappler to optimize the graph.
796     std::unique_ptr<tensorflow::Graph> optimized_graph;
797     std::unique_ptr<tensorflow::FunctionLibraryDefinition> optimized_flib;
798     tensorflow::BuildGraphOptions build_graph_options;
799     std::vector<std::string> args;
800     args.reserve(fbody->arg_nodes.size());
801     for (const auto& arg : fbody->arg_nodes) args.push_back(arg->name());
802     std::vector<std::string> rets;
803     rets.reserve(fbody->ret_nodes.size());
804     for (const auto& ret : fbody->ret_nodes) rets.push_back(ret->name());
805     std::vector<std::string> control_rets;
806     control_rets.reserve(fbody->control_ret_nodes.size());
807     for (const auto& control_ret : fbody->control_ret_nodes) {
808       control_rets.push_back(control_ret->name());
809     }
810     PopulateCallableOptions(build_graph_options.callable_options, args, rets,
811                             control_rets);
812     auto status = graph_execution_state->OptimizeGraph(
813         build_graph_options, *graph_execution_state->full_graph(), &flib,
814         &optimized_graph, &optimized_flib);
815 
816     if (!status.ok()) {
817       LOG(ERROR) << "TFRT failed to optimize graph (converted from function: "
818                  << fdef.signature().name() << "): " << status;
819       continue;
820     }
821 
822     TF_RETURN_IF_ERROR(
823         optimized_graph->AddFunctionLibrary(optimized_flib->ToProto()));
824 
825     // Convert graph back to function.
826     // We need to store the conversion result into a new `FunctionDef` first to
827     // avoid errors.
828     FunctionDef new_fdef;
829     TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph,
830                                           fdef.signature().name(), &new_fdef));
831 
832     fdef = std::move(new_fdef);
833   }
834   return OkStatus();
835 }
836 
837 }  // namespace
838 
839 StatusOr<std::unique_ptr<tensorflow::Graph>>
OptimizeGraph(const tensorflow::Graph & graph,const tensorflow::BuildGraphOptions & build_graph_options)840 TfrtGraphExecutionState::OptimizeGraph(
841     const tensorflow::Graph& graph,
842     const tensorflow::BuildGraphOptions& build_graph_options) {
843   std::unique_ptr<tensorflow::Graph> optimized_graph;
844   std::unique_ptr<tensorflow::FunctionLibraryDefinition> optimized_flib;
845 
846   {
847     absl::MutexLock lock(&graph_execution_state_mu_);
848     // Invoke Grappler to optimize the graph.
849     TF_RETURN_IF_ERROR(graph_execution_state_->OptimizeGraph(
850         build_graph_options, graph, &graph.flib_def(), &optimized_graph,
851         &optimized_flib));
852   }
853 
854   FunctionDefLibrary optimized_flib_proto = optimized_flib->ToProto();
855   if (options_.run_placer_grappler_on_functions) {
856     TF_RETURN_IF_ERROR(OptimizeFunctions(optimized_flib_proto, *optimized_flib,
857                                          fallback_state_,
858                                          functions_to_optimize_));
859     // Any optimized function is altered but still has the previous name. To
860     // avoid errors when adding the optimized flib, we should clear the current
861     // flib first.
862     optimized_graph->mutable_flib_def()->Clear();
863   }
864 
865   TF_RETURN_IF_ERROR(optimized_graph->AddFunctionLibrary(optimized_flib_proto));
866 
867   return optimized_graph;
868 }
869 
870 // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA
871 // bridge.
BuildXlaLaunchOps(Graph * graph)872 Status BuildXlaLaunchOps(Graph* graph) {
873   const auto is_xla_launch_node = [](const Node& n) -> StatusOr<bool> {
874     if (!n.IsPartitionedCall()) {
875       return false;
876     }
877     bool xla_must_compile = false;
878     const bool has_attribute =
879         TryGetNodeAttr(n.attrs(), kXlaMustCompileAttr, &xla_must_compile);
880     return has_attribute && xla_must_compile;
881   };
882 
883   const auto get_xla_function_info = [](const Node& launch)
884       -> StatusOr<EncapsulateXlaComputationsPass::XlaFunctionInfo> {
885     EncapsulateXlaComputationsPass::XlaFunctionInfo result;
886     std::vector<DataType> tin_dtypes;
887     TF_RETURN_IF_ERROR(GetNodeAttr(launch.def(), "Tin", &tin_dtypes));
888     int variable_start_index = 0;
889     for (; variable_start_index < tin_dtypes.size(); ++variable_start_index) {
890       if (tin_dtypes.at(variable_start_index) == DT_RESOURCE) break;
891     }
892     result.variable_start_index = variable_start_index;
893 
894     NameAttrList func;
895     TF_RETURN_IF_ERROR(GetNodeAttr(launch.attrs(), "f", &func));
896     result.function_name = func.name();
897 
898     return result;
899   };
900 
901   return EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
902       graph, is_xla_launch_node, get_xla_function_info,
903       /*add_edges_to_output_of_downstream_nodes=*/false);
904 }
905 
906 }  // namespace tfrt_stub
907 }  // namespace tensorflow
908