1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
29 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/shape_refiner.h"
34 #include "tensorflow/core/framework/graph_to_functiondef.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/control_flow.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/util/dump_graph.h"
44
45 namespace tensorflow {
46 namespace functionalize_cond {
47
operator <(const AncestorNode & other) const48 bool AncestorNode::operator<(const AncestorNode& other) const {
49 return (output_tensor.node->id() < other.output_tensor.node->id()) ||
50 (output_tensor.node->id() == other.output_tensor.node->id() &&
51 output_tensor.index < other.output_tensor.index) ||
52 (output_tensor.node->id() == other.output_tensor.node->id() &&
53 output_tensor.index == other.output_tensor.index &&
54 type < other.type);
55 }
56
operator ==(const AncestorNode & other) const57 bool AncestorNode::operator==(const AncestorNode& other) const {
58 return output_tensor.node->id() == other.output_tensor.node->id() &&
59 output_tensor.index == other.output_tensor.index && type == other.type;
60 }
61
operator ()(const AncestorNode & ancestor) const62 size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
63 size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
64 h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
65 return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
66 }
67
68 typedef std::tuple<StateMap::CondId, StateMap::AncestorId, OutputTensor>
69 ClusterTuple;
70
71 struct ClusterTupleLessThan {
operator ()tensorflow::functionalize_cond::ClusterTupleLessThan72 bool operator()(const ClusterTuple& a, const ClusterTuple& b) const {
73 if (std::tie(std::get<0>(a), std::get<1>(a)) <
74 std::tie(std::get<0>(b), std::get<1>(b))) {
75 return true;
76 } else if (std::tie(std::get<0>(a), std::get<1>(a)) ==
77 std::tie(std::get<0>(b), std::get<1>(b))) {
78 return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b));
79 } else {
80 return false;
81 }
82 }
83 };
84
85 // TODO(jpienaar): Move to OutputTensor.
DebugString(const OutputTensor & tensor)86 string DebugString(const OutputTensor& tensor) {
87 return absl::StrCat(tensor.node->name(), ":", tensor.index);
88 }
89
Branch_Name(BranchType b)90 string Branch_Name(BranchType b) {
91 switch (b) {
92 case BranchType::kElseBranch:
93 return "else";
94 case BranchType::kThenBranch:
95 return "then";
96 case BranchType::kBoth:
97 return "both";
98 case BranchType::kNeither:
99 return "neither";
100 }
101 }
102
DebugString(StateMap::CondId cond_state)103 string DebugString(StateMap::CondId cond_state) {
104 if (cond_state == nullptr || cond_state->empty()) return "{}";
105 using value_type = StateMap::CondState::value_type;
106 return absl::StrCat(
107 "{",
108 absl::StrJoin(*cond_state, ", ",
109 [](string* output, const value_type& pred_branch) {
110 const OutputTensor& pred = pred_branch.first;
111 const BranchType& branch = pred_branch.second;
112 if (branch == BranchType::kNeither)
113 absl::StrAppend(output, "d");
114 else
115 absl::StrAppend(output, "s(", DebugString(pred), ",",
116 Branch_Name(branch), ")");
117 }),
118 "}");
119 }
120
121 // Returns the predicate of a switch.
GetSwitchPredicate(const Node & switch_node,OutputTensor * pred)122 Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
123 const Edge* pred_edge;
124 TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
125 // The predicate can be preceded by a identity node. Look through
126 // identity nodes to predicate.
127 while (pred_edge->src()->IsIdentity()) {
128 TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
129 }
130 *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
131 return OkStatus();
132 }
133
GetSwitchValue(const Node & switch_node,OutputTensor * val)134 Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
135 const Edge* val_edge;
136 TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
137 *val = OutputTensor(val_edge->src(), val_edge->src_output());
138 return OkStatus();
139 }
140
operator ()(const OutputTensor & lhs,const OutputTensor & rhs) const141 bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
142 const OutputTensor& rhs) const {
143 return (lhs.node->id() < rhs.node->id()) ||
144 (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
145 }
146
147 struct CondStateLess {
operator ()tensorflow::functionalize_cond::CondStateLess148 bool operator()(const StateMap::CondState::value_type& lhs,
149 const StateMap::CondState::value_type& rhs) const {
150 if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
151 return true;
152 if (lhs.first.node->id() == rhs.first.node->id() &&
153 lhs.first.index == rhs.first.index)
154 return lhs.second < rhs.second;
155 return false;
156 }
157 };
158
StateMap(Graph * graph)159 StateMap::StateMap(Graph* graph) {
160 node_to_condid_map_.resize(graph->num_node_ids());
161 node_to_ancestorid_map_.resize(graph->num_node_ids());
162 // Initialize the dead state (empty state is designated with a nullptr).
163 dead_id_ = GetCondId(
164 {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
165 }
166
IsDead(StateMap::CondId id) const167 bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
168
IsEmpty(StateMap::CondId id) const169 bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
170
operator ()(const StateMap::CondState & map) const171 size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
172 if (map.empty()) return 0;
173 // Compute hash of the front element.
174 auto it = map.begin();
175 size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
176 hash<BranchType>()(it->second));
177 for (++it; it != map.end(); ++it) {
178 // Combine the has with the different elements in the map.
179 h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
180 hash<BranchType>()(it->second)));
181 }
182 return h;
183 }
184
operator ()(const StateMap::AncestorState & map) const185 size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
186 if (map.empty()) return 0;
187 // Compute hash of the front element.
188 auto it = map.begin();
189 size_t h = AncestorNode::Hash()(*it);
190 for (++it; it != map.end(); ++it) {
191 // Combine the has with the different elements in the map.
192 h = Hash64Combine(h, AncestorNode::Hash()(*it));
193 }
194 return h;
195 }
196
197 // CondArgNode represents a input to the conditional and its corresponding
198 // switch nodes.
199 struct CondArgNode {
CondArgNodetensorflow::functionalize_cond::CondArgNode200 explicit CondArgNode(Node* src, int src_output)
201 : src(src), src_output(src_output) {}
202
ToStringtensorflow::functionalize_cond::CondArgNode203 string ToString() const {
204 return absl::StrCat("src=", src->name(), ":", src_output,
205 " switches=", NodesToString(switches));
206 }
207
208 Node* src;
209 int src_output;
210 std::array<Node*, 2> branch_copy;
211 std::vector<Node*> switches;
212 };
213 using CondArgNodes = std::vector<CondArgNode>;
214
DebugString(const CondArgNodes & nodes)215 string DebugString(const CondArgNodes& nodes) {
216 return absl::StrCat(
217 "[",
218 absl::StrJoin(nodes, ", ",
219 [](string* output, const CondArgNode& node) {
220 absl::StrAppend(output, node.ToString());
221 }),
222 "]");
223 }
224
LookupCondId(const Node * node) const225 StateMap::CondId StateMap::LookupCondId(const Node* node) const {
226 const int64_t map_size = node_to_condid_map_.size();
227 if (node->id() < map_size) return node_to_condid_map_[node->id()];
228 return added_node_condid_mapping_.at(node->id());
229 }
230
GetCondId(const StateMap::CondState & state)231 StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
232 if (state.empty()) return nullptr;
233 return &*condstate_set_.insert(state).first;
234 }
235
ResetCondId(const Node * node,StateMap::CondId id)236 void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
237 const int64_t map_size = node_to_condid_map_.size();
238 if (node->id() < map_size)
239 node_to_condid_map_[node->id()] = id;
240 else
241 added_node_condid_mapping_[node->id()] = id;
242 }
243
LookupAncestorId(const Node * node) const244 StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
245 const int64_t map_size = node_to_ancestorid_map_.size();
246 if (node->id() < map_size) return node_to_ancestorid_map_[node->id()];
247 return added_node_ancestorid_mapping_.at(node->id());
248 }
249
GetAncestorId(const StateMap::AncestorState & state)250 StateMap::AncestorId StateMap::GetAncestorId(
251 const StateMap::AncestorState& state) {
252 if (state.empty()) return nullptr;
253 return &*ancestorstate_set_.insert(state).first;
254 }
255
ResetAncestorId(const Node * node,StateMap::AncestorId id)256 void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
257 const int64_t map_size = node_to_ancestorid_map_.size();
258 if (node->id() < map_size)
259 node_to_ancestorid_map_[node->id()] = id;
260 else
261 added_node_ancestorid_mapping_[node->id()] = id;
262 }
263
MarkDead(const Node * node)264 void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
265
CondStateToString(const Node * node) const266 string StateMap::CondStateToString(const Node* node) const {
267 return CondStateToString(LookupCondId(node));
268 }
269
CondStateToString(StateMap::CondId id) const270 string StateMap::CondStateToString(StateMap::CondId id) const {
271 return DebugString(id);
272 }
273
AncestorStateToString(const Node * node) const274 string StateMap::AncestorStateToString(const Node* node) const {
275 if (auto id = LookupAncestorId(node)) {
276 return absl::StrCat(
277 "{",
278 absl::StrJoin(*id, ",",
279 [](string* output, const AncestorNode& ancestor) {
280 absl::StrAppend(output,
281 ancestor.output_tensor.node->name(),
282 ":", ancestor.output_tensor.index);
283 }),
284 "}");
285 }
286 return "{}";
287 }
288
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)289 FunctionalizeCond::FunctionalizeCond(Graph* graph,
290 FunctionLibraryDefinition* library,
291 const NodeFilter& node_filter)
292 : state_map_(graph),
293 library_(library),
294 graph_(graph),
295 node_filter_(node_filter) {}
296
297 // Class representing the merge/switch nodes that will become a conditional.
298 class Conditional {
299 public:
300 Conditional(OutputTensor predicate, FunctionalizeCond* parent,
301 StateMap* cond_state_map, const ShapeRefiner& refiner);
302
303 // Adds merge node that is part of this conditional.
304 Status AddMerge(Node* m);
305
306 // Constructs an If node from the merge nodes.
307 Status BuildAndReplace(
308 Graph* graph, FunctionLibraryDefinition* library,
309 std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
310
311 private:
312 // Extracts the then/else bodies: creates new graphs with the nodes
313 // corresponding to the nodes in the then/else branches as of this conditional
314 // as function bodies.
315 Status ExtractBodies(Graph* graph);
316
317 // Builds the arguments that are the input to the If.
318 Status BuildArgumentNodes();
319
320 // Builds the If node for the extracted bodies with the given predicate.
321 Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
322
323 // Adds input edges to If node.
324 Status AddInputEdges(
325 Graph* graph,
326 const std::unordered_map<Node*, OutputTensor>& merge_to_replacement);
327
328 // Adds output edges from If node.
329 // Record new output tensor for all Merge nodes in 'merge_to_replacement'.
330 Status AddOutputEdges(
331 Graph* graph,
332 std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
333
334 // Adds switch node that is part of this conditional.
335 Status AddSwitch(Node* s);
336
337 // Adds a switch node along the edge and rewire the edge to go via the switch.
338 Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
339 Graph* graph);
340
341 // Internal name of conditional. The name is based on the first merge node
342 // added.
343 string name() const;
344
345 // The FunctionalizeCond instance that created this.
346 FunctionalizeCond* parent_;
347
348 // Mapping between nodes and their cond state.
349 StateMap* state_map_;
350
351 // The predicate of the conditional.
352 OutputTensor predicate_;
353
354 // Shape refiner of ops in the graph.
355 const ShapeRefiner& refiner_;
356
357 // The predicate of the switches of the conditional. This may be different
358 // than predicate (which is initialized from the original graph) as the
359 // predicate could be the output of a newly created If node.
360 OutputTensor switch_predicate_;
361
362 // Switch nodes in graph that are part of this conditional.
363 std::set<Node*, NodeCmpByNameResourcesLast> switches_;
364
365 // Merge nodes in graph that are part of this conditional.
366 std::set<Node*, NodeCmpByNameResourcesLast> merges_;
367
368 // Vector of control inputs from outside the conditional to a node inside.
369 std::vector<Node*> external_control_inputs_;
370 std::vector<Node*> external_control_outputs_;
371
372 // Graphs corresponding to the then and else branch.
373 std::array<std::unique_ptr<Graph>, 2> bodies_;
374
375 // Maps from graph_ to the branch body's graph.
376 std::array<std::vector<Node*>, 2> node_maps_;
377
378 // The argument nodes created for the switches.
379 CondArgNodes cond_arg_nodes_;
380
381 // The constructed If node.
382 Node* if_node_ = nullptr;
383
384 // Whether the merge nodes of this conditional have been replaced.
385 bool replaced_ = false;
386 };
387
Conditional(OutputTensor predicate,FunctionalizeCond * parent,StateMap * cond_state_map,const ShapeRefiner & refiner)388 Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
389 StateMap* cond_state_map, const ShapeRefiner& refiner)
390 : parent_(parent),
391 state_map_(cond_state_map),
392 predicate_(predicate),
393 refiner_(refiner) {}
394
AddMerge(Node * m)395 Status Conditional::AddMerge(Node* m) {
396 merges_.insert(m);
397 return OkStatus();
398 }
399
AddSwitch(Node * s)400 Status Conditional::AddSwitch(Node* s) {
401 VLOG(5) << "Adding switch " << s->DebugString();
402 OutputTensor predicate;
403 TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
404 if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
405 if (!(switch_predicate_ == predicate)) {
406 return errors::InvalidArgument(
407 "Merge nodes ", NodesToString(merges_),
408 " directly dominated by switch nodes with different predicates (",
409 DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
410 }
411 switches_.insert(s);
412 parent_->AddSwitchId(s->id());
413 return OkStatus();
414 }
415
BuildArgumentNodes()416 Status Conditional::BuildArgumentNodes() {
417 VLOG(1) << "Build function arguments";
418 struct Hash {
419 size_t operator()(const std::pair<Node*, int>& item) const {
420 return Hash64Combine(hash<Node*>()(item.first),
421 std::hash<int>()(item.second));
422 }
423 };
424
425 std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
426 for (Node* switch_node : switches_) {
427 const Edge* e;
428 TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
429 std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
430 if (input_index.find(key) == input_index.end()) {
431 input_index[key] = cond_arg_nodes_.size();
432 cond_arg_nodes_.emplace_back(key.first, key.second);
433 }
434 cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
435 }
436 VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
437
438 int arg_count = 0;
439 for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
440 DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
441 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
442 int branch_index = static_cast<int>(branch);
443 TF_RETURN_IF_ERROR(
444 NodeBuilder(absl::StrCat("_Arg", arg_count),
445 FunctionLibraryDefinition::kArgOp)
446 .Attr("T", dtype)
447 .Attr("index", arg_count)
448 .Finalize(bodies_[branch_index].get(),
449 &cond_arg_node.branch_copy[branch_index]));
450 }
451 for (Node* node : cond_arg_node.switches) {
452 for (const Edge* e : node->out_edges()) {
453 if (e->IsControlEdge()) continue;
454 int branch_index = e->src_output();
455 Node* src_copy = cond_arg_node.branch_copy[branch_index];
456 Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
457
458 // The graph may contain dead switch nodes,
459 if (dst_copy == nullptr) continue;
460
461 TF_RET_CHECK(dst_copy != nullptr)
462 << "Unable to find copied node for " << e->dst()->DebugString()
463 << " on branch " << Branch_Name(BranchType(branch_index));
464 // If the input goes directly to a merge then the merge has
465 // been replaced by a retval so the dst input is 0 instead of
466 // dst_input.
467 int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
468 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
469 }
470 }
471 ++arg_count;
472 }
473
474 // Verify that all retvals have an input.
475 // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
476 // input.
477 for (Node* m : merges_) {
478 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
479 bool has_input = false;
480 for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
481 if (!e->IsControlEdge()) {
482 has_input = true;
483 break;
484 }
485 }
486 if (!has_input) {
487 return errors::Internal(
488 "Failed to functionalize control flow with merge ",
489 FormatNodeForError(*m), " that doesn't have input on ",
490 Branch_Name(branch), " branch.");
491 }
492 }
493 }
494
495 return OkStatus();
496 }
497
AddSwitchNodeAlongEdge(const Edge * edge,BranchType branch,Graph * graph)498 Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
499 Graph* graph) {
500 // Previously we had edge:
501 // src:src_output ---- edge ----> dst:dst_input
502 // post this we have (in graph)
503 // src:src_output --> switch<pred> --- new_edge --> dst:dst_input
504
505 // TODO(jpienaar): One could keep a map caching the extra switch nodes added
506 // to avoid adding another switch to feed a value for which a switch was
507 // already added.
508 Node* switch_node;
509 Node* src = edge->src();
510 int src_output = edge->src_output();
511 TF_RETURN_IF_ERROR(
512 NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
513 "Switch")
514 .Input(src, src_output)
515 .Input(const_cast<Node*>(predicate_.node), predicate_.index)
516 .Finalize(graph, &switch_node));
517 state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
518 state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
519
520 Node* dst = edge->dst();
521 int dst_input = edge->dst_input();
522 graph->RemoveEdge(edge);
523 graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
524 return AddSwitch(switch_node);
525 }
526
ExtractBodies(Graph * graph)527 Status Conditional::ExtractBodies(Graph* graph) {
528 VLOG(2) << "Extracting bodies for " << name();
529 for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
530 bodies_[static_cast<int>(b)] =
531 std::make_unique<Graph>(graph->op_registry());
532 }
533
534 auto find_branch = [&](const Edge* e) {
535 const auto& id = state_map_->LookupCondId(e->src());
536 return IsSwitch(e->src()) ? BranchType(e->src_output())
537 : state_map_->FindBranchOf(id, predicate_);
538 };
539
540 std::array<std::vector<Node*>, 2> stacks;
541 VLOG(5) << "Merges: " << NodesToString(merges_);
542 for (Node* m : merges_) {
543 VLOG(5) << "For merge: " << m->DebugString() << " "
544 << state_map_->CondStateToString(m);
545 for (auto e : m->in_edges()) {
546 if (e->IsControlEdge()) continue;
547 BranchType branch = find_branch(e);
548 TF_RET_CHECK(branch == BranchType::kThenBranch ||
549 branch == BranchType::kElseBranch)
550 << "Error: " << e->src()->name()
551 << " is not on either then or else branch (" << Branch_Name(branch)
552 << ") for predicate " << DebugString(predicate_) << " ["
553 << DebugString(state_map_->LookupCondId(e->src())) << "].";
554 Node* src = e->src();
555 if (IsSwitch(src)) {
556 // Switch node outputs and dependencies are handled separately.
557 TF_RETURN_IF_ERROR(AddSwitch(src));
558 } else {
559 stacks[static_cast<int>(branch)].push_back(src);
560 }
561 }
562 }
563
564 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
565 int branch_index = static_cast<int>(branch);
566 auto output = bodies_[branch_index].get();
567 auto& stack = stacks[branch_index];
568 VLOG(5) << "In branch: " << Branch_Name(branch) << " "
569 << NodesToString(stack);
570 std::vector<bool> visited(graph->num_node_ids(), false);
571 node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
572 auto& node_map = node_maps_[branch_index];
573
574 while (!stack.empty()) {
575 Node* n = stack.back();
576 stack.pop_back();
577
578 if (visited.at(n->id())) continue;
579 visited[n->id()] = true;
580
581 // Verify output edges and record control edges exiting scope.
582 for (const Edge* e : n->out_edges()) {
583 Node* dst = e->dst();
584 if (IsMerge(dst)) continue;
585 Node* src = e->src();
586
587 auto dst_id = state_map_->LookupCondId(dst);
588 auto src_id = state_map_->LookupCondId(src);
589 if (dst_id != src_id) {
590 if (e->IsControlEdge()) {
591 external_control_outputs_.push_back(e->src());
592 } else {
593 // Constants are treated specially to workaround the case of
594 // non-dominated constant nodes.
595 if (!IsConstant(src)) {
596 // TODO(b/78882471): A node that feeds into two different
597 // CondState is not necessarily an error so log a warning for now
598 // but revisit to improve the testing to enable making this an
599 // error.
600 LOG(WARNING) << errors::InvalidArgument(
601 "Graph contains node ", FormatNodeForError(*src),
602 " that feeds into node ", FormatNodeForError(*dst),
603 " but these nodes are in different control contexts (",
604 DebugString(src_id), " vs ", DebugString(dst_id),
605 " (detected during out edge testing)");
606 }
607 }
608 }
609 }
610
611 // Copying incoming edges to dst node. Iterate over a copy of the edges
612 // as they could be mutated during iteration.
613 std::vector<const Edge*> in_edges(n->in_edges().begin(),
614 n->in_edges().end());
615 // Sort in_edges to make sure nodes are copied in a deterministic order.
616 std::sort(
617 in_edges.begin(), in_edges.end(), [](const Edge* a, const Edge* b) {
618 int a_src_output = a->src_output(), b_src_output = b->src_output();
619 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
620 return std::tie(a_src_output, a_name) <
621 std::tie(b_src_output, b_name);
622 });
623 for (const Edge* e : in_edges) {
624 Node* src = e->src();
625 // Skip src/dst node.
626 if (!src->IsOp()) continue;
627
628 Node* dst = e->dst();
629 if (IsSwitch(src)) {
630 // Switch node outputs and dependencies are handled separately.
631 TF_RETURN_IF_ERROR(AddSwitch(src));
632 continue;
633 }
634
635 // Verify input is from the same context.
636 auto src_id = state_map_->LookupCondId(src);
637 auto dst_id = state_map_->LookupCondId(dst);
638 if (IsMerge(dst) || src_id == dst_id) {
639 // TODO(jpienaar): The merge case can be more strict.
640 if (node_map.at(src->id()) == nullptr) {
641 node_map.at(src->id()) = output->CopyNode(src);
642 stack.push_back(src);
643 }
644 } else if (e->IsControlEdge()) {
645 // Here we have a control flow edge between src and dst that are not
646 // in the same context. This is an external control dependency except
647 // for one case: where the only difference between CondId of e->src()
648 // and CondId of e->dst() is that e->src() has {PRED, kNeither} and
649 // e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in
650 // gradients code for tf.cond(), where e->src() is a control pivot
651 // node for a branch and e->dst() is a data node in that branch.
652 bool is_external_control_input = true;
653 if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) {
654 std::vector<StateMap::CondState::value_type> diff;
655 std::set_symmetric_difference(
656 src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(),
657 std::back_inserter(diff), CondStateLess());
658 if (diff.size() == 2 && diff[0].first == diff[1].first &&
659 (diff[0].second == BranchType::kNeither ||
660 diff[1].second == BranchType::kNeither)) {
661 auto src_branch = src_id->find(diff[0].first);
662 if (src_branch != src_id->end() &&
663 src_branch->second == BranchType::kNeither) {
664 is_external_control_input = false;
665 }
666 }
667 }
668 if (is_external_control_input) {
669 external_control_inputs_.push_back(src);
670 }
671 } else {
672 // This shouldn't happen, this means we have an external data input
673 // not entering via a switch node. Work around this by for
674 // * constant nodes copy them;
675 // * non-constant nodes, insert a switch along the edge;
676 if (IsConstant(src)) {
677 // Check if constant node was added already. It is possible to have
678 // multiple uses of a constant node.
679 if (node_map.at(src->id()) == nullptr) {
680 node_map.at(src->id()) = output->CopyNode(src);
681 }
682 } else {
683 StateMap::CondState state = *dst_id;
684 state.erase(predicate_);
685 if (state_map_->GetCondId(state) == src_id) {
686 TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
687 continue;
688 } else {
689 return errors::InvalidArgument(
690 "Graph contains node ", FormatNodeForError(*src),
691 " that feeds into node ", FormatNodeForError(*dst),
692 " but these nodes are in different control contexts (",
693 DebugString(src_id), " vs ", DebugString(dst_id),
694 " (detected during in edge testing)");
695 }
696 }
697 }
698
699 Node* src_copy = node_map.at(e->src()->id());
700 int src_output = e->src_output();
701 if (node_map.at(dst->id()) == nullptr) {
702 node_map.at(dst->id()) = output->CopyNode(dst);
703 }
704 Node* dst_copy = node_map.at(e->dst()->id());
705 if (e->IsControlEdge()) {
706 // Skip control inputs from external context.
707 if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
708 } else {
709 output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
710 }
711 }
712 }
713 }
714
715 // Build return values from the merge nodes.
716 int index = 0;
717 for (Node* m : merges_) {
718 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
719 int branch_index = static_cast<int>(branch);
720 auto& node_map = node_maps_[branch_index];
721 auto output = bodies_[branch_index].get();
722 TF_ASSIGN_OR_RETURN(node_map[m->id()],
723 BuildRetvalNode(output, m->output_type(0), index));
724 }
725 ++index;
726
727 // Connect the input to the merge_ with the retval, except if it is a
728 // Switch node, which is handled separately.
729 for (auto e : m->in_edges()) {
730 if (e->IsControlEdge()) continue;
731 int branch_index = static_cast<int>(find_branch(e));
732 auto& node_map = node_maps_[branch_index];
733 auto output = bodies_[branch_index].get();
734 Node* in = e->src();
735 if (!IsSwitch(in)) {
736 if (node_map.at(in->id()) == nullptr) {
737 node_map[in->id()] = output->CopyNode(in);
738 }
739 output->AddEdge(node_map[in->id()], e->src_output(),
740 node_map.at(m->id()), 0);
741 }
742 }
743 }
744 return OkStatus();
745 }
746
BuildIfNode(Graph * graph,FunctionLibraryDefinition * library)747 Status Conditional::BuildIfNode(Graph* graph,
748 FunctionLibraryDefinition* library) {
749 VLOG(2) << "Build cond function for " << name();
750 NodeDebugInfo debug_info((*merges_.begin())->def());
751 NodeDefBuilder builder(name(), "If", library, &debug_info);
752 const string branch_name[] = {"else_branch", "then_branch"};
753 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
754 int branch_index = static_cast<int>(branch);
755
756 NameAttrList body_name;
757 body_name.set_name(library->UniqueFunctionName(
758 absl::StrCat("_functionalize_if_", branch_name[branch_index], "_")));
759
760 VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
761 << "): "
762 << DumpGraphToFile(
763 "functionalize_cond_body_" + branch_name[branch_index],
764 *bodies_[branch_index], nullptr);
765
766 FunctionDef body_fdef;
767 TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
768 body_name.name(), &body_fdef));
769 TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
770 builder.Attr(branch_name[branch_index], body_name);
771 }
772
773 VLOG(3) << "Build input type";
774 std::vector<NodeDefBuilder::NodeOut> inputs;
775 DataTypeVector in_arg_types;
776 for (auto& kv : cond_arg_nodes_) {
777 bool inserted = false;
778 for (const Node* arg : kv.switches) {
779 const Edge* in_edge;
780 TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
781 if (in_edge->IsControlEdge()) {
782 builder.ControlInput(in_edge->src()->name());
783 } else {
784 if (!inserted) {
785 DataType dtype = arg->input_type(0);
786 inputs.emplace_back(NodeDefBuilder::NodeOut(
787 in_edge->src()->name(), in_edge->src_output(), dtype));
788 in_arg_types.push_back(dtype);
789 inserted = true;
790 }
791 }
792 }
793 }
794 builder.Attr("Tin", in_arg_types);
795
796 DataTypeVector out_type;
797 std::vector<PartialTensorShape> output_shapes;
798 output_shapes.reserve(merges_.size());
799 for (const Node* merge : merges_) {
800 DataType dtype = merge->output_type(0);
801 TensorShapeProto shape;
802 if (auto* shape_ctx = refiner_.GetContext(merge)) {
803 shape_inference::ShapeHandle handle;
804 shape_ctx->ShapeHandleToProto(shape_ctx->output(0), &shape);
805 }
806 out_type.push_back(dtype);
807 output_shapes.push_back(shape);
808 }
809 builder.Attr("Tout", out_type);
810 VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
811 builder.Attr("output_shapes", output_shapes);
812 VLOG(3) << "Build output shapes: "
813 << PartialTensorShapeUtils::PartialShapeListString(output_shapes);
814
815 builder.Attr("Tcond", DT_BOOL);
816 // Add some internal attributes which need to be propagated.
817 for (absl::string_view attr_name : kAttrsToPropagate) {
818 string attr_val;
819 if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) {
820 builder.Attr(attr_name, attr_val);
821 }
822 }
823 builder.Device(predicate_.node->assigned_device_name());
824 // Conditional should be the first input ...
825 builder.Input(
826 NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index,
827 predicate_.node->output_type(predicate_.index)));
828 // ... followed by the other inputs.
829 builder.Input(inputs);
830
831 VLOG(3) << "Build If node";
832 NodeDef if_def;
833 TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
834 TF_ASSIGN_OR_RETURN(if_node_,
835 parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
836
837 return OkStatus();
838 }
839
AddInputEdges(Graph * graph,const std::unordered_map<Node *,OutputTensor> & merge_to_replacement)840 Status Conditional::AddInputEdges(
841 Graph* graph,
842 const std::unordered_map<Node*, OutputTensor>& merge_to_replacement) {
843 VLOG(2) << "AddInputEdges for " << if_node_->name();
844 int index = 0;
845 // Add predicate input.
846 if (predicate_.node->IsMerge()) {
847 // If the predicate is a Merge node, we should not use Merge output as
848 // predicate. Instead, we should use the corresponding If output in
849 // 'merge_to_replacement'. Otherwise, this Conditional's If node is still
850 // connected to the predicate Merge node; and when we call
851 // DeleteReachableAndDeadNodes(), the predicate Merge node and this
852 // Conditional's If node will be removed.
853 auto iter = merge_to_replacement.find(predicate_.node);
854 if (iter == merge_to_replacement.end()) {
855 return errors::Internal("Cannot find replacement for Merge node ",
856 predicate_.node->name());
857 }
858 graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++);
859 } else {
860 graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index,
861 if_node_, index++);
862 }
863 // Add function body inputs.
864 for (auto& arg : cond_arg_nodes_) {
865 if (arg.src_output == Graph::kControlSlot) {
866 graph->AddControlEdge(arg.src, if_node_);
867 } else {
868 graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
869 }
870 }
871 for (Node* n : external_control_inputs_) {
872 graph->AddControlEdge(n, if_node_);
873 }
874 return OkStatus();
875 }
876
AddOutputEdges(Graph * graph,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)877 Status Conditional::AddOutputEdges(
878 Graph* graph,
879 std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
880 VLOG(2) << "AddOutputEdges for " << if_node_->name();
881 int i = 0;
882 for (Node* node : merges_) {
883 TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
884 std::vector<const Edge*> edges(node->out_edges().begin(),
885 node->out_edges().end());
886 for (const Edge* edge : edges) {
887 Node* dst = edge->dst();
888 int dst_input = edge->dst_input();
889 if (edge->src_output() > 0) {
890 return errors::Unimplemented("Output of index (", edge->src_output(),
891 ") of merge node ",
892 FormatNodeForError(*node));
893 }
894
895 bool control_edge = edge->IsControlEdge();
896 graph->RemoveEdge(edge);
897 if (control_edge) {
898 graph->AddControlEdge(if_node_, dst);
899 } else {
900 graph->AddEdge(if_node_, i, dst, dst_input);
901 }
902 }
903
904 // Record corresponding output tensor in 'merge_to_replacement'.
905 (*merge_to_replacement)[node] = OutputTensor{if_node_, i};
906
907 ++i;
908 }
909 for (Node* n : external_control_outputs_) {
910 graph->AddControlEdge(if_node_, n);
911 }
912
913 return OkStatus();
914 }
915
BuildAndReplace(Graph * graph,FunctionLibraryDefinition * library,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)916 Status Conditional::BuildAndReplace(
917 Graph* graph, FunctionLibraryDefinition* library,
918 std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
919 VLOG(1) << "Build If and replace merge nodes "
920 << NodesToString(this->merges_);
921 if (replaced_) return OkStatus();
922
923 TF_RETURN_IF_ERROR(ExtractBodies(graph));
924 TF_RETURN_IF_ERROR(BuildArgumentNodes());
925
926 if (VLOG_IS_ON(3)) {
927 LOG(INFO) << "Extracted bodies:";
928 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
929 int branch_index = static_cast<int>(branch);
930 auto output = bodies_[branch_index].get();
931 LOG(INFO) << Branch_Name(branch) << ": "
932 << DebugString(output->ToGraphDefDebug());
933 }
934 }
935
936 TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
937 TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement));
938 TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement));
939 TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
940
941 // Check that the if_node doesn't feed into itself.
942 TF_RETURN_WITH_CONTEXT_IF_ERROR(
943 CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
944 "Converting to If failed.");
945
946 replaced_ = true;
947 return OkStatus();
948 }
949
name() const950 string Conditional::name() const {
951 CHECK(!merges_.empty());
952 return absl::StrCat((*merges_.begin())->name(), "_if");
953 }
954
AddIdentityNode(const Node * replacee,Node * if_node,int port)955 Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
956 int port) {
957 NodeBuilder id_builder(replacee->name(), "Identity");
958 id_builder.Input(if_node, port);
959 string outside_compilation;
960 if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttr,
961 &outside_compilation)
962 .ok()) {
963 id_builder.Attr(kXlaOutsideCompilationAttr, outside_compilation);
964 }
965 Node* id;
966 TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
967 state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
968 state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
969 return OkStatus();
970 }
971
AddIfNode(const NodeDef & def,const Node * replacee,const OutputTensor & predicate)972 StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
973 const Node* replacee,
974 const OutputTensor& predicate) {
975 TF_ASSIGN_OR_RETURN(Node * ret, graph_->AddNode(def));
976 VLOG(1) << "Adding If for " << replacee->name();
977 StateMap::CondId id = state_map_.LookupCondId(replacee);
978 if (id) {
979 StateMap::CondState state = *id;
980 state.erase(predicate);
981 state_map_.ResetCondId(ret, state_map_.GetCondId(state));
982 } else {
983 state_map_.ResetCondId(ret, nullptr);
984 }
985
986 state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
987
988 return ret;
989 }
990
PropagateUpdatedState(const Node * replacee)991 Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
992 VLOG(2) << "Propagating update state for " << replacee->name() << " "
993 << state_map_.CondStateToString(replacee);
994 // Redo topological sort as the order could have changed.
995 // TODO(jpienaar): The original topological order could also be updated
996 // dynamically if needed.
997 std::vector<Node*> rev_topo_order;
998 GetPostOrder(*graph_, &rev_topo_order, NodeComparatorID());
999
1000 // All the outputs of the new node could potentially be updated.
1001 std::unordered_set<Node*> changed;
1002 for (auto n : replacee->out_nodes())
1003 if (n->IsOp()) changed.insert(n);
1004
1005 // Iterate through the changed/possible changed nodes in topological order.
1006 for (auto it = rev_topo_order.rbegin();
1007 it != rev_topo_order.rend() && !changed.empty(); ++it) {
1008 if (changed.find(*it) != changed.end()) {
1009 // Update the node state.
1010 Node* n = *it;
1011 StateMap::CondId old_state = state_map_.LookupCondId(n);
1012 state_map_.ResetCondId(n, nullptr);
1013 TF_RETURN_IF_ERROR(DetermineCondState(n));
1014 if (state_map_.LookupCondId(n) != old_state) {
1015 for (auto out : n->out_nodes())
1016 if (out->IsOp()) changed.insert(out);
1017 }
1018 changed.erase(n);
1019 }
1020 }
1021 return OkStatus();
1022 }
1023
1024 // Returns the most restrictive branch of two branches or neither. This is the
1025 // meet operator of the BranchType lattice.
MeetBranch(const BranchType & lhs,const BranchType & rhs)1026 BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
1027 if (lhs == rhs) return lhs;
1028 if (lhs == BranchType::kNeither) return rhs;
1029 if (rhs == BranchType::kNeither) return lhs;
1030 if (lhs == BranchType::kBoth) return rhs;
1031 if (rhs == BranchType::kBoth) return lhs;
1032 return BranchType::kNeither;
1033 }
1034
FindBranchOf(CondId id,OutputTensor predicate) const1035 BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
1036 if (IsEmpty(id)) return BranchType::kNeither;
1037 const CondState& nodes = *id;
1038 auto it = nodes.find(predicate);
1039 if (it == nodes.end()) return BranchType::kNeither;
1040 return it->second;
1041 }
1042
JoinCondStatesNonMerge(StateMap::CondId src,StateMap::CondId dst)1043 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
1044 StateMap::CondId src, StateMap::CondId dst) {
1045 VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
1046 << "] and dst=" << DebugString(dst) << " [" << dst << "]";
1047
1048 if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
1049 if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
1050
1051 // Nothing to do if the CondState is the same.
1052 if (src == dst) return src;
1053
1054 StateMap::CondState both = *src;
1055 for (const auto& kv : *dst) {
1056 auto it = both.find(kv.first);
1057 if (it == both.end()) {
1058 both.insert(kv);
1059 } else {
1060 if (it->second != kv.second) {
1061 if (it->second == BranchType::kNeither) {
1062 // BranchType for 'src' is kNeither. Use the BranchType in 'dst'.
1063 it->second = kv.second;
1064 } else if (kv.second == BranchType::kNeither) {
1065 // BranchType for 'dst' is kNeither. Use the BranchType in 'src'.
1066 // No need to change it->second.
1067 } else {
1068 return errors::InvalidArgument(
1069 "Graph contains node with inputs predicated on incompatible "
1070 "predicates: ",
1071 DebugString(src), " and ", DebugString(dst));
1072 }
1073 }
1074 }
1075 }
1076 return state_map_.GetCondId(both);
1077 }
1078
JoinCondStatesMerge(Node * merge,StateMap::CondId src,StateMap::CondId dst)1079 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
1080 Node* merge, StateMap::CondId src, StateMap::CondId dst) {
1081 // Determine the flow state when joining two states for a merge
1082 // node. Combining the two states for a merge node is effectively performing a
1083 // disjunction of the states along the different input edges. For a merge that
1084 // can be transformed into an If the two inputs paths have to have a predicate
1085 // on which they differ (e.g., along one edge predicate `p` has to hold while
1086 // on another it should not). This function first determines this predicate
1087 // and then the resultant state is the common path between the two inputs
1088 // followed by s(p, both).
1089 VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
1090 << DebugString(dst);
1091 if (state_map_.IsEmpty(dst)) return src;
1092 if (state_map_.IsEmpty(src)) {
1093 return errors::Internal("Merge node ", merge->name(),
1094 " has input that's not in any CondContext.");
1095 }
1096
1097 if (state_map_.IsDead(src)) return src;
1098 if (state_map_.IsDead(dst)) return dst;
1099
1100 std::vector<StateMap::CondState::value_type> diff;
1101 StateMap::CondState merged;
1102 std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
1103 dst->end(), std::back_inserter(diff),
1104 CondStateLess());
1105 std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
1106 std::inserter(merged, merged.begin()), CondStateLess());
1107
1108 // Update mapping from merge node to predicate.
1109 if (diff.size() == 2) {
1110 auto pred = diff[0].first;
1111 bool different_branches = (diff[0].second != diff[1].second) &&
1112 (diff[0].second == BranchType::kThenBranch ||
1113 diff[0].second == BranchType::kElseBranch) &&
1114 (diff[1].second == BranchType::kThenBranch ||
1115 diff[1].second == BranchType::kElseBranch);
1116 if (!(pred == diff[1].first) || !different_branches)
1117 return errors::InvalidArgument(
1118 "Unable to determine predicate for merge node");
1119 merge_to_predicate_[merge] = pred;
1120 } else {
1121 return errors::InvalidArgument(
1122 "Merge of two inputs that differ on more than one predicate ",
1123 DebugString(src), " and ", DebugString(dst));
1124 }
1125
1126 return state_map_.GetCondId(merged);
1127 }
1128
StateAlongEdge(const Edge * e)1129 StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
1130 Node* src = e->src();
1131 StateMap::CondId id = state_map_.LookupCondId(e->src());
1132
1133 // Dead nodes only propagate dead state.
1134 if (state_map_.IsDead(id)) return id;
1135
1136 if (IsSwitch(src)) {
1137 StateMap::CondState state;
1138 if (id != nullptr) state = *id;
1139 OutputTensor predicate;
1140 TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
1141 if (e->IsControlEdge()) {
1142 // In gradients of tf.cond(), in each branch, we have a NoOp node as
1143 // control pivot. These NoOp nodes have control dependency from Switch
1144 // node. If we don't record this into CondState, branches might have
1145 // incorrect CondState (e.g. if the branch only has a Const data node).
1146 // We set it to kNeither because there is no way to tell whether it's
1147 // for true branch or false branch. This node's descendents might have
1148 // other incoming edges with defined BranchType, and we correctly handle
1149 // merging kNeither with other defined BranchType in StateAlongEdge().
1150 state[predicate] = BranchType::kNeither;
1151 } else {
1152 state[predicate] = BranchType(e->src_output());
1153 }
1154 return state_map_.GetCondId(state);
1155 }
1156 return id;
1157 }
1158
DetermineCondStateMerge(Node * dst)1159 Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
1160 // Only Merge nodes with two inputs are supported, but if this is a redundant
1161 // merge, then the dead edge may already have been removed (if due to a
1162 // switch) and so the input count would be incorrect.
1163 if (state_map_.IsDead(state_map_.LookupCondId(dst))) return OkStatus();
1164
1165 int data_inputs = 0;
1166 for (auto e : dst->in_edges()) {
1167 Node* src = e->src();
1168 VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
1169 << state_map_.CondStateToString(src);
1170 if (!src->IsOp()) continue;
1171 if (!e->IsControlEdge()) ++data_inputs;
1172
1173 StateMap::CondId prop = StateAlongEdge(e);
1174 auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
1175 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1176 FormatNodeForError(*dst));
1177 state_map_.ResetCondId(dst, id_or.ValueOrDie());
1178 }
1179
1180 // Incomplete Merge nodes are not supported.
1181 if (data_inputs != 2) {
1182 return errors::Unimplemented(
1183 dst->name(), " only has ", data_inputs,
1184 " inputs, while only merge nodes with two inputs supported.");
1185 }
1186 return OkStatus();
1187 }
1188
DetermineCondStateNonMerge(Node * dst)1189 Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
1190 // Handle non-merge join.
1191 for (auto e : dst->in_edges()) {
1192 VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
1193 << state_map_.CondStateToString(dst);
1194 Node* src = e->src();
1195 if (!src->IsOp()) continue;
1196
1197 // Joining the state between the current and propagated state.
1198 StateMap::CondId prop = StateAlongEdge(e);
1199 auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
1200 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1201 FormatNodeForError(*dst));
1202 state_map_.ResetCondId(dst, id_or.ValueOrDie());
1203 }
1204 return OkStatus();
1205 }
1206
RemoveRedundantMerge(Node * node)1207 Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
1208 // Handle redundant merge nodes. A merge node is considered redundant if
1209 // one input edge is dead while the other has a value.
1210 if (!state_map_.IsDead(state_map_.LookupCondId(node))) return OkStatus();
1211
1212 const Edge* non_dead_edge = nullptr;
1213 for (auto e : node->in_edges()) {
1214 if (e->IsControlEdge()) continue;
1215 Node* src = e->src();
1216
1217 // Handle merge with dead state.
1218 const auto& src_id = state_map_.LookupCondId(src);
1219 if (!state_map_.IsDead(src_id)) {
1220 non_dead_edge = e;
1221 break;
1222 }
1223 }
1224
1225 if (non_dead_edge == nullptr) {
1226 return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
1227 " has no non-dead inputs.");
1228 }
1229 state_map_.MarkDead(node);
1230 VLOG(5) << "removing redundant merge: " << node->name();
1231 while (!node->out_edges().empty()) {
1232 const Edge* oe = *node->out_edges().begin();
1233 Node* dst_node = oe->dst();
1234 int dst_port = oe->dst_input();
1235 graph_->RemoveEdge(oe);
1236 graph_->AddEdge(non_dead_edge->src(),
1237 dst_port == Graph::kControlSlot
1238 ? Graph::kControlSlot
1239 : non_dead_edge->src_output(),
1240 dst_node, dst_port);
1241 }
1242 return OkStatus();
1243 }
1244
RemoveRedundantSwitch(Node * node)1245 Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
1246 // Handle redundant switch nodes. A switch node is considered redundant if
1247 // the predicate of the switch already holds on the current branch. E.g., if
1248 // p is the predicate of the switch but p is already known to hold on this
1249 // branch, then the switch can be removed and the dead state propagated
1250 // along one. The checking of predicate is based on the exact predicate
1251 // (rather than boolean equivalence) and aimed at redundant switches as
1252 // currently generated by gradient code.
1253 StateMap::CondId dst_id = state_map_.LookupCondId(node);
1254 if (state_map_.IsDead(dst_id)) return OkStatus();
1255
1256 BranchType b;
1257 OutputTensor pred;
1258 TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
1259
1260 // Determine if we are already on a branch where the switch predicate is
1261 // true/false. Consider both the data and predicate to determine if the
1262 // node is redundant (skipping over identity node).
1263 b = state_map_.FindBranchOf(dst_id, pred);
1264 if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
1265 OutputTensor val;
1266 const Edge* e;
1267 TF_RETURN_IF_ERROR(node->input_edge(0, &e));
1268 val = OutputTensor(e->src(), e->src_output());
1269 while (IsIdentity(val.node)) {
1270 TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
1271 val = OutputTensor(e->src(), e->src_output());
1272 }
1273 b = state_map_.FindBranchOf(dst_id, val);
1274 if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
1275 return OkStatus();
1276 }
1277
1278 VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
1279 << DebugString(dst_id);
1280 const Edge* value_edge;
1281 TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
1282 Node* val_node = value_edge->src();
1283 int val_port = value_edge->src_output();
1284 while (!node->out_edges().empty()) {
1285 auto e = *node->out_edges().begin();
1286 Node* dst_node = e->dst();
1287 int dst_input = e->dst_input();
1288 int switch_branch = e->src_output();
1289 graph_->RemoveEdge(e);
1290 if (switch_branch == Graph::kControlSlot) {
1291 if (IsMerge(dst_node)) {
1292 auto id_or = JoinCondStatesMerge(dst_node, dst_id,
1293 state_map_.LookupCondId(dst_node));
1294 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1295 FormatNodeForError(*dst_node));
1296 state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1297 } else {
1298 auto id_or =
1299 JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
1300 TF_RETURN_IF_ERROR(id_or.status());
1301 state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1302 }
1303 } else if (BranchType(switch_branch) != b) {
1304 state_map_.MarkDead(dst_node);
1305 continue;
1306 }
1307 graph_->AddEdge(
1308 val_node,
1309 switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
1310 dst_node, dst_input);
1311 }
1312 return OkStatus();
1313 }
1314
DetermineStates(std::vector<Node * > rev_topo_order)1315 Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
1316 // The state that is propagated along the given edge.
1317 for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
1318 Node* dst = *it;
1319 TF_RETURN_IF_ERROR(DetermineCondState(dst));
1320 TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
1321 if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
1322 if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
1323
1324 VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
1325 << " @ " << state_map_.AncestorStateToString(dst);
1326 if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
1327 }
1328 return OkStatus();
1329 }
1330
DetermineAncestorState(Node * dst)1331 Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
1332 StateMap::AncestorId id = nullptr;
1333 StateMap::AncestorState state;
1334
1335 auto insert = [&](StateMap::AncestorId id, Node* src) {
1336 auto other_id = state_map_.LookupAncestorId(src);
1337 if (other_id != id && other_id != nullptr) {
1338 state.insert(other_id->begin(), other_id->end());
1339 }
1340 if (IsMerge(src)) {
1341 state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
1342 } else if (IsSwitch(src)) {
1343 OutputTensor pred;
1344 // For dead switch nodes, GetSwitchPredicate() will fail, and we use
1345 // the switch node directly as ancestor.
1346 if (GetSwitchPredicate(*src, &pred).ok()) {
1347 state.insert({pred, AncestorNode::AncestorNodeType::kPred});
1348 } else {
1349 state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
1350 }
1351 }
1352 return state_map_.GetAncestorId(state);
1353 };
1354
1355 // Compute the union of all the switch/merge nodes that affects the input of
1356 // dst.
1357 for (auto e : dst->in_edges()) {
1358 Node* src = e->src();
1359 id = insert(id, src);
1360 }
1361 state_map_.ResetAncestorId(dst, id);
1362 return OkStatus();
1363 }
1364
DeleteReachableAndDeadNodes(const std::vector<Node * > & merge_order)1365 void FunctionalizeCond::DeleteReachableAndDeadNodes(
1366 const std::vector<Node*>& merge_order) {
1367 // Delete all nodes that have been extracted or are reachable from
1368 // deleted/dead nodes. The input and outgoing edges should have already been
1369 // removed.
1370 std::deque<int> delete_nodes;
1371 std::vector<bool> deleted(graph_->num_node_ids(), false);
1372 // Don't try to delete source or sink nodes.
1373 deleted[graph_->kSourceId] = true;
1374 deleted[graph_->kSinkId] = true;
1375
1376 // All remaining switch nodes that were not excluded from functionalization
1377 // according to `node_filter_` are not reachable from a merge node and
1378 // removed. This is to account for dead switch nodes.
1379 for (int s_id : switch_ids_) {
1380 Node* s = graph_->FindNodeId(s_id);
1381 if (s == nullptr) continue;
1382 for (const Edge* e : s->out_edges()) {
1383 // Control outputs of switch nodes (which are unconditionally executed if
1384 // the switch is) are not removed as they need not be part of a
1385 // conditional.
1386 if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1387 }
1388 // Only remove switch node if we have functionalized the corresponding
1389 // condition before (according to `node_filter_`).
1390 if (!node_filter_ || node_filter_(s)) {
1391 VLOG(2) << "Removing obsolete switch node " << s->name();
1392 deleted[s_id] = true;
1393 graph_->RemoveNode(s);
1394 }
1395 }
1396
1397 // All merge nodes that were not excluded from functionalization according to
1398 // `node_filter_` should have been transformed at this point and we remove
1399 // them from the graph here.
1400 for (Node* m : merge_order) {
1401 for (const Edge* e : m->out_edges()) {
1402 // Similar to control outputs of switch nodes don't remove control
1403 // outputs of merge nodes.
1404 // TODO(jpienaar): Check cases where output edges still exist here vs
1405 // being removed in AddOutputEdges.
1406 if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1407 }
1408 // Only remove merge node if we have functionalized the corresponding
1409 // condition before (according to `node_filter_`).
1410 if (!node_filter_ || node_filter_(m)) {
1411 VLOG(2) << "Removing obsolete merge node " << m->name();
1412 deleted[m->id()] = true;
1413 graph_->RemoveNode(m);
1414 }
1415 }
1416
1417 // Enqueue all the dead nodes.
1418 for (Node* n : graph_->nodes()) {
1419 if (state_map_.IsDead(state_map_.LookupCondId(n))) {
1420 delete_nodes.push_back(n->id());
1421 }
1422 }
1423 // Remove dead nodes and nodes that are reachable from dead nodes.
1424 while (!delete_nodes.empty()) {
1425 int d_id = delete_nodes.front();
1426 delete_nodes.pop_front();
1427 if (deleted[d_id]) continue;
1428 Node* d = graph_->FindNodeId(d_id);
1429 // Switch and Merge nodes could have been deleted already.
1430 if (d == nullptr) continue;
1431 for (const Edge* e : d->out_edges()) {
1432 delete_nodes.push_back(e->dst()->id());
1433 }
1434 VLOG(2) << "Removing obsolete node " << d->name();
1435 deleted[d_id] = true;
1436 graph_->RemoveNode(d);
1437 }
1438 }
1439
SortMergeNodes(std::vector<Node * > * merge_order)1440 void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
1441 // Sort merge nodes by nesting depth.
1442 using sort_pair = std::pair<int, Node*>;
1443 std::vector<sort_pair> inner_to_outer_merge_order;
1444 inner_to_outer_merge_order.reserve(merge_order->size());
1445 for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
1446 Node* merge = *it;
1447 StateMap::CondId id = state_map_.LookupCondId(merge);
1448 int depth = id != nullptr ? id->size() : 0;
1449 inner_to_outer_merge_order.emplace_back(depth, merge);
1450 }
1451 std::stable_sort(
1452 inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
1453 [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
1454 merge_order->clear();
1455 for (sort_pair t : inner_to_outer_merge_order) {
1456 merge_order->push_back(t.second);
1457 }
1458 }
1459
FunctionalizeInternal()1460 Status FunctionalizeCond::FunctionalizeInternal() {
1461 // The general approach for converting a tf.cond (as lowered via switch/merge
1462 // nodes) to a functional if is as follows:
1463 // 1. Determine the topological order and collect all the switch and merge
1464 // nodes in the graph;
1465 // 2. Compute the predicates and dominance structure for all the nodes in the
1466 // graph - this includes which predicate must be true for a op to execute
1467 // (predicate values are considered directly rather than attempting to
1468 // determine deeper equivalence). We shall refer to this structure as the
1469 // CondState;
1470 // 3. Sort the merge nodes by nesting depth;
1471 // 4. Extract merge nodes together that have the same CondState and
1472 // AncestorState from the innermost to the outermost into IfOps;
1473 // Note: In the above only nodes that feed into a merge node will be
1474 // considered for functionalization.
1475 // Note: Nodes for which `node_filter_` returns false are excluded.
1476
1477 // Perform a DFS over the graph and
1478 // * Determine the reverse topological order of the nodes (there should be no
1479 // cycles at this point so the post-order numbering corresponds to the
1480 // reverse topological sorting);
1481 // * Record reverse topological for merge and switch nodes;
1482 std::vector<Node*> rev_topo_order;
1483 std::vector<Node*> merge_order;
1484 DFS(*graph_, nullptr, [&](Node* n) {
1485 // Only collect switch and merge nodes that are not filtered out, those form
1486 // the conditions that will be functionalized.
1487 if (!node_filter_ || node_filter_(n)) {
1488 if (IsSwitch(n)) {
1489 AddSwitchId(n->id());
1490 }
1491 if (IsMerge(n)) {
1492 merge_order.push_back(n);
1493 }
1494 }
1495 // Collect all other nodes here, independent of `node_filter_`, because they
1496 // might belong to a condition that should be functionalized.
1497 if (n->IsOp()) {
1498 rev_topo_order.push_back(n);
1499 }
1500 });
1501
1502 // No merges to functionalize.
1503 if (merge_order.empty()) {
1504 // No merges mean no switch values consumed (as only considering values
1505 // fetchable as output of merge);
1506 DeleteReachableAndDeadNodes(merge_order);
1507 return OkStatus();
1508 }
1509
1510 TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
1511 if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
1512
1513 // Determine the shapes of the ops in the graph.
1514 ShapeRefiner shape_refiner{graph_->versions().producer(),
1515 graph_->op_registry()};
1516 std::vector<Node*> nodes;
1517 GetReversePostOrder(*graph_, &nodes, NodeComparatorID());
1518 for (auto node : nodes) {
1519 if (!shape_refiner.AddNode(node).ok()) {
1520 LOG(WARNING) << "Couldn't deduce shape for " << node->name();
1521 }
1522 }
1523
1524 // Sort the merge nodes from innermost outwards.
1525 SortMergeNodes(&merge_order);
1526
1527 // Cluster merge nodes by (CondId, AncestorId, predicate) in order of
1528 // nesting. (CondId, AncestorId) is not enough, e.g.
1529 // pred1 = array_ops.placeholder(dtypes.bool, name='pred1')
1530 // pred2 = array_ops.placeholder(dtypes.bool, name='pred2')
1531 // cond1 = control_flow_ops.cond(pred1, ...)
1532 // cond2 = control_flow_ops.cond(pred2, ...)
1533 // cond3 = control_flow_ops.cond(pred1, use cond1 and cond2)
1534 // cond4 = control_flow_ops.cond(pred2, use cond1 and cond2)
1535 // cond3 and cond4 have the same (CondId, AncestorId), but they should not
1536 // be merged into one "If" node (because they have different predicates).
1537 std::deque<std::vector<Node*>> merge_clusters;
1538 std::map<ClusterTuple, int, ClusterTupleLessThan> merge_cluster_index;
1539 for (Node* merge : merge_order) {
1540 auto cond_id = state_map_.LookupCondId(merge);
1541 if (state_map_.IsDead(cond_id)) continue;
1542
1543 auto predicate = merge_to_predicate_.find(merge);
1544 if (predicate == merge_to_predicate_.end()) {
1545 return errors::Internal("Cannot find predicate for Merge node ",
1546 merge->name());
1547 }
1548
1549 ClusterTuple key = std::make_tuple(
1550 cond_id, state_map_.LookupAncestorId(merge), predicate->second);
1551 auto idx = merge_cluster_index.find(key);
1552 if (idx == merge_cluster_index.end()) {
1553 merge_cluster_index[key] = merge_clusters.size();
1554 merge_clusters.push_back({merge});
1555 } else {
1556 merge_clusters[idx->second].emplace_back(merge);
1557 }
1558 }
1559
1560 // Extract the conditionals from inner most to outer most. Extracting from
1561 // innermost to outermost enables the extraction pass to stop once it
1562 // encounters a Switch node instead of having to keep track of Switch/Merge
1563 // nodes seen.
1564 for (const auto& cluster : merge_clusters) {
1565 // Construct a Conditional with the predicate of the merge.
1566 Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_,
1567 shape_refiner);
1568 for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
1569 TF_RETURN_IF_ERROR(
1570 cond.BuildAndReplace(graph_, library_, &merge_to_replacement_));
1571
1572 if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
1573 }
1574
1575 DeleteReachableAndDeadNodes(merge_order);
1576
1577 return OkStatus();
1578 }
1579
DumpGraphWithCondState(const string & name)1580 void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
1581 const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
1582
1583 for (Node* n : graph_->nodes()) {
1584 n->ClearAttr(kCondGroupDebugAttr);
1585 n->AddAttr(kCondGroupDebugAttr,
1586 absl::StrCat(state_map_.CondStateToString(n), "_",
1587 state_map_.AncestorStateToString(n)));
1588 }
1589 LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
1590 << DumpGraphToFile(absl::StrCat("functionalize_cond_", name),
1591 *graph_, library_);
1592 }
1593
AddSwitchId(int switch_id)1594 void FunctionalizeCond::AddSwitchId(int switch_id) {
1595 switch_ids_.push_back(switch_id);
1596 }
1597
Functionalize(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1598 Status FunctionalizeCond::Functionalize(Graph* graph,
1599 FunctionLibraryDefinition* library,
1600 const NodeFilter& node_filter) {
1601 VLOG(1) << "FunctionalizeCond::Functionalize";
1602 FunctionalizeCond fc(graph, library, node_filter);
1603 return fc.FunctionalizeInternal();
1604 }
1605
1606 } // namespace functionalize_cond
1607
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1608 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
1609 const NodeFilter& node_filter) {
1610 // FunctionalizeControlFlow is invoked for every function, so the loops's
1611 // bodies and conditionals that were extracted into functions will be handled
1612 // in successive invocations.
1613 return functionalize_cond::FunctionalizeCond::Functionalize(graph, library,
1614 node_filter);
1615 }
1616
1617 } // namespace tensorflow
1618