1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/synchronization/mutex.h"
27 #include "absl/time/clock.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/jit/defs.h"
30 #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
32 #include "tensorflow/core/common_runtime/function_body.h"
33 #include "tensorflow/core/common_runtime/function_def_utils.h"
34 #include "tensorflow/core/common_runtime/graph_constructor.h"
35 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
36 #include "tensorflow/core/common_runtime/optimization_registry.h"
37 #include "tensorflow/core/common_runtime/placer.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/function.pb.h"
41 #include "tensorflow/core/framework/graph.pb.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_util.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/op_def.pb.h"
46 #include "tensorflow/core/framework/versions.pb.h"
47 #include "tensorflow/core/graph/graph.h"
48 #include "tensorflow/core/graph/node_builder.h"
49 #include "tensorflow/core/grappler/utils.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/statusor.h"
53 #include "tensorflow/core/protobuf/config.pb.h"
54 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
55 #include "tensorflow/core/tfrt/utils/graph_partition.h"
56 #include "tensorflow/core/util/dump_graph.h"
57
58 namespace tensorflow {
59 namespace tfrt_stub {
60
61 namespace {
62
63 // Finds the names of functions that are safe to optimize.
FindFunctionsToOptimize(const GraphDef & graph_def)64 absl::flat_hash_set<std::string> FindFunctionsToOptimize(
65 const GraphDef& graph_def) {
66 // TODO(b/203689805): Add more functional ops.
67 static const auto* const kOpWhitelist = new absl::flat_hash_set<std::string>{
68 "PartitionedCall", "StatefulPartitionedCall"};
69 absl::flat_hash_map<
70 std::string /*function_name*/,
71 absl::flat_hash_set<std::string> /*ops_using_the_function*/>
72 function_to_ops;
73
74 auto build_map = [&](const auto& node_defs) {
75 for (const auto& node_def : node_defs) {
76 for (const auto& p : node_def.attr()) {
77 const AttrValue& attr_value = p.second;
78 if (!attr_value.has_func()) continue;
79 function_to_ops[attr_value.func().name()].insert(node_def.op());
80 }
81 }
82 };
83
84 build_map(graph_def.node());
85 for (const auto& function_def : graph_def.library().function()) {
86 build_map(function_def.node_def());
87 }
88
89 absl::flat_hash_set<std::string> functions_to_optimize;
90 for (const auto& p : function_to_ops) {
91 const std::string& function_name = p.first;
92 const absl::flat_hash_set<std::string>& ops = p.second;
93 // Optimize a function iff all the ops that use it are whitelisted.
94 if (std::all_of(ops.begin(), ops.end(), [](const auto& op) {
95 return kOpWhitelist->contains(op);
96 })) {
97 functions_to_optimize.insert(function_name);
98 }
99 }
100
101 return functions_to_optimize;
102 }
103
104 // Preprocesses `graph_def`, returns the functions to optimize if
105 // `run_placer_grappler_on_functions` is true.
PreprocessGraph(tensorflow::GraphDef & graph_def,bool run_placer_grappler_on_functions)106 StatusOr<absl::flat_hash_set<std::string>> PreprocessGraph(
107 tensorflow::GraphDef& graph_def, bool run_placer_grappler_on_functions) {
108 if (VLOG_IS_ON(1)) {
109 DumpGraphDefToFile("before_generate_resource_shared_name_graph_def",
110 graph_def);
111 }
112
113 TF_RETURN_IF_ERROR(tensorflow::GenerateResourceSharedNameIfEmpty(
114 graph_def, tensorflow::OpRegistry::Global()));
115
116 if (VLOG_IS_ON(2)) {
117 DumpGraphDefToFile("after_generate_resource_shared_name_graph_def",
118 graph_def);
119 }
120
121 if (run_placer_grappler_on_functions) {
122 return FindFunctionsToOptimize(graph_def);
123 }
124 return absl::flat_hash_set<std::string>();
125 }
126
127 } // namespace
128
129 StatusOr<std::unique_ptr<TfrtGraphExecutionState>>
Create(const TfrtGraphExecutionState::Options & options,tensorflow::GraphDef graph_def,const FallbackState & fallback_state)130 TfrtGraphExecutionState::Create(const TfrtGraphExecutionState::Options& options,
131 tensorflow::GraphDef graph_def,
132 const FallbackState& fallback_state) {
133 TF_ASSIGN_OR_RETURN(
134 auto functions_to_optimize,
135 PreprocessGraph(graph_def, options.run_placer_grappler_on_functions));
136
137 // `CreateGraphExecutionState()` will preprocess the graph (e.g., apply
138 // Placer to the top level graph).
139 TF_ASSIGN_OR_RETURN(
140 auto graph_execution_state,
141 fallback_state.CreateGraphExecutionState(std::move(graph_def)));
142
143 return std::make_unique<TfrtGraphExecutionState>(
144 options, std::move(graph_execution_state), fallback_state,
145 std::move(functions_to_optimize));
146 }
147
148 namespace {
149
PopulateCallableOptions(CallableOptions & callable_options,absl::Span<const std::string> feed_tensor_names,absl::Span<const std::string> fetch_tensor_names,absl::Span<const std::string> target_tensor_names)150 CallableOptions PopulateCallableOptions(
151 CallableOptions& callable_options,
152 absl::Span<const std::string> feed_tensor_names,
153 absl::Span<const std::string> fetch_tensor_names,
154 absl::Span<const std::string> target_tensor_names) {
155 // Configure pruning with the feed/fetch/target tensor names.
156 callable_options.mutable_feed()->Reserve(feed_tensor_names.size());
157 for (const auto& feed : feed_tensor_names) {
158 callable_options.add_feed(feed);
159 }
160 callable_options.mutable_fetch()->Reserve(fetch_tensor_names.size());
161 for (const auto& fetch : fetch_tensor_names) {
162 callable_options.add_fetch(fetch);
163 }
164 callable_options.mutable_target()->Reserve(target_tensor_names.size());
165 for (const auto& target : target_tensor_names) {
166 callable_options.add_target(target);
167 }
168
169 return callable_options;
170 }
171
CreateGraphDefFromGraphAndFlibDef(const tensorflow::Graph & graph,const tensorflow::FunctionLibraryDefinition & flib_def)172 tensorflow::GraphDef CreateGraphDefFromGraphAndFlibDef(
173 const tensorflow::Graph& graph,
174 const tensorflow::FunctionLibraryDefinition& flib_def) {
175 tensorflow::GraphDef graph_def;
176 graph.ToGraphDef(&graph_def);
177 *graph_def.mutable_library() = flib_def.ToProto();
178 return graph_def;
179 }
180
181 // Creates a pruned graph from `graph_def` according to `callable_options`.
CreatePrunedGraph(tensorflow::GraphDef graph_def,const CallableOptions & callable_options)182 StatusOr<std::unique_ptr<tensorflow::Graph>> CreatePrunedGraph(
183 tensorflow::GraphDef graph_def, const CallableOptions& callable_options) {
184 VLOG(1) << "Creating pruned graph: " << callable_options.DebugString();
185
186 // Prune the graph with `callable_options`. Although
187 // grappler has model_pruner stage, it may leave v1 control flows in an
188 // invalid state that cannot be functionalized. So we perform additional
189 // pruning before functionalization.
190 TF_RETURN_IF_ERROR(PruneGraphDef(graph_def, callable_options));
191
192 if (VLOG_IS_ON(2)) {
193 DumpGraphDefToFile("before_eliminate_ref_variables_graph_def", graph_def);
194 }
195
196 // Ref variables in V1 Control flow prevent it from being functionalized. So
197 // we eliminate them first.
198 TF_RETURN_IF_ERROR(EliminateRefVariablesFromV1ControlFlow(graph_def));
199
200 // The "_input_shapes" attributes will be not be correct after function
201 // optimizer in grappler, we need to remove them. Note that "_input_shapes" is
202 // not used except as a debug hint (somehow this debug hint is used by MLIR
203 // graphdef importer, which is not expected).
204 RemoveInputShapesInFunctions(graph_def);
205
206 auto pruned_graph =
207 std::make_unique<tensorflow::Graph>(tensorflow::OpRegistry::Global());
208 tensorflow::GraphConstructorOptions options;
209 options.allow_internal_ops = true;
210 options.add_default_attributes = true;
211 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, std::move(graph_def),
212 pruned_graph.get()));
213 return pruned_graph;
214 }
215
216 // Creates a new identity node to replace an operand of a given `node`.
CreateNewIdentityNode(const NodeDef & node,const std::string & input_name,const std::string & identity_name)217 NodeDef CreateNewIdentityNode(const NodeDef& node,
218 const std::string& input_name,
219 const std::string& identity_name) {
220 NodeDef identity;
221 identity.set_name(identity_name);
222 identity.set_op("Identity");
223 identity.add_input(input_name);
224 identity.set_device(node.device());
225 for (const auto& name_and_attr : node.attr()) {
226 if (name_and_attr.first == "T") {
227 identity.mutable_attr()->insert(name_and_attr);
228 break;
229 }
230 }
231 return identity;
232 }
233
234 // Inlines functions into the top level graph.
InlineFunctions(std::unique_ptr<Graph> * graph,const DeviceSet * device_set)235 Status InlineFunctions(std::unique_ptr<Graph>* graph,
236 const DeviceSet* device_set) {
237 GraphOptimizationPassOptions optimization_options;
238 SessionOptions session_options;
239 // We don't lower v2 control flow to v1 for now.
240 session_options.config.mutable_experimental()->set_use_tfrt(true);
241 session_options.config.mutable_graph_options()
242 ->mutable_optimizer_options()
243 ->set_do_function_inlining(true);
244 optimization_options.session_options = &session_options;
245 optimization_options.graph = graph;
246 optimization_options.flib_def = (*graph)->mutable_flib_def();
247 optimization_options.device_set = device_set;
248 optimization_options.is_function_graph = false;
249
250 LowerFunctionalOpsPass pass;
251 return pass.Run(optimization_options);
252 }
253
254 // Assigns input/output nodes to the host.
PlaceInputOutputNodesOnHost(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const Device * cpu_device,Graph * graph)255 Status PlaceInputOutputNodesOnHost(const std::vector<std::string>& inputs,
256 const std::vector<std::string>& outputs,
257 const Device* cpu_device, Graph* graph) {
258 std::unordered_map<std::string, Node*> name_to_node_map =
259 graph->BuildNodeNameIndex();
260 for (const auto& input : inputs) {
261 name_to_node_map.at(grappler::NodeName(input))
262 ->set_assigned_device_name(cpu_device->name());
263 }
264
265 // Collect all output nodes.
266 absl::flat_hash_set<Node*> output_nodes;
267 for (const auto& output : outputs) {
268 output_nodes.insert(name_to_node_map.at(grappler::NodeName(output)));
269 }
270 for (const auto& output_node : output_nodes) {
271 // Append an IdentityN node to the original output node if it is not
272 // assigned to the host.
273 if (!output_node->IsIdentity() &&
274 output_node->type_string() != "IdentityN" &&
275 output_node->assigned_device_name() != cpu_device->name()) {
276 // Rename the original output node.
277 std::string output_node_name = output_node->name();
278 output_node->set_name(output_node_name + "/tfrt_renamed");
279
280 // Append an IdentityN node with the original output node name.
281 std::vector<NodeBuilder::NodeOut> output_tensors;
282 output_tensors.reserve(output_node->num_outputs());
283 for (int i = 0; i < output_node->num_outputs(); i++) {
284 output_tensors.push_back(NodeBuilder::NodeOut(output_node, i));
285 }
286 TF_RETURN_IF_ERROR(NodeBuilder(output_node_name, "IdentityN")
287 .AssignedDevice(cpu_device->name())
288 .Input(output_tensors)
289 .Finalize(graph, /*created_node=*/nullptr));
290 } else {
291 output_node->set_assigned_device_name(cpu_device->name());
292 }
293 }
294 return OkStatus();
295 }
296
AdjustDeviceAssignment(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const std::vector<std::string> & control_outputs,const Device * cpu_device,Graph * graph)297 Status AdjustDeviceAssignment(const std::vector<std::string>& inputs,
298 const std::vector<std::string>& outputs,
299 const std::vector<std::string>& control_outputs,
300 const Device* cpu_device, Graph* graph) {
301 // TODO(b/232299232): We don't inline and partition v2 control flow currently.
302 // All ops within control flow are placed on CPU for now. Figure out a better
303 // way to handle v2 control flow.
304 for (Node* node : graph->op_nodes()) {
305 if (node->IsWhileNode() || node->IsIfNode()) {
306 LOG(WARNING) << "The control flow node " << node->name()
307 << " is placed on CPU.";
308 node->set_assigned_device_name(cpu_device->name());
309 }
310 }
311
312 TF_RETURN_IF_ERROR(
313 PlaceInputOutputNodesOnHost(inputs, outputs, cpu_device, graph));
314 return Status::OK();
315 }
316
IsTpuGraph(const Graph * graph)317 bool IsTpuGraph(const Graph* graph) {
318 static const auto* const kTpuOps = new absl::flat_hash_set<std::string>{
319 "TPUPartitionedCall", "TPUCompile", "TPUReplicateMetadata"};
320 for (const Node* node : graph->nodes()) {
321 if (kTpuOps->contains(node->type_string())) {
322 return true;
323 }
324 }
325 for (const std::string& func_name : graph->flib_def().ListFunctionNames()) {
326 const FunctionDef* func_def = graph->flib_def().Find(func_name);
327 for (const NodeDef& node_def : func_def->node_def()) {
328 if (kTpuOps->contains(node_def.op())) return true;
329 }
330 }
331 return false;
332 }
333
334 // Adds Send/Recv ops to `graph` for data transfer, if ops are run on different
335 // devices. Returns a new graph with the added Send/Recv ops.
336 // This is done by partitioning `graph` and add Send/Recv ops on the edges
337 // across devices.
BuildXlaOpsAndMaybeInsertTransferOps(const std::string & graph_func_name,const FallbackState & fallback_state,const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,const std::vector<std::string> & control_outputs,std::unique_ptr<Graph> graph)338 StatusOr<std::unique_ptr<Graph>> BuildXlaOpsAndMaybeInsertTransferOps(
339 const std::string& graph_func_name, const FallbackState& fallback_state,
340 const std::vector<std::string>& inputs,
341 const std::vector<std::string>& outputs,
342 const std::vector<std::string>& control_outputs,
343 std::unique_ptr<Graph> graph) {
344 // Skip inserting transfer ops if this is a TPU graph.
345 // Our stack currently cannot run the old bridge on TPU graphs, as it will
346 // generate ops that are not supported by the subsequent MLIR passes.
347 // In the case where TPU related ops are not wrapped in TPUPartitionedCall,
348 // running placer and partitioning on such graphs will fail. So we skip TPU
349 // graphs for now.
350 // TODO(b/228510957): In the long term, we will want a unified way for data
351 // transfer, i.e., using Send/Recv ops for data transfer for TPU as well.
352 if (IsTpuGraph(graph.get())) {
353 return graph;
354 }
355
356 // Inline functions to facilitate partitioning nodes in the functions.
357 TF_RETURN_IF_ERROR(InlineFunctions(&graph, &fallback_state.device_set()));
358 if (VLOG_IS_ON(1)) {
359 DumpGraphToFile("after_inlining", *graph);
360 }
361
362 // Replace the StatefulPartitionedCall op that should be compiled to an
363 // XlaLaunch op.
364 // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA
365 // bridge.
366 TF_RETURN_IF_ERROR(BuildXlaLaunchOps(graph.get()));
367 if (VLOG_IS_ON(1)) {
368 DumpGraphToFile("after_build_xla_launch", *graph);
369 }
370
371 // Run placer.
372 const Device* cpu_device = fallback_state.device_manager().HostCPU();
373 if (cpu_device == nullptr) {
374 return errors::Internal("No CPU device found.");
375 }
376 Placer placer(graph.get(), /*function_name=*/"", &graph->flib_def(),
377 &fallback_state.device_set(), cpu_device,
378 /*allow_soft_placement=*/true,
379 /*log_device_placement=*/false);
380 TF_RETURN_IF_ERROR(placer.Run());
381 if (VLOG_IS_ON(1)) {
382 DumpGraphToFile("after_placer", *graph);
383 }
384
385 TF_RETURN_IF_ERROR(AdjustDeviceAssignment(inputs, outputs, control_outputs,
386 cpu_device, graph.get()));
387
388 // Insert send/recv ops to the graph.
389 TF_ASSIGN_OR_RETURN(
390 std::unique_ptr<Graph> new_graph,
391 InsertTransferOps(graph_func_name, fallback_state.device_set(),
392 cpu_device, inputs, outputs, control_outputs,
393 std::move(graph)));
394 if (VLOG_IS_ON(1)) {
395 DumpGraphToFile("after_transfer_ops_insertion", *new_graph);
396 }
397
398 return new_graph;
399 }
400
401 } // namespace
402
403 StatusOr<TfrtGraphExecutionState::OptimizationResult>
CreateOptimizedGraph(tensorflow::GraphImportConfig & graph_import_config)404 TfrtGraphExecutionState::CreateOptimizedGraph(
405 tensorflow::GraphImportConfig& graph_import_config) {
406 OptimizationResult result;
407
408 tensorflow::BuildGraphOptions build_graph_options;
409
410 std::vector<std::string> inputs;
411 inputs.reserve(graph_import_config.inputs.size());
412 for (const auto& input : graph_import_config.inputs) {
413 inputs.push_back(input.first);
414 }
415 PopulateCallableOptions(build_graph_options.callable_options, inputs,
416 graph_import_config.outputs,
417 graph_import_config.control_outputs);
418
419 auto graph_def = CreateGraphDefFromGraphAndFlibDef(graph(), flib_def());
420
421 if (VLOG_IS_ON(1)) {
422 DumpGraphDefToFile("before_pruning", graph_def);
423 }
424
425 TF_ASSIGN_OR_RETURN(
426 result.graph,
427 CreatePrunedGraph(graph_def, build_graph_options.callable_options));
428 DCHECK(result.graph);
429
430 if (VLOG_IS_ON(1)) {
431 DumpGraphToFile("after_pruning", *result.graph);
432 }
433
434 const auto functionalization_start_time = absl::Now();
435
436 // Perform functionalization to convert v1 control flow to v2 control flow. It
437 // should be applied to the unoptimized graph, because Grappler may cause
438 // unfunctionalizablity.
439 TF_RETURN_IF_ERROR(tensorflow::UpgradeLegacyGraph(
440 result.graph.get(),
441 const_cast<tensorflow::FunctionLibraryDefinition*>(
442 &result.graph->flib_def()),
443 /*restrict_functionalization_to_compiled_nodes=*/false));
444
445 if (VLOG_IS_ON(1)) {
446 DumpGraphToFile("after_functionalization", *result.graph);
447 }
448
449 auto grappler_start_time = absl::Now();
450 result.functionalization_duration =
451 grappler_start_time - functionalization_start_time;
452
453 auto status_or_optimized_graph =
454 OptimizeGraph(*result.graph, build_graph_options);
455 if (status_or_optimized_graph.ok()) {
456 result.graph = std::move(status_or_optimized_graph.ValueOrDie());
457 } else {
458 LOG(WARNING) << "TFRT failed to optimize graph: "
459 << status_or_optimized_graph.status();
460 }
461
462 if (VLOG_IS_ON(1)) {
463 DumpGraphToFile("after_grappler", *result.graph);
464 }
465
466 result.grappler_duration = absl::Now() - grappler_start_time;
467
468 if (options_.enable_tfrt_gpu) {
469 TF_ASSIGN_OR_RETURN(
470 result.graph,
471 BuildXlaOpsAndMaybeInsertTransferOps(
472 graph_import_config.graph_func_name, fallback_state_, inputs,
473 graph_import_config.outputs, graph_import_config.control_outputs,
474 std::move(result.graph)));
475
476 // Update `control_outputs` as there might be newly added Send ops.
477 for (const Node* node : result.graph->nodes()) {
478 if (node->IsSend()) {
479 graph_import_config.control_outputs.push_back(node->name());
480 }
481 }
482 }
483
484 return result;
485 }
486
Extend(const GraphDef & graph)487 Status TfrtGraphExecutionState::Extend(const GraphDef& graph) {
488 std::unique_ptr<GraphExecutionState> new_state;
489 absl::MutexLock lock(&graph_execution_state_mu_);
490 TF_RETURN_IF_ERROR(graph_execution_state_->Extend(graph, &new_state));
491 graph_execution_state_.swap(new_state);
492
493 auto* graph_def = graph_execution_state_->original_graph_def();
494 DCHECK_NE(graph_def, nullptr);
495 TF_ASSIGN_OR_RETURN(
496 functions_to_optimize_,
497 PreprocessGraph(*graph_def, options_.run_placer_grappler_on_functions));
498
499 return OkStatus();
500 }
501
502 namespace {
503
504 // Given an "Exit" node, finds its corresponding "LoopCond" node.
FindLoopCondFromExitNode(const NodeDef & exit_node,const absl::flat_hash_map<std::string,NodeDef * > & name_to_node)505 StatusOr<const NodeDef*> FindLoopCondFromExitNode(
506 const NodeDef& exit_node,
507 const absl::flat_hash_map<std::string, NodeDef*>& name_to_node) {
508 const NodeDef* switch_node = nullptr;
509 for (const std::string& tensor_name : exit_node.input()) {
510 const std::string node_name = grappler::NodeName(tensor_name);
511 if (!name_to_node.contains(node_name)) {
512 return errors::InvalidArgument("Graph does not contain input ", node_name,
513 " of exit node ", exit_node.name());
514 }
515 const NodeDef* node = name_to_node.at(node_name);
516 if (node->op() == "Switch") {
517 switch_node = node;
518 break;
519 }
520 }
521 if (switch_node == nullptr) {
522 return errors::InvalidArgument("Exit node ", exit_node.name(),
523 " does not have a Switch node as its ",
524 "predecessor.");
525 }
526 for (const std::string& tensor_name : switch_node->input()) {
527 const std::string node_name = grappler::NodeName(tensor_name);
528 if (!name_to_node.contains(node_name)) {
529 return errors::InvalidArgument("Graph does not contain input ", node_name,
530 " of switch node ", switch_node->name());
531 }
532
533 const NodeDef* node = name_to_node.at(node_name);
534 if (node->op() == "LoopCond") {
535 return node;
536 }
537 }
538
539 return errors::InvalidArgument("Switch node ", switch_node->name(),
540 " does not have a LoopCond node as its ",
541 "predecessor.");
542 }
543
544 } // namespace
545
PruneGraphDef(GraphDef & graph_def,const CallableOptions & callable_options)546 Status PruneGraphDef(GraphDef& graph_def,
547 const CallableOptions& callable_options) {
548 // Gather node names and create a map from names to NodeDefs.
549 absl::flat_hash_map<std::string, NodeDef*> name_to_node;
550 // All exit nodes in order to track all while loops.
551 absl::flat_hash_set<const NodeDef*> exit_nodes;
552 for (auto& node : *graph_def.mutable_node()) {
553 name_to_node[node.name()] = &node;
554 if (node.op() == "Exit") {
555 exit_nodes.insert(&node);
556 }
557
558 // TODO(tfrt-devs): Add support for _Send and _Recv ops.
559 if (node.op() == "_Send" || node.op() == "_Recv") {
560 return errors::InvalidArgument(
561 "TFRT prune graphdef cannot handle graphs contains _Send and _Recv "
562 "ops.");
563 }
564 }
565
566 // Find all LoopCond -> Exit nodes mapping. So when we traverse to a LoopCond
567 // node, we can add corresponding Exit nodes to the traversal queue in order
568 // to maintain complete structure of a while loop.
569 absl::flat_hash_map<const NodeDef*, absl::flat_hash_set<const NodeDef*>>
570 loop_cond_to_exit_nodes;
571 for (const NodeDef* exit_node : exit_nodes) {
572 TF_ASSIGN_OR_RETURN(const NodeDef* loop_cond_node,
573 FindLoopCondFromExitNode(*exit_node, name_to_node));
574 loop_cond_to_exit_nodes[loop_cond_node].insert(exit_node);
575 }
576
577 // `queue` is for candidate nodes we want to visit in the graph.
578 std::vector<const NodeDef*> queue;
579
580 // Add fetch nodes to the queue.
581 absl::flat_hash_set<std::string> fetch_node_names;
582 for (const std::string& tensor_name : callable_options.fetch()) {
583 const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
584 if (!node) {
585 return errors::InvalidArgument("Graph does not contain fetch node ",
586 tensor_name, ".");
587 }
588 queue.push_back(node);
589 fetch_node_names.insert(node->name());
590 }
591
592 // Add control target nodes to the queue.
593 for (const std::string& tensor_name : callable_options.target()) {
594 const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
595 if (!node) {
596 return errors::InvalidArgument("Graph does not contain target node ",
597 tensor_name, ".");
598 }
599 queue.push_back(node);
600 fetch_node_names.insert(node->name());
601 }
602
603 absl::flat_hash_set<NodeDef*> feed_node_defs;
604
605 // Add feed nodes to the queue. In addition, perform necessary rewrites to
606 // remove unnecessary input edges.
607 for (const std::string& tensor_name : callable_options.feed()) {
608 NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
609 if (!node) {
610 return errors::InvalidArgument("Graph does not contain feed node ",
611 tensor_name, ".");
612 }
613
614 // If a feed node is a Const, we don't need its inputs at all.
615 //
616 // TODO(tfrt-devs): Consider a general solution that we could just rewrite
617 // all feed nodes to Placeholder nodes.
618 if (node->op() == "Const") {
619 node->clear_input();
620 }
621
622 queue.push_back(node);
623 feed_node_defs.insert(node);
624 }
625
626 absl::flat_hash_set<const NodeDef*> visited;
627 std::vector<NodeDef> keep;
628
629 // Perform graph traversal to find out connected nodes from fetches.
630 while (!queue.empty()) {
631 const NodeDef* node = queue.back();
632 queue.pop_back();
633
634 if (!visited.insert(node).second) {
635 continue;
636 }
637
638 keep.push_back(*node);
639 if (node->op() == "LoopCond") {
640 for (const NodeDef* exit_node : loop_cond_to_exit_nodes[node]) {
641 queue.push_back(exit_node);
642 }
643 }
644
645 for (const std::string& tensor_name : node->input()) {
646 const NodeDef* in = name_to_node[grappler::NodeName(tensor_name)];
647 if (!in) {
648 return errors::InvalidArgument("Graph does not contain input ",
649 grappler::NodeName(tensor_name),
650 " of node ", node->name(), ".");
651 }
652 queue.push_back(in);
653 }
654 }
655
656 graph_def.clear_node();
657 for (auto& node : keep) {
658 if (fetch_node_names.contains(node.name())) {
659 // If the fetch node is an Exit op, we insert an Identity op right after
660 // it and rename it to be the new fetch node. This is to prevent
661 // functionalization from removing the fetch nodes.
662 if (node.op() == "Exit") {
663 auto renamed_exit_node = node;
664 renamed_exit_node.set_name(
665 absl::StrCat(renamed_exit_node.name(), "/tfrt_renamed"));
666 node.set_op("Identity");
667 *node.mutable_input(0) = renamed_exit_node.name();
668 *graph_def.add_node() = std::move(renamed_exit_node);
669 }
670 }
671
672 *graph_def.add_node() = std::move(node);
673 }
674
675 return OkStatus();
676 }
677
EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef & graph_def)678 Status EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef& graph_def) {
679 auto* op_factory = OpRegistry::Global();
680
681 absl::flat_hash_set<std::string> ref_nodes;
682 for (const auto& node : graph_def.node()) {
683 if (node.op() == "RefEnter" || node.op() == "RefSwitch") {
684 ref_nodes.insert(node.name());
685 }
686 }
687
688 tensorflow::GraphDef updated_graph_def;
689 absl::flat_hash_set<std::string> new_identities;
690 // Insert an identity node between each "RefEnter" or "RefSwitch" node and its
691 // ref input. Then modify each "RefEnter"/"RefSwitch" node in-place to an
692 // "Enter"/"Switch" node.
693 for (auto& node : *graph_def.mutable_node()) {
694 // First find the ref input name to this RefEnter or RefSwitch.
695 std::string* ref_input_name = nullptr;
696 if (node.op() == "RefEnter") {
697 node.set_op("Enter");
698 if (node.input_size() != 1) {
699 return errors::InvalidArgument("RefEnter node ", node.name(),
700 " does not have exactly 1 input.");
701 }
702 ref_input_name = node.mutable_input(0);
703 } else if (node.op() == "RefSwitch") {
704 node.set_op("Switch");
705 if (node.input_size() != 2) {
706 return errors::InvalidArgument("RefSwitch node", node.name(),
707 " does not have exactly 2 inputs.");
708 }
709 ref_input_name = node.mutable_input(0);
710 } else {
711 // For other ops, check if their inputs are the ref ops we want to
712 // eliminate, and if so, these ops must not require their inputs to be
713 // refs.
714 std::string ref_input;
715 for (const auto& tensor_name : node.input()) {
716 std::string input = grappler::NodeName(tensor_name);
717 if (ref_nodes.contains(input)) {
718 ref_input = std::move(input);
719 break;
720 }
721 }
722 if (!ref_input.empty()) {
723 const OpDef* op_def;
724 TF_RETURN_IF_ERROR(op_factory->LookUpOpDef(node.op(), &op_def));
725 // TODO(tfrt-devs): How to match input_args to input names in NodeDef?
726 for (const auto& input_arg : op_def->input_arg()) {
727 if (input_arg.is_ref()) {
728 return errors::Unimplemented(
729 "Cannot in-place update ref node ", ref_input,
730 " to the non-ref counterpart since its user node ", node.name(),
731 " requires its input to be refs.");
732 }
733 }
734 }
735 }
736
737 if (ref_input_name != nullptr) {
738 std::string identity_name =
739 absl::StrCat(grappler::NodeName(*ref_input_name), "/identity");
740 if (!new_identities.contains(identity_name)) {
741 *updated_graph_def.add_node() =
742 CreateNewIdentityNode(node, *ref_input_name, identity_name);
743 new_identities.insert(identity_name);
744 }
745 *ref_input_name = std::move(identity_name);
746 }
747
748 *updated_graph_def.add_node() = std::move(node);
749 }
750
751 graph_def.mutable_node()->Swap(updated_graph_def.mutable_node());
752 return OkStatus();
753 }
754
RemoveInputShapesInFunctions(tensorflow::GraphDef & graph_def)755 void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def) {
756 for (tensorflow::FunctionDef& function_def :
757 *graph_def.mutable_library()->mutable_function()) {
758 function_def.mutable_attr()->erase("_input_shapes");
759 }
760 }
761
762 namespace {
763
764 // Optimizes the functions in `flib_proto` (filtering with
765 // `functions_to_optimize`) using `flib` and `fallback_state`. Each
766 // function is converted to a graph and optimized with Placer and Grappler, then
767 // converted back to a function to replace the old one.
OptimizeFunctions(FunctionDefLibrary & flib_proto,const FunctionLibraryDefinition & flib,const FallbackState & fallback_state,const absl::flat_hash_set<std::string> & functions_to_optimize)768 Status OptimizeFunctions(
769 FunctionDefLibrary& flib_proto, const FunctionLibraryDefinition& flib,
770 const FallbackState& fallback_state,
771 const absl::flat_hash_set<std::string>& functions_to_optimize) {
772 for (FunctionDef& fdef : *flib_proto.mutable_function()) {
773 if (!functions_to_optimize.contains(fdef.signature().name())) {
774 continue;
775 }
776
777 // Convert function to graph.
778 std::unique_ptr<FunctionBody> fbody;
779 TF_RETURN_IF_ERROR(
780 FunctionDefToBodyHelper(fdef, AttrSlice(), &flib, &fbody));
781
782 tensorflow::Graph* graph = fbody->graph;
783 tensorflow::GraphDef graph_def;
784 graph->ToGraphDef(&graph_def);
785 // We need to manually add the flib because it's not added in
786 // `FunctionDefToBodyHelper()`.
787 *graph_def.mutable_library() = flib.ToProto();
788
789 // `CreateGraphExecutionState()` will preprocess the graph (e.g., apply
790 // Placer).
791 TF_ASSIGN_OR_RETURN(
792 auto graph_execution_state,
793 fallback_state.CreateGraphExecutionState(std::move(graph_def)));
794
795 // Invoke Grappler to optimize the graph.
796 std::unique_ptr<tensorflow::Graph> optimized_graph;
797 std::unique_ptr<tensorflow::FunctionLibraryDefinition> optimized_flib;
798 tensorflow::BuildGraphOptions build_graph_options;
799 std::vector<std::string> args;
800 args.reserve(fbody->arg_nodes.size());
801 for (const auto& arg : fbody->arg_nodes) args.push_back(arg->name());
802 std::vector<std::string> rets;
803 rets.reserve(fbody->ret_nodes.size());
804 for (const auto& ret : fbody->ret_nodes) rets.push_back(ret->name());
805 std::vector<std::string> control_rets;
806 control_rets.reserve(fbody->control_ret_nodes.size());
807 for (const auto& control_ret : fbody->control_ret_nodes) {
808 control_rets.push_back(control_ret->name());
809 }
810 PopulateCallableOptions(build_graph_options.callable_options, args, rets,
811 control_rets);
812 auto status = graph_execution_state->OptimizeGraph(
813 build_graph_options, *graph_execution_state->full_graph(), &flib,
814 &optimized_graph, &optimized_flib);
815
816 if (!status.ok()) {
817 LOG(ERROR) << "TFRT failed to optimize graph (converted from function: "
818 << fdef.signature().name() << "): " << status;
819 continue;
820 }
821
822 TF_RETURN_IF_ERROR(
823 optimized_graph->AddFunctionLibrary(optimized_flib->ToProto()));
824
825 // Convert graph back to function.
826 // We need to store the conversion result into a new `FunctionDef` first to
827 // avoid errors.
828 FunctionDef new_fdef;
829 TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph,
830 fdef.signature().name(), &new_fdef));
831
832 fdef = std::move(new_fdef);
833 }
834 return OkStatus();
835 }
836
837 } // namespace
838
839 StatusOr<std::unique_ptr<tensorflow::Graph>>
OptimizeGraph(const tensorflow::Graph & graph,const tensorflow::BuildGraphOptions & build_graph_options)840 TfrtGraphExecutionState::OptimizeGraph(
841 const tensorflow::Graph& graph,
842 const tensorflow::BuildGraphOptions& build_graph_options) {
843 std::unique_ptr<tensorflow::Graph> optimized_graph;
844 std::unique_ptr<tensorflow::FunctionLibraryDefinition> optimized_flib;
845
846 {
847 absl::MutexLock lock(&graph_execution_state_mu_);
848 // Invoke Grappler to optimize the graph.
849 TF_RETURN_IF_ERROR(graph_execution_state_->OptimizeGraph(
850 build_graph_options, graph, &graph.flib_def(), &optimized_graph,
851 &optimized_flib));
852 }
853
854 FunctionDefLibrary optimized_flib_proto = optimized_flib->ToProto();
855 if (options_.run_placer_grappler_on_functions) {
856 TF_RETURN_IF_ERROR(OptimizeFunctions(optimized_flib_proto, *optimized_flib,
857 fallback_state_,
858 functions_to_optimize_));
859 // Any optimized function is altered but still has the previous name. To
860 // avoid errors when adding the optimized flib, we should clear the current
861 // flib first.
862 optimized_graph->mutable_flib_def()->Clear();
863 }
864
865 TF_RETURN_IF_ERROR(optimized_graph->AddFunctionLibrary(optimized_flib_proto));
866
867 return optimized_graph;
868 }
869
870 // TODO(b/239089915): Clean this up after the logic is implemented in TFXLA
871 // bridge.
BuildXlaLaunchOps(Graph * graph)872 Status BuildXlaLaunchOps(Graph* graph) {
873 const auto is_xla_launch_node = [](const Node& n) -> StatusOr<bool> {
874 if (!n.IsPartitionedCall()) {
875 return false;
876 }
877 bool xla_must_compile = false;
878 const bool has_attribute =
879 TryGetNodeAttr(n.attrs(), kXlaMustCompileAttr, &xla_must_compile);
880 return has_attribute && xla_must_compile;
881 };
882
883 const auto get_xla_function_info = [](const Node& launch)
884 -> StatusOr<EncapsulateXlaComputationsPass::XlaFunctionInfo> {
885 EncapsulateXlaComputationsPass::XlaFunctionInfo result;
886 std::vector<DataType> tin_dtypes;
887 TF_RETURN_IF_ERROR(GetNodeAttr(launch.def(), "Tin", &tin_dtypes));
888 int variable_start_index = 0;
889 for (; variable_start_index < tin_dtypes.size(); ++variable_start_index) {
890 if (tin_dtypes.at(variable_start_index) == DT_RESOURCE) break;
891 }
892 result.variable_start_index = variable_start_index;
893
894 NameAttrList func;
895 TF_RETURN_IF_ERROR(GetNodeAttr(launch.attrs(), "f", &func));
896 result.function_name = func.name();
897
898 return result;
899 };
900
901 return EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
902 graph, is_xla_launch_node, get_xla_function_info,
903 /*add_edges_to_output_of_downstream_nodes=*/false);
904 }
905
906 } // namespace tfrt_stub
907 } // namespace tensorflow
908