xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/functionalize_cond.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
29 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/shape_refiner.h"
34 #include "tensorflow/core/framework/graph_to_functiondef.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/control_flow.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/util/dump_graph.h"
44 
45 namespace tensorflow {
46 namespace functionalize_cond {
47 
operator <(const AncestorNode & other) const48 bool AncestorNode::operator<(const AncestorNode& other) const {
49   return (output_tensor.node->id() < other.output_tensor.node->id()) ||
50          (output_tensor.node->id() == other.output_tensor.node->id() &&
51           output_tensor.index < other.output_tensor.index) ||
52          (output_tensor.node->id() == other.output_tensor.node->id() &&
53           output_tensor.index == other.output_tensor.index &&
54           type < other.type);
55 }
56 
operator ==(const AncestorNode & other) const57 bool AncestorNode::operator==(const AncestorNode& other) const {
58   return output_tensor.node->id() == other.output_tensor.node->id() &&
59          output_tensor.index == other.output_tensor.index && type == other.type;
60 }
61 
operator ()(const AncestorNode & ancestor) const62 size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
63   size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
64   h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
65   return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
66 }
67 
68 typedef std::tuple<StateMap::CondId, StateMap::AncestorId, OutputTensor>
69     ClusterTuple;
70 
71 struct ClusterTupleLessThan {
operator ()tensorflow::functionalize_cond::ClusterTupleLessThan72   bool operator()(const ClusterTuple& a, const ClusterTuple& b) const {
73     if (std::tie(std::get<0>(a), std::get<1>(a)) <
74         std::tie(std::get<0>(b), std::get<1>(b))) {
75       return true;
76     } else if (std::tie(std::get<0>(a), std::get<1>(a)) ==
77                std::tie(std::get<0>(b), std::get<1>(b))) {
78       return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b));
79     } else {
80       return false;
81     }
82   }
83 };
84 
85 // TODO(jpienaar): Move to OutputTensor.
DebugString(const OutputTensor & tensor)86 string DebugString(const OutputTensor& tensor) {
87   return absl::StrCat(tensor.node->name(), ":", tensor.index);
88 }
89 
Branch_Name(BranchType b)90 string Branch_Name(BranchType b) {
91   switch (b) {
92     case BranchType::kElseBranch:
93       return "else";
94     case BranchType::kThenBranch:
95       return "then";
96     case BranchType::kBoth:
97       return "both";
98     case BranchType::kNeither:
99       return "neither";
100   }
101 }
102 
DebugString(StateMap::CondId cond_state)103 string DebugString(StateMap::CondId cond_state) {
104   if (cond_state == nullptr || cond_state->empty()) return "{}";
105   using value_type = StateMap::CondState::value_type;
106   return absl::StrCat(
107       "{",
108       absl::StrJoin(*cond_state, ", ",
109                     [](string* output, const value_type& pred_branch) {
110                       const OutputTensor& pred = pred_branch.first;
111                       const BranchType& branch = pred_branch.second;
112                       if (branch == BranchType::kNeither)
113                         absl::StrAppend(output, "d");
114                       else
115                         absl::StrAppend(output, "s(", DebugString(pred), ",",
116                                         Branch_Name(branch), ")");
117                     }),
118       "}");
119 }
120 
121 // Returns the predicate of a switch.
GetSwitchPredicate(const Node & switch_node,OutputTensor * pred)122 Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
123   const Edge* pred_edge;
124   TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
125   // The predicate can be preceded by a identity node. Look through
126   // identity nodes to predicate.
127   while (pred_edge->src()->IsIdentity()) {
128     TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
129   }
130   *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
131   return OkStatus();
132 }
133 
GetSwitchValue(const Node & switch_node,OutputTensor * val)134 Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
135   const Edge* val_edge;
136   TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
137   *val = OutputTensor(val_edge->src(), val_edge->src_output());
138   return OkStatus();
139 }
140 
operator ()(const OutputTensor & lhs,const OutputTensor & rhs) const141 bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
142                                             const OutputTensor& rhs) const {
143   return (lhs.node->id() < rhs.node->id()) ||
144          (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
145 }
146 
147 struct CondStateLess {
operator ()tensorflow::functionalize_cond::CondStateLess148   bool operator()(const StateMap::CondState::value_type& lhs,
149                   const StateMap::CondState::value_type& rhs) const {
150     if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
151       return true;
152     if (lhs.first.node->id() == rhs.first.node->id() &&
153         lhs.first.index == rhs.first.index)
154       return lhs.second < rhs.second;
155     return false;
156   }
157 };
158 
StateMap(Graph * graph)159 StateMap::StateMap(Graph* graph) {
160   node_to_condid_map_.resize(graph->num_node_ids());
161   node_to_ancestorid_map_.resize(graph->num_node_ids());
162   // Initialize the dead state (empty state is designated with a nullptr).
163   dead_id_ = GetCondId(
164       {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
165 }
166 
IsDead(StateMap::CondId id) const167 bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
168 
IsEmpty(StateMap::CondId id) const169 bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
170 
operator ()(const StateMap::CondState & map) const171 size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
172   if (map.empty()) return 0;
173   // Compute hash of the front element.
174   auto it = map.begin();
175   size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
176                            hash<BranchType>()(it->second));
177   for (++it; it != map.end(); ++it) {
178     // Combine the has with the different elements in the map.
179     h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
180                                        hash<BranchType>()(it->second)));
181   }
182   return h;
183 }
184 
operator ()(const StateMap::AncestorState & map) const185 size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
186   if (map.empty()) return 0;
187   // Compute hash of the front element.
188   auto it = map.begin();
189   size_t h = AncestorNode::Hash()(*it);
190   for (++it; it != map.end(); ++it) {
191     // Combine the has with the different elements in the map.
192     h = Hash64Combine(h, AncestorNode::Hash()(*it));
193   }
194   return h;
195 }
196 
197 // CondArgNode represents a input to the conditional and its corresponding
198 // switch nodes.
199 struct CondArgNode {
CondArgNodetensorflow::functionalize_cond::CondArgNode200   explicit CondArgNode(Node* src, int src_output)
201       : src(src), src_output(src_output) {}
202 
ToStringtensorflow::functionalize_cond::CondArgNode203   string ToString() const {
204     return absl::StrCat("src=", src->name(), ":", src_output,
205                         " switches=", NodesToString(switches));
206   }
207 
208   Node* src;
209   int src_output;
210   std::array<Node*, 2> branch_copy;
211   std::vector<Node*> switches;
212 };
213 using CondArgNodes = std::vector<CondArgNode>;
214 
DebugString(const CondArgNodes & nodes)215 string DebugString(const CondArgNodes& nodes) {
216   return absl::StrCat(
217       "[",
218       absl::StrJoin(nodes, ", ",
219                     [](string* output, const CondArgNode& node) {
220                       absl::StrAppend(output, node.ToString());
221                     }),
222       "]");
223 }
224 
LookupCondId(const Node * node) const225 StateMap::CondId StateMap::LookupCondId(const Node* node) const {
226   const int64_t map_size = node_to_condid_map_.size();
227   if (node->id() < map_size) return node_to_condid_map_[node->id()];
228   return added_node_condid_mapping_.at(node->id());
229 }
230 
GetCondId(const StateMap::CondState & state)231 StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
232   if (state.empty()) return nullptr;
233   return &*condstate_set_.insert(state).first;
234 }
235 
ResetCondId(const Node * node,StateMap::CondId id)236 void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
237   const int64_t map_size = node_to_condid_map_.size();
238   if (node->id() < map_size)
239     node_to_condid_map_[node->id()] = id;
240   else
241     added_node_condid_mapping_[node->id()] = id;
242 }
243 
LookupAncestorId(const Node * node) const244 StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
245   const int64_t map_size = node_to_ancestorid_map_.size();
246   if (node->id() < map_size) return node_to_ancestorid_map_[node->id()];
247   return added_node_ancestorid_mapping_.at(node->id());
248 }
249 
GetAncestorId(const StateMap::AncestorState & state)250 StateMap::AncestorId StateMap::GetAncestorId(
251     const StateMap::AncestorState& state) {
252   if (state.empty()) return nullptr;
253   return &*ancestorstate_set_.insert(state).first;
254 }
255 
ResetAncestorId(const Node * node,StateMap::AncestorId id)256 void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
257   const int64_t map_size = node_to_ancestorid_map_.size();
258   if (node->id() < map_size)
259     node_to_ancestorid_map_[node->id()] = id;
260   else
261     added_node_ancestorid_mapping_[node->id()] = id;
262 }
263 
MarkDead(const Node * node)264 void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
265 
CondStateToString(const Node * node) const266 string StateMap::CondStateToString(const Node* node) const {
267   return CondStateToString(LookupCondId(node));
268 }
269 
CondStateToString(StateMap::CondId id) const270 string StateMap::CondStateToString(StateMap::CondId id) const {
271   return DebugString(id);
272 }
273 
AncestorStateToString(const Node * node) const274 string StateMap::AncestorStateToString(const Node* node) const {
275   if (auto id = LookupAncestorId(node)) {
276     return absl::StrCat(
277         "{",
278         absl::StrJoin(*id, ",",
279                       [](string* output, const AncestorNode& ancestor) {
280                         absl::StrAppend(output,
281                                         ancestor.output_tensor.node->name(),
282                                         ":", ancestor.output_tensor.index);
283                       }),
284         "}");
285   }
286   return "{}";
287 }
288 
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)289 FunctionalizeCond::FunctionalizeCond(Graph* graph,
290                                      FunctionLibraryDefinition* library,
291                                      const NodeFilter& node_filter)
292     : state_map_(graph),
293       library_(library),
294       graph_(graph),
295       node_filter_(node_filter) {}
296 
297 // Class representing the merge/switch nodes that will become a conditional.
298 class Conditional {
299  public:
300   Conditional(OutputTensor predicate, FunctionalizeCond* parent,
301               StateMap* cond_state_map, const ShapeRefiner& refiner);
302 
303   // Adds merge node that is part of this conditional.
304   Status AddMerge(Node* m);
305 
306   // Constructs an If node from the merge nodes.
307   Status BuildAndReplace(
308       Graph* graph, FunctionLibraryDefinition* library,
309       std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
310 
311  private:
312   // Extracts the then/else bodies: creates new graphs with the nodes
313   // corresponding to the nodes in the then/else branches as of this conditional
314   // as function bodies.
315   Status ExtractBodies(Graph* graph);
316 
317   // Builds the arguments that are the input to the If.
318   Status BuildArgumentNodes();
319 
320   // Builds the If node for the extracted bodies with the given predicate.
321   Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
322 
323   // Adds input edges to If node.
324   Status AddInputEdges(
325       Graph* graph,
326       const std::unordered_map<Node*, OutputTensor>& merge_to_replacement);
327 
328   // Adds output edges from If node.
329   // Record new output tensor for all Merge nodes in 'merge_to_replacement'.
330   Status AddOutputEdges(
331       Graph* graph,
332       std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
333 
334   // Adds switch node that is part of this conditional.
335   Status AddSwitch(Node* s);
336 
337   // Adds a switch node along the edge and rewire the edge to go via the switch.
338   Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
339                                 Graph* graph);
340 
341   // Internal name of conditional. The name is based on the first merge node
342   // added.
343   string name() const;
344 
345   // The FunctionalizeCond instance that created this.
346   FunctionalizeCond* parent_;
347 
348   // Mapping between nodes and their cond state.
349   StateMap* state_map_;
350 
351   // The predicate of the conditional.
352   OutputTensor predicate_;
353 
354   // Shape refiner of ops in the graph.
355   const ShapeRefiner& refiner_;
356 
357   // The predicate of the switches of the conditional. This may be different
358   // than predicate (which is initialized from the original graph) as the
359   // predicate could be the output of a newly created If node.
360   OutputTensor switch_predicate_;
361 
362   // Switch nodes in graph that are part of this conditional.
363   std::set<Node*, NodeCmpByNameResourcesLast> switches_;
364 
365   // Merge nodes in graph that are part of this conditional.
366   std::set<Node*, NodeCmpByNameResourcesLast> merges_;
367 
368   // Vector of control inputs from outside the conditional to a node inside.
369   std::vector<Node*> external_control_inputs_;
370   std::vector<Node*> external_control_outputs_;
371 
372   // Graphs corresponding to the then and else branch.
373   std::array<std::unique_ptr<Graph>, 2> bodies_;
374 
375   // Maps from graph_ to the branch body's graph.
376   std::array<std::vector<Node*>, 2> node_maps_;
377 
378   // The argument nodes created for the switches.
379   CondArgNodes cond_arg_nodes_;
380 
381   // The constructed If node.
382   Node* if_node_ = nullptr;
383 
384   // Whether the merge nodes of this conditional have been replaced.
385   bool replaced_ = false;
386 };
387 
Conditional(OutputTensor predicate,FunctionalizeCond * parent,StateMap * cond_state_map,const ShapeRefiner & refiner)388 Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
389                          StateMap* cond_state_map, const ShapeRefiner& refiner)
390     : parent_(parent),
391       state_map_(cond_state_map),
392       predicate_(predicate),
393       refiner_(refiner) {}
394 
AddMerge(Node * m)395 Status Conditional::AddMerge(Node* m) {
396   merges_.insert(m);
397   return OkStatus();
398 }
399 
AddSwitch(Node * s)400 Status Conditional::AddSwitch(Node* s) {
401   VLOG(5) << "Adding switch " << s->DebugString();
402   OutputTensor predicate;
403   TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
404   if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
405   if (!(switch_predicate_ == predicate)) {
406     return errors::InvalidArgument(
407         "Merge nodes ", NodesToString(merges_),
408         " directly dominated by switch nodes with different predicates (",
409         DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
410   }
411   switches_.insert(s);
412   parent_->AddSwitchId(s->id());
413   return OkStatus();
414 }
415 
BuildArgumentNodes()416 Status Conditional::BuildArgumentNodes() {
417   VLOG(1) << "Build function arguments";
418   struct Hash {
419     size_t operator()(const std::pair<Node*, int>& item) const {
420       return Hash64Combine(hash<Node*>()(item.first),
421                            std::hash<int>()(item.second));
422     }
423   };
424 
425   std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
426   for (Node* switch_node : switches_) {
427     const Edge* e;
428     TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
429     std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
430     if (input_index.find(key) == input_index.end()) {
431       input_index[key] = cond_arg_nodes_.size();
432       cond_arg_nodes_.emplace_back(key.first, key.second);
433     }
434     cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
435   }
436   VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
437 
438   int arg_count = 0;
439   for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
440     DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
441     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
442       int branch_index = static_cast<int>(branch);
443       TF_RETURN_IF_ERROR(
444           NodeBuilder(absl::StrCat("_Arg", arg_count),
445                       FunctionLibraryDefinition::kArgOp)
446               .Attr("T", dtype)
447               .Attr("index", arg_count)
448               .Finalize(bodies_[branch_index].get(),
449                         &cond_arg_node.branch_copy[branch_index]));
450     }
451     for (Node* node : cond_arg_node.switches) {
452       for (const Edge* e : node->out_edges()) {
453         if (e->IsControlEdge()) continue;
454         int branch_index = e->src_output();
455         Node* src_copy = cond_arg_node.branch_copy[branch_index];
456         Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
457 
458         // The graph may contain dead switch nodes,
459         if (dst_copy == nullptr) continue;
460 
461         TF_RET_CHECK(dst_copy != nullptr)
462             << "Unable to find copied node for " << e->dst()->DebugString()
463             << " on branch " << Branch_Name(BranchType(branch_index));
464         // If the input goes directly to a merge then the merge has
465         // been replaced by a retval so the dst input is 0 instead of
466         // dst_input.
467         int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
468         bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
469       }
470     }
471     ++arg_count;
472   }
473 
474   // Verify that all retvals have an input.
475   // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
476   // input.
477   for (Node* m : merges_) {
478     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
479       bool has_input = false;
480       for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
481         if (!e->IsControlEdge()) {
482           has_input = true;
483           break;
484         }
485       }
486       if (!has_input) {
487         return errors::Internal(
488             "Failed to functionalize control flow with merge ",
489             FormatNodeForError(*m), " that doesn't have input on ",
490             Branch_Name(branch), " branch.");
491       }
492     }
493   }
494 
495   return OkStatus();
496 }
497 
AddSwitchNodeAlongEdge(const Edge * edge,BranchType branch,Graph * graph)498 Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
499                                            Graph* graph) {
500   // Previously we had edge:
501   //   src:src_output ---- edge ----> dst:dst_input
502   // post this we have (in graph)
503   //   src:src_output --> switch<pred> --- new_edge --> dst:dst_input
504 
505   // TODO(jpienaar): One could keep a map caching the extra switch nodes added
506   // to avoid adding another switch to feed a value for which a switch was
507   // already added.
508   Node* switch_node;
509   Node* src = edge->src();
510   int src_output = edge->src_output();
511   TF_RETURN_IF_ERROR(
512       NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
513                   "Switch")
514           .Input(src, src_output)
515           .Input(const_cast<Node*>(predicate_.node), predicate_.index)
516           .Finalize(graph, &switch_node));
517   state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
518   state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
519 
520   Node* dst = edge->dst();
521   int dst_input = edge->dst_input();
522   graph->RemoveEdge(edge);
523   graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
524   return AddSwitch(switch_node);
525 }
526 
ExtractBodies(Graph * graph)527 Status Conditional::ExtractBodies(Graph* graph) {
528   VLOG(2) << "Extracting bodies for " << name();
529   for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
530     bodies_[static_cast<int>(b)] =
531         std::make_unique<Graph>(graph->op_registry());
532   }
533 
534   auto find_branch = [&](const Edge* e) {
535     const auto& id = state_map_->LookupCondId(e->src());
536     return IsSwitch(e->src()) ? BranchType(e->src_output())
537                               : state_map_->FindBranchOf(id, predicate_);
538   };
539 
540   std::array<std::vector<Node*>, 2> stacks;
541   VLOG(5) << "Merges: " << NodesToString(merges_);
542   for (Node* m : merges_) {
543     VLOG(5) << "For merge: " << m->DebugString() << " "
544             << state_map_->CondStateToString(m);
545     for (auto e : m->in_edges()) {
546       if (e->IsControlEdge()) continue;
547       BranchType branch = find_branch(e);
548       TF_RET_CHECK(branch == BranchType::kThenBranch ||
549                    branch == BranchType::kElseBranch)
550           << "Error: " << e->src()->name()
551           << " is not on either then or else branch (" << Branch_Name(branch)
552           << ") for predicate " << DebugString(predicate_) << " ["
553           << DebugString(state_map_->LookupCondId(e->src())) << "].";
554       Node* src = e->src();
555       if (IsSwitch(src)) {
556         // Switch node outputs and dependencies are handled separately.
557         TF_RETURN_IF_ERROR(AddSwitch(src));
558       } else {
559         stacks[static_cast<int>(branch)].push_back(src);
560       }
561     }
562   }
563 
564   for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
565     int branch_index = static_cast<int>(branch);
566     auto output = bodies_[branch_index].get();
567     auto& stack = stacks[branch_index];
568     VLOG(5) << "In branch: " << Branch_Name(branch) << " "
569             << NodesToString(stack);
570     std::vector<bool> visited(graph->num_node_ids(), false);
571     node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
572     auto& node_map = node_maps_[branch_index];
573 
574     while (!stack.empty()) {
575       Node* n = stack.back();
576       stack.pop_back();
577 
578       if (visited.at(n->id())) continue;
579       visited[n->id()] = true;
580 
581       // Verify output edges and record control edges exiting scope.
582       for (const Edge* e : n->out_edges()) {
583         Node* dst = e->dst();
584         if (IsMerge(dst)) continue;
585         Node* src = e->src();
586 
587         auto dst_id = state_map_->LookupCondId(dst);
588         auto src_id = state_map_->LookupCondId(src);
589         if (dst_id != src_id) {
590           if (e->IsControlEdge()) {
591             external_control_outputs_.push_back(e->src());
592           } else {
593             // Constants are treated specially to workaround the case of
594             // non-dominated constant nodes.
595             if (!IsConstant(src)) {
596               // TODO(b/78882471): A node that feeds into two different
597               // CondState is not necessarily an error so log a warning for now
598               // but revisit to improve the testing to enable making this an
599               // error.
600               LOG(WARNING) << errors::InvalidArgument(
601                   "Graph contains node ", FormatNodeForError(*src),
602                   " that feeds into node ", FormatNodeForError(*dst),
603                   " but these nodes are in different control contexts (",
604                   DebugString(src_id), " vs ", DebugString(dst_id),
605                   " (detected during out edge testing)");
606             }
607           }
608         }
609       }
610 
611       // Copying incoming edges to dst node. Iterate over a copy of the edges
612       // as they could be mutated during iteration.
613       std::vector<const Edge*> in_edges(n->in_edges().begin(),
614                                         n->in_edges().end());
615       // Sort in_edges to make sure nodes are copied in a deterministic order.
616       std::sort(
617           in_edges.begin(), in_edges.end(), [](const Edge* a, const Edge* b) {
618             int a_src_output = a->src_output(), b_src_output = b->src_output();
619             StringPiece a_name(a->src()->name()), b_name(b->src()->name());
620             return std::tie(a_src_output, a_name) <
621                    std::tie(b_src_output, b_name);
622           });
623       for (const Edge* e : in_edges) {
624         Node* src = e->src();
625         // Skip src/dst node.
626         if (!src->IsOp()) continue;
627 
628         Node* dst = e->dst();
629         if (IsSwitch(src)) {
630           // Switch node outputs and dependencies are handled separately.
631           TF_RETURN_IF_ERROR(AddSwitch(src));
632           continue;
633         }
634 
635         // Verify input is from the same context.
636         auto src_id = state_map_->LookupCondId(src);
637         auto dst_id = state_map_->LookupCondId(dst);
638         if (IsMerge(dst) || src_id == dst_id) {
639           // TODO(jpienaar): The merge case can be more strict.
640           if (node_map.at(src->id()) == nullptr) {
641             node_map.at(src->id()) = output->CopyNode(src);
642             stack.push_back(src);
643           }
644         } else if (e->IsControlEdge()) {
645           // Here we have a control flow edge between src and dst that are not
646           // in the same context. This is an external control dependency except
647           // for one case: where the only difference between CondId of e->src()
648           // and CondId of e->dst() is that e->src() has {PRED, kNeither} and
649           // e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in
650           // gradients code for tf.cond(), where e->src() is a control pivot
651           // node for a branch and e->dst() is a data node in that branch.
652           bool is_external_control_input = true;
653           if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) {
654             std::vector<StateMap::CondState::value_type> diff;
655             std::set_symmetric_difference(
656                 src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(),
657                 std::back_inserter(diff), CondStateLess());
658             if (diff.size() == 2 && diff[0].first == diff[1].first &&
659                 (diff[0].second == BranchType::kNeither ||
660                  diff[1].second == BranchType::kNeither)) {
661               auto src_branch = src_id->find(diff[0].first);
662               if (src_branch != src_id->end() &&
663                   src_branch->second == BranchType::kNeither) {
664                 is_external_control_input = false;
665               }
666             }
667           }
668           if (is_external_control_input) {
669             external_control_inputs_.push_back(src);
670           }
671         } else {
672           // This shouldn't happen, this means we have an external data input
673           // not entering via a switch node. Work around this by for
674           // * constant nodes copy them;
675           // * non-constant nodes, insert a switch along the edge;
676           if (IsConstant(src)) {
677             // Check if constant node was added already. It is possible to have
678             // multiple uses of a constant node.
679             if (node_map.at(src->id()) == nullptr) {
680               node_map.at(src->id()) = output->CopyNode(src);
681             }
682           } else {
683             StateMap::CondState state = *dst_id;
684             state.erase(predicate_);
685             if (state_map_->GetCondId(state) == src_id) {
686               TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
687               continue;
688             } else {
689               return errors::InvalidArgument(
690                   "Graph contains node ", FormatNodeForError(*src),
691                   " that feeds into node ", FormatNodeForError(*dst),
692                   " but these nodes are in different control contexts (",
693                   DebugString(src_id), " vs ", DebugString(dst_id),
694                   " (detected during in edge testing)");
695             }
696           }
697         }
698 
699         Node* src_copy = node_map.at(e->src()->id());
700         int src_output = e->src_output();
701         if (node_map.at(dst->id()) == nullptr) {
702           node_map.at(dst->id()) = output->CopyNode(dst);
703         }
704         Node* dst_copy = node_map.at(e->dst()->id());
705         if (e->IsControlEdge()) {
706           // Skip control inputs from external context.
707           if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
708         } else {
709           output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
710         }
711       }
712     }
713   }
714 
715   // Build return values from the merge nodes.
716   int index = 0;
717   for (Node* m : merges_) {
718     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
719       int branch_index = static_cast<int>(branch);
720       auto& node_map = node_maps_[branch_index];
721       auto output = bodies_[branch_index].get();
722       TF_ASSIGN_OR_RETURN(node_map[m->id()],
723                           BuildRetvalNode(output, m->output_type(0), index));
724     }
725     ++index;
726 
727     // Connect the input to the merge_ with the retval, except if it is a
728     // Switch node, which is handled separately.
729     for (auto e : m->in_edges()) {
730       if (e->IsControlEdge()) continue;
731       int branch_index = static_cast<int>(find_branch(e));
732       auto& node_map = node_maps_[branch_index];
733       auto output = bodies_[branch_index].get();
734       Node* in = e->src();
735       if (!IsSwitch(in)) {
736         if (node_map.at(in->id()) == nullptr) {
737           node_map[in->id()] = output->CopyNode(in);
738         }
739         output->AddEdge(node_map[in->id()], e->src_output(),
740                         node_map.at(m->id()), 0);
741       }
742     }
743   }
744   return OkStatus();
745 }
746 
BuildIfNode(Graph * graph,FunctionLibraryDefinition * library)747 Status Conditional::BuildIfNode(Graph* graph,
748                                 FunctionLibraryDefinition* library) {
749   VLOG(2) << "Build cond function for " << name();
750   NodeDebugInfo debug_info((*merges_.begin())->def());
751   NodeDefBuilder builder(name(), "If", library, &debug_info);
752   const string branch_name[] = {"else_branch", "then_branch"};
753   for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
754     int branch_index = static_cast<int>(branch);
755 
756     NameAttrList body_name;
757     body_name.set_name(library->UniqueFunctionName(
758         absl::StrCat("_functionalize_if_", branch_name[branch_index], "_")));
759 
760     VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
761             << "): "
762             << DumpGraphToFile(
763                    "functionalize_cond_body_" + branch_name[branch_index],
764                    *bodies_[branch_index], nullptr);
765 
766     FunctionDef body_fdef;
767     TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
768                                           body_name.name(), &body_fdef));
769     TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
770     builder.Attr(branch_name[branch_index], body_name);
771   }
772 
773   VLOG(3) << "Build input type";
774   std::vector<NodeDefBuilder::NodeOut> inputs;
775   DataTypeVector in_arg_types;
776   for (auto& kv : cond_arg_nodes_) {
777     bool inserted = false;
778     for (const Node* arg : kv.switches) {
779       const Edge* in_edge;
780       TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
781       if (in_edge->IsControlEdge()) {
782         builder.ControlInput(in_edge->src()->name());
783       } else {
784         if (!inserted) {
785           DataType dtype = arg->input_type(0);
786           inputs.emplace_back(NodeDefBuilder::NodeOut(
787               in_edge->src()->name(), in_edge->src_output(), dtype));
788           in_arg_types.push_back(dtype);
789           inserted = true;
790         }
791       }
792     }
793   }
794   builder.Attr("Tin", in_arg_types);
795 
796   DataTypeVector out_type;
797   std::vector<PartialTensorShape> output_shapes;
798   output_shapes.reserve(merges_.size());
799   for (const Node* merge : merges_) {
800     DataType dtype = merge->output_type(0);
801     TensorShapeProto shape;
802     if (auto* shape_ctx = refiner_.GetContext(merge)) {
803       shape_inference::ShapeHandle handle;
804       shape_ctx->ShapeHandleToProto(shape_ctx->output(0), &shape);
805     }
806     out_type.push_back(dtype);
807     output_shapes.push_back(shape);
808   }
809   builder.Attr("Tout", out_type);
810   VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
811   builder.Attr("output_shapes", output_shapes);
812   VLOG(3) << "Build output shapes: "
813           << PartialTensorShapeUtils::PartialShapeListString(output_shapes);
814 
815   builder.Attr("Tcond", DT_BOOL);
816   // Add some internal attributes which need to be propagated.
817   for (absl::string_view attr_name : kAttrsToPropagate) {
818     string attr_val;
819     if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) {
820       builder.Attr(attr_name, attr_val);
821     }
822   }
823   builder.Device(predicate_.node->assigned_device_name());
824   // Conditional should be the first input ...
825   builder.Input(
826       NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index,
827                               predicate_.node->output_type(predicate_.index)));
828   // ... followed by the other inputs.
829   builder.Input(inputs);
830 
831   VLOG(3) << "Build If node";
832   NodeDef if_def;
833   TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
834   TF_ASSIGN_OR_RETURN(if_node_,
835                       parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
836 
837   return OkStatus();
838 }
839 
AddInputEdges(Graph * graph,const std::unordered_map<Node *,OutputTensor> & merge_to_replacement)840 Status Conditional::AddInputEdges(
841     Graph* graph,
842     const std::unordered_map<Node*, OutputTensor>& merge_to_replacement) {
843   VLOG(2) << "AddInputEdges for " << if_node_->name();
844   int index = 0;
845   // Add predicate input.
846   if (predicate_.node->IsMerge()) {
847     // If the predicate is a Merge node, we should not use Merge output as
848     // predicate. Instead, we should use the corresponding If output in
849     // 'merge_to_replacement'. Otherwise, this Conditional's If node is still
850     // connected to the predicate Merge node; and when we call
851     // DeleteReachableAndDeadNodes(), the predicate Merge node and this
852     // Conditional's If node will be removed.
853     auto iter = merge_to_replacement.find(predicate_.node);
854     if (iter == merge_to_replacement.end()) {
855       return errors::Internal("Cannot find replacement for Merge node ",
856                               predicate_.node->name());
857     }
858     graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++);
859   } else {
860     graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index,
861                    if_node_, index++);
862   }
863   // Add function body inputs.
864   for (auto& arg : cond_arg_nodes_) {
865     if (arg.src_output == Graph::kControlSlot) {
866       graph->AddControlEdge(arg.src, if_node_);
867     } else {
868       graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
869     }
870   }
871   for (Node* n : external_control_inputs_) {
872     graph->AddControlEdge(n, if_node_);
873   }
874   return OkStatus();
875 }
876 
AddOutputEdges(Graph * graph,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)877 Status Conditional::AddOutputEdges(
878     Graph* graph,
879     std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
880   VLOG(2) << "AddOutputEdges for " << if_node_->name();
881   int i = 0;
882   for (Node* node : merges_) {
883     TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
884     std::vector<const Edge*> edges(node->out_edges().begin(),
885                                    node->out_edges().end());
886     for (const Edge* edge : edges) {
887       Node* dst = edge->dst();
888       int dst_input = edge->dst_input();
889       if (edge->src_output() > 0) {
890         return errors::Unimplemented("Output of index (", edge->src_output(),
891                                      ") of merge node ",
892                                      FormatNodeForError(*node));
893       }
894 
895       bool control_edge = edge->IsControlEdge();
896       graph->RemoveEdge(edge);
897       if (control_edge) {
898         graph->AddControlEdge(if_node_, dst);
899       } else {
900         graph->AddEdge(if_node_, i, dst, dst_input);
901       }
902     }
903 
904     // Record corresponding output tensor in 'merge_to_replacement'.
905     (*merge_to_replacement)[node] = OutputTensor{if_node_, i};
906 
907     ++i;
908   }
909   for (Node* n : external_control_outputs_) {
910     graph->AddControlEdge(if_node_, n);
911   }
912 
913   return OkStatus();
914 }
915 
BuildAndReplace(Graph * graph,FunctionLibraryDefinition * library,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)916 Status Conditional::BuildAndReplace(
917     Graph* graph, FunctionLibraryDefinition* library,
918     std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
919   VLOG(1) << "Build If and replace merge nodes "
920           << NodesToString(this->merges_);
921   if (replaced_) return OkStatus();
922 
923   TF_RETURN_IF_ERROR(ExtractBodies(graph));
924   TF_RETURN_IF_ERROR(BuildArgumentNodes());
925 
926   if (VLOG_IS_ON(3)) {
927     LOG(INFO) << "Extracted bodies:";
928     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
929       int branch_index = static_cast<int>(branch);
930       auto output = bodies_[branch_index].get();
931       LOG(INFO) << Branch_Name(branch) << ": "
932                 << DebugString(output->ToGraphDefDebug());
933     }
934   }
935 
936   TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
937   TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement));
938   TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement));
939   TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
940 
941   // Check that the if_node doesn't feed into itself.
942   TF_RETURN_WITH_CONTEXT_IF_ERROR(
943       CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
944       "Converting to If failed.");
945 
946   replaced_ = true;
947   return OkStatus();
948 }
949 
name() const950 string Conditional::name() const {
951   CHECK(!merges_.empty());
952   return absl::StrCat((*merges_.begin())->name(), "_if");
953 }
954 
AddIdentityNode(const Node * replacee,Node * if_node,int port)955 Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
956                                           int port) {
957   NodeBuilder id_builder(replacee->name(), "Identity");
958   id_builder.Input(if_node, port);
959   string outside_compilation;
960   if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttr,
961                   &outside_compilation)
962           .ok()) {
963     id_builder.Attr(kXlaOutsideCompilationAttr, outside_compilation);
964   }
965   Node* id;
966   TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
967   state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
968   state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
969   return OkStatus();
970 }
971 
AddIfNode(const NodeDef & def,const Node * replacee,const OutputTensor & predicate)972 StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
973                                              const Node* replacee,
974                                              const OutputTensor& predicate) {
975   TF_ASSIGN_OR_RETURN(Node * ret, graph_->AddNode(def));
976   VLOG(1) << "Adding If for " << replacee->name();
977   StateMap::CondId id = state_map_.LookupCondId(replacee);
978   if (id) {
979     StateMap::CondState state = *id;
980     state.erase(predicate);
981     state_map_.ResetCondId(ret, state_map_.GetCondId(state));
982   } else {
983     state_map_.ResetCondId(ret, nullptr);
984   }
985 
986   state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
987 
988   return ret;
989 }
990 
PropagateUpdatedState(const Node * replacee)991 Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
992   VLOG(2) << "Propagating update state for " << replacee->name() << " "
993           << state_map_.CondStateToString(replacee);
994   // Redo topological sort as the order could have changed.
995   // TODO(jpienaar): The original topological order could also be updated
996   // dynamically if needed.
997   std::vector<Node*> rev_topo_order;
998   GetPostOrder(*graph_, &rev_topo_order, NodeComparatorID());
999 
1000   // All the outputs of the new node could potentially be updated.
1001   std::unordered_set<Node*> changed;
1002   for (auto n : replacee->out_nodes())
1003     if (n->IsOp()) changed.insert(n);
1004 
1005   // Iterate through the changed/possible changed nodes in topological order.
1006   for (auto it = rev_topo_order.rbegin();
1007        it != rev_topo_order.rend() && !changed.empty(); ++it) {
1008     if (changed.find(*it) != changed.end()) {
1009       // Update the node state.
1010       Node* n = *it;
1011       StateMap::CondId old_state = state_map_.LookupCondId(n);
1012       state_map_.ResetCondId(n, nullptr);
1013       TF_RETURN_IF_ERROR(DetermineCondState(n));
1014       if (state_map_.LookupCondId(n) != old_state) {
1015         for (auto out : n->out_nodes())
1016           if (out->IsOp()) changed.insert(out);
1017       }
1018       changed.erase(n);
1019     }
1020   }
1021   return OkStatus();
1022 }
1023 
1024 // Returns the most restrictive branch of two branches or neither. This is the
1025 // meet operator of the BranchType lattice.
MeetBranch(const BranchType & lhs,const BranchType & rhs)1026 BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
1027   if (lhs == rhs) return lhs;
1028   if (lhs == BranchType::kNeither) return rhs;
1029   if (rhs == BranchType::kNeither) return lhs;
1030   if (lhs == BranchType::kBoth) return rhs;
1031   if (rhs == BranchType::kBoth) return lhs;
1032   return BranchType::kNeither;
1033 }
1034 
FindBranchOf(CondId id,OutputTensor predicate) const1035 BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
1036   if (IsEmpty(id)) return BranchType::kNeither;
1037   const CondState& nodes = *id;
1038   auto it = nodes.find(predicate);
1039   if (it == nodes.end()) return BranchType::kNeither;
1040   return it->second;
1041 }
1042 
JoinCondStatesNonMerge(StateMap::CondId src,StateMap::CondId dst)1043 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
1044     StateMap::CondId src, StateMap::CondId dst) {
1045   VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
1046           << "] and dst=" << DebugString(dst) << " [" << dst << "]";
1047 
1048   if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
1049   if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
1050 
1051   // Nothing to do if the CondState is the same.
1052   if (src == dst) return src;
1053 
1054   StateMap::CondState both = *src;
1055   for (const auto& kv : *dst) {
1056     auto it = both.find(kv.first);
1057     if (it == both.end()) {
1058       both.insert(kv);
1059     } else {
1060       if (it->second != kv.second) {
1061         if (it->second == BranchType::kNeither) {
1062           // BranchType for 'src' is kNeither. Use the BranchType in 'dst'.
1063           it->second = kv.second;
1064         } else if (kv.second == BranchType::kNeither) {
1065           // BranchType for 'dst' is kNeither. Use the BranchType in 'src'.
1066           // No need to change it->second.
1067         } else {
1068           return errors::InvalidArgument(
1069               "Graph contains node with inputs predicated on incompatible "
1070               "predicates: ",
1071               DebugString(src), " and ", DebugString(dst));
1072         }
1073       }
1074     }
1075   }
1076   return state_map_.GetCondId(both);
1077 }
1078 
JoinCondStatesMerge(Node * merge,StateMap::CondId src,StateMap::CondId dst)1079 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
1080     Node* merge, StateMap::CondId src, StateMap::CondId dst) {
1081   // Determine the flow state when joining two states for a merge
1082   // node. Combining the two states for a merge node is effectively performing a
1083   // disjunction of the states along the different input edges. For a merge that
1084   // can be transformed into an If the two inputs paths have to have a predicate
1085   // on which they differ (e.g., along one edge predicate `p` has to hold while
1086   // on another it should not). This function first determines this predicate
1087   // and then the resultant state is the common path between the two inputs
1088   // followed by s(p, both).
1089   VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
1090           << DebugString(dst);
1091   if (state_map_.IsEmpty(dst)) return src;
1092   if (state_map_.IsEmpty(src)) {
1093     return errors::Internal("Merge node ", merge->name(),
1094                             " has input that's not in any CondContext.");
1095   }
1096 
1097   if (state_map_.IsDead(src)) return src;
1098   if (state_map_.IsDead(dst)) return dst;
1099 
1100   std::vector<StateMap::CondState::value_type> diff;
1101   StateMap::CondState merged;
1102   std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
1103                                 dst->end(), std::back_inserter(diff),
1104                                 CondStateLess());
1105   std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
1106                         std::inserter(merged, merged.begin()), CondStateLess());
1107 
1108   // Update mapping from merge node to predicate.
1109   if (diff.size() == 2) {
1110     auto pred = diff[0].first;
1111     bool different_branches = (diff[0].second != diff[1].second) &&
1112                               (diff[0].second == BranchType::kThenBranch ||
1113                                diff[0].second == BranchType::kElseBranch) &&
1114                               (diff[1].second == BranchType::kThenBranch ||
1115                                diff[1].second == BranchType::kElseBranch);
1116     if (!(pred == diff[1].first) || !different_branches)
1117       return errors::InvalidArgument(
1118           "Unable to determine predicate for merge node");
1119     merge_to_predicate_[merge] = pred;
1120   } else {
1121     return errors::InvalidArgument(
1122         "Merge of two inputs that differ on more than one predicate ",
1123         DebugString(src), " and ", DebugString(dst));
1124   }
1125 
1126   return state_map_.GetCondId(merged);
1127 }
1128 
StateAlongEdge(const Edge * e)1129 StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
1130   Node* src = e->src();
1131   StateMap::CondId id = state_map_.LookupCondId(e->src());
1132 
1133   // Dead nodes only propagate dead state.
1134   if (state_map_.IsDead(id)) return id;
1135 
1136   if (IsSwitch(src)) {
1137     StateMap::CondState state;
1138     if (id != nullptr) state = *id;
1139     OutputTensor predicate;
1140     TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
1141     if (e->IsControlEdge()) {
1142       // In gradients of tf.cond(), in each branch, we have a NoOp node as
1143       // control pivot. These NoOp nodes have control dependency from Switch
1144       // node. If we don't record this into CondState, branches might have
1145       // incorrect CondState (e.g. if the branch only has a Const data node).
1146       // We set it to kNeither because there is no way to tell whether it's
1147       // for true branch or false branch. This node's descendents might have
1148       // other incoming edges with defined BranchType, and we correctly handle
1149       // merging kNeither with other defined BranchType in StateAlongEdge().
1150       state[predicate] = BranchType::kNeither;
1151     } else {
1152       state[predicate] = BranchType(e->src_output());
1153     }
1154     return state_map_.GetCondId(state);
1155   }
1156   return id;
1157 }
1158 
DetermineCondStateMerge(Node * dst)1159 Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
1160   // Only Merge nodes with two inputs are supported, but if this is a redundant
1161   // merge, then the dead edge may already have been removed (if due to a
1162   // switch) and so the input count would be incorrect.
1163   if (state_map_.IsDead(state_map_.LookupCondId(dst))) return OkStatus();
1164 
1165   int data_inputs = 0;
1166   for (auto e : dst->in_edges()) {
1167     Node* src = e->src();
1168     VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
1169             << state_map_.CondStateToString(src);
1170     if (!src->IsOp()) continue;
1171     if (!e->IsControlEdge()) ++data_inputs;
1172 
1173     StateMap::CondId prop = StateAlongEdge(e);
1174     auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
1175     TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1176                                     FormatNodeForError(*dst));
1177     state_map_.ResetCondId(dst, id_or.ValueOrDie());
1178   }
1179 
1180   // Incomplete Merge nodes are not supported.
1181   if (data_inputs != 2) {
1182     return errors::Unimplemented(
1183         dst->name(), " only has ", data_inputs,
1184         " inputs, while only merge nodes with two inputs supported.");
1185   }
1186   return OkStatus();
1187 }
1188 
DetermineCondStateNonMerge(Node * dst)1189 Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
1190   // Handle non-merge join.
1191   for (auto e : dst->in_edges()) {
1192     VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
1193             << state_map_.CondStateToString(dst);
1194     Node* src = e->src();
1195     if (!src->IsOp()) continue;
1196 
1197     // Joining the state between the current and propagated state.
1198     StateMap::CondId prop = StateAlongEdge(e);
1199     auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
1200     TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1201                                     FormatNodeForError(*dst));
1202     state_map_.ResetCondId(dst, id_or.ValueOrDie());
1203   }
1204   return OkStatus();
1205 }
1206 
RemoveRedundantMerge(Node * node)1207 Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
1208   // Handle redundant merge nodes. A merge node is considered redundant if
1209   // one input edge is dead while the other has a value.
1210   if (!state_map_.IsDead(state_map_.LookupCondId(node))) return OkStatus();
1211 
1212   const Edge* non_dead_edge = nullptr;
1213   for (auto e : node->in_edges()) {
1214     if (e->IsControlEdge()) continue;
1215     Node* src = e->src();
1216 
1217     // Handle merge with dead state.
1218     const auto& src_id = state_map_.LookupCondId(src);
1219     if (!state_map_.IsDead(src_id)) {
1220       non_dead_edge = e;
1221       break;
1222     }
1223   }
1224 
1225   if (non_dead_edge == nullptr) {
1226     return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
1227                                    " has no non-dead inputs.");
1228   }
1229   state_map_.MarkDead(node);
1230   VLOG(5) << "removing redundant merge: " << node->name();
1231   while (!node->out_edges().empty()) {
1232     const Edge* oe = *node->out_edges().begin();
1233     Node* dst_node = oe->dst();
1234     int dst_port = oe->dst_input();
1235     graph_->RemoveEdge(oe);
1236     graph_->AddEdge(non_dead_edge->src(),
1237                     dst_port == Graph::kControlSlot
1238                         ? Graph::kControlSlot
1239                         : non_dead_edge->src_output(),
1240                     dst_node, dst_port);
1241   }
1242   return OkStatus();
1243 }
1244 
RemoveRedundantSwitch(Node * node)1245 Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
1246   // Handle redundant switch nodes. A switch node is considered redundant if
1247   // the predicate of the switch already holds on the current branch. E.g., if
1248   // p is the predicate of the switch but p is already known to hold on this
1249   // branch, then the switch can be removed and the dead state propagated
1250   // along one. The checking of predicate is based on the exact predicate
1251   // (rather than boolean equivalence) and aimed at redundant switches as
1252   // currently generated by gradient code.
1253   StateMap::CondId dst_id = state_map_.LookupCondId(node);
1254   if (state_map_.IsDead(dst_id)) return OkStatus();
1255 
1256   BranchType b;
1257   OutputTensor pred;
1258   TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
1259 
1260   // Determine if we are already on a branch where the switch predicate is
1261   // true/false. Consider both the data and predicate to determine if the
1262   // node is redundant (skipping over identity node).
1263   b = state_map_.FindBranchOf(dst_id, pred);
1264   if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
1265     OutputTensor val;
1266     const Edge* e;
1267     TF_RETURN_IF_ERROR(node->input_edge(0, &e));
1268     val = OutputTensor(e->src(), e->src_output());
1269     while (IsIdentity(val.node)) {
1270       TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
1271       val = OutputTensor(e->src(), e->src_output());
1272     }
1273     b = state_map_.FindBranchOf(dst_id, val);
1274     if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
1275       return OkStatus();
1276   }
1277 
1278   VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
1279           << DebugString(dst_id);
1280   const Edge* value_edge;
1281   TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
1282   Node* val_node = value_edge->src();
1283   int val_port = value_edge->src_output();
1284   while (!node->out_edges().empty()) {
1285     auto e = *node->out_edges().begin();
1286     Node* dst_node = e->dst();
1287     int dst_input = e->dst_input();
1288     int switch_branch = e->src_output();
1289     graph_->RemoveEdge(e);
1290     if (switch_branch == Graph::kControlSlot) {
1291       if (IsMerge(dst_node)) {
1292         auto id_or = JoinCondStatesMerge(dst_node, dst_id,
1293                                          state_map_.LookupCondId(dst_node));
1294         TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1295                                         FormatNodeForError(*dst_node));
1296         state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1297       } else {
1298         auto id_or =
1299             JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
1300         TF_RETURN_IF_ERROR(id_or.status());
1301         state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1302       }
1303     } else if (BranchType(switch_branch) != b) {
1304       state_map_.MarkDead(dst_node);
1305       continue;
1306     }
1307     graph_->AddEdge(
1308         val_node,
1309         switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
1310         dst_node, dst_input);
1311   }
1312   return OkStatus();
1313 }
1314 
DetermineStates(std::vector<Node * > rev_topo_order)1315 Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
1316   // The state that is propagated along the given edge.
1317   for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
1318     Node* dst = *it;
1319     TF_RETURN_IF_ERROR(DetermineCondState(dst));
1320     TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
1321     if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
1322     if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
1323 
1324     VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
1325             << " @ " << state_map_.AncestorStateToString(dst);
1326     if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
1327   }
1328   return OkStatus();
1329 }
1330 
DetermineAncestorState(Node * dst)1331 Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
1332   StateMap::AncestorId id = nullptr;
1333   StateMap::AncestorState state;
1334 
1335   auto insert = [&](StateMap::AncestorId id, Node* src) {
1336     auto other_id = state_map_.LookupAncestorId(src);
1337     if (other_id != id && other_id != nullptr) {
1338       state.insert(other_id->begin(), other_id->end());
1339     }
1340     if (IsMerge(src)) {
1341       state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
1342     } else if (IsSwitch(src)) {
1343       OutputTensor pred;
1344       // For dead switch nodes, GetSwitchPredicate() will fail, and we use
1345       // the switch node directly as ancestor.
1346       if (GetSwitchPredicate(*src, &pred).ok()) {
1347         state.insert({pred, AncestorNode::AncestorNodeType::kPred});
1348       } else {
1349         state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
1350       }
1351     }
1352     return state_map_.GetAncestorId(state);
1353   };
1354 
1355   // Compute the union of all the switch/merge nodes that affects the input of
1356   // dst.
1357   for (auto e : dst->in_edges()) {
1358     Node* src = e->src();
1359     id = insert(id, src);
1360   }
1361   state_map_.ResetAncestorId(dst, id);
1362   return OkStatus();
1363 }
1364 
DeleteReachableAndDeadNodes(const std::vector<Node * > & merge_order)1365 void FunctionalizeCond::DeleteReachableAndDeadNodes(
1366     const std::vector<Node*>& merge_order) {
1367   // Delete all nodes that have been extracted or are reachable from
1368   // deleted/dead nodes. The input and outgoing edges should have already been
1369   // removed.
1370   std::deque<int> delete_nodes;
1371   std::vector<bool> deleted(graph_->num_node_ids(), false);
1372   // Don't try to delete source or sink nodes.
1373   deleted[graph_->kSourceId] = true;
1374   deleted[graph_->kSinkId] = true;
1375 
1376   // All remaining switch nodes that were not excluded from functionalization
1377   // according to `node_filter_` are not reachable from a merge node and
1378   // removed. This is to account for dead switch nodes.
1379   for (int s_id : switch_ids_) {
1380     Node* s = graph_->FindNodeId(s_id);
1381     if (s == nullptr) continue;
1382     for (const Edge* e : s->out_edges()) {
1383       // Control outputs of switch nodes (which are unconditionally executed if
1384       // the switch is) are not removed as they need not be part of a
1385       // conditional.
1386       if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1387     }
1388     // Only remove switch node if we have functionalized the corresponding
1389     // condition before (according to `node_filter_`).
1390     if (!node_filter_ || node_filter_(s)) {
1391       VLOG(2) << "Removing obsolete switch node " << s->name();
1392       deleted[s_id] = true;
1393       graph_->RemoveNode(s);
1394     }
1395   }
1396 
1397   // All merge nodes that were not excluded from functionalization according to
1398   // `node_filter_` should have been transformed at this point and we remove
1399   // them from the graph here.
1400   for (Node* m : merge_order) {
1401     for (const Edge* e : m->out_edges()) {
1402       // Similar to control outputs of switch nodes don't remove control
1403       // outputs of merge nodes.
1404       // TODO(jpienaar): Check cases where output edges still exist here vs
1405       // being removed in AddOutputEdges.
1406       if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1407     }
1408     // Only remove merge node if we have functionalized the corresponding
1409     // condition before (according to `node_filter_`).
1410     if (!node_filter_ || node_filter_(m)) {
1411       VLOG(2) << "Removing obsolete merge node " << m->name();
1412       deleted[m->id()] = true;
1413       graph_->RemoveNode(m);
1414     }
1415   }
1416 
1417   // Enqueue all the dead nodes.
1418   for (Node* n : graph_->nodes()) {
1419     if (state_map_.IsDead(state_map_.LookupCondId(n))) {
1420       delete_nodes.push_back(n->id());
1421     }
1422   }
1423   // Remove dead nodes and nodes that are reachable from dead nodes.
1424   while (!delete_nodes.empty()) {
1425     int d_id = delete_nodes.front();
1426     delete_nodes.pop_front();
1427     if (deleted[d_id]) continue;
1428     Node* d = graph_->FindNodeId(d_id);
1429     // Switch and Merge nodes could have been deleted already.
1430     if (d == nullptr) continue;
1431     for (const Edge* e : d->out_edges()) {
1432       delete_nodes.push_back(e->dst()->id());
1433     }
1434     VLOG(2) << "Removing obsolete node " << d->name();
1435     deleted[d_id] = true;
1436     graph_->RemoveNode(d);
1437   }
1438 }
1439 
SortMergeNodes(std::vector<Node * > * merge_order)1440 void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
1441   // Sort merge nodes by nesting depth.
1442   using sort_pair = std::pair<int, Node*>;
1443   std::vector<sort_pair> inner_to_outer_merge_order;
1444   inner_to_outer_merge_order.reserve(merge_order->size());
1445   for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
1446     Node* merge = *it;
1447     StateMap::CondId id = state_map_.LookupCondId(merge);
1448     int depth = id != nullptr ? id->size() : 0;
1449     inner_to_outer_merge_order.emplace_back(depth, merge);
1450   }
1451   std::stable_sort(
1452       inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
1453       [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
1454   merge_order->clear();
1455   for (sort_pair t : inner_to_outer_merge_order) {
1456     merge_order->push_back(t.second);
1457   }
1458 }
1459 
FunctionalizeInternal()1460 Status FunctionalizeCond::FunctionalizeInternal() {
1461   // The general approach for converting a tf.cond (as lowered via switch/merge
1462   // nodes) to a functional if is as follows:
1463   // 1. Determine the topological order and collect all the switch and merge
1464   // nodes in the graph;
1465   // 2. Compute the predicates and dominance structure for all the nodes in the
1466   // graph - this includes which predicate must be true for a op to execute
1467   // (predicate values are considered directly rather than attempting to
1468   // determine deeper equivalence). We shall refer to this structure as the
1469   // CondState;
1470   // 3. Sort the merge nodes by nesting depth;
1471   // 4. Extract merge nodes together that have the same CondState and
1472   // AncestorState from the innermost to the outermost into IfOps;
1473   // Note: In the above only nodes that feed into a merge node will be
1474   // considered for functionalization.
1475   // Note: Nodes for which `node_filter_` returns false are excluded.
1476 
1477   // Perform a DFS over the graph and
1478   // * Determine the reverse topological order of the nodes (there should be no
1479   //   cycles at this point so the post-order numbering corresponds to the
1480   //   reverse topological sorting);
1481   // * Record reverse topological for merge and switch nodes;
1482   std::vector<Node*> rev_topo_order;
1483   std::vector<Node*> merge_order;
1484   DFS(*graph_, nullptr, [&](Node* n) {
1485     // Only collect switch and merge nodes that are not filtered out, those form
1486     // the conditions that will be functionalized.
1487     if (!node_filter_ || node_filter_(n)) {
1488       if (IsSwitch(n)) {
1489         AddSwitchId(n->id());
1490       }
1491       if (IsMerge(n)) {
1492         merge_order.push_back(n);
1493       }
1494     }
1495     // Collect all other nodes here, independent of `node_filter_`, because they
1496     // might belong to a condition that should be functionalized.
1497     if (n->IsOp()) {
1498       rev_topo_order.push_back(n);
1499     }
1500   });
1501 
1502   // No merges to functionalize.
1503   if (merge_order.empty()) {
1504     // No merges mean no switch values consumed (as only considering values
1505     // fetchable as output of merge);
1506     DeleteReachableAndDeadNodes(merge_order);
1507     return OkStatus();
1508   }
1509 
1510   TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
1511   if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
1512 
1513   // Determine the shapes of the ops in the graph.
1514   ShapeRefiner shape_refiner{graph_->versions().producer(),
1515                              graph_->op_registry()};
1516   std::vector<Node*> nodes;
1517   GetReversePostOrder(*graph_, &nodes, NodeComparatorID());
1518   for (auto node : nodes) {
1519     if (!shape_refiner.AddNode(node).ok()) {
1520       LOG(WARNING) << "Couldn't deduce shape for " << node->name();
1521     }
1522   }
1523 
1524   // Sort the merge nodes from innermost outwards.
1525   SortMergeNodes(&merge_order);
1526 
1527   // Cluster merge nodes by (CondId, AncestorId, predicate) in order of
1528   // nesting. (CondId, AncestorId) is not enough, e.g.
1529   //   pred1 = array_ops.placeholder(dtypes.bool, name='pred1')
1530   //   pred2 = array_ops.placeholder(dtypes.bool, name='pred2')
1531   //   cond1 = control_flow_ops.cond(pred1, ...)
1532   //   cond2 = control_flow_ops.cond(pred2, ...)
1533   //   cond3 = control_flow_ops.cond(pred1, use cond1 and cond2)
1534   //   cond4 = control_flow_ops.cond(pred2, use cond1 and cond2)
1535   // cond3 and cond4 have the same (CondId, AncestorId), but they should not
1536   // be merged into one "If" node (because they have different predicates).
1537   std::deque<std::vector<Node*>> merge_clusters;
1538   std::map<ClusterTuple, int, ClusterTupleLessThan> merge_cluster_index;
1539   for (Node* merge : merge_order) {
1540     auto cond_id = state_map_.LookupCondId(merge);
1541     if (state_map_.IsDead(cond_id)) continue;
1542 
1543     auto predicate = merge_to_predicate_.find(merge);
1544     if (predicate == merge_to_predicate_.end()) {
1545       return errors::Internal("Cannot find predicate for Merge node ",
1546                               merge->name());
1547     }
1548 
1549     ClusterTuple key = std::make_tuple(
1550         cond_id, state_map_.LookupAncestorId(merge), predicate->second);
1551     auto idx = merge_cluster_index.find(key);
1552     if (idx == merge_cluster_index.end()) {
1553       merge_cluster_index[key] = merge_clusters.size();
1554       merge_clusters.push_back({merge});
1555     } else {
1556       merge_clusters[idx->second].emplace_back(merge);
1557     }
1558   }
1559 
1560   // Extract the conditionals from inner most to outer most. Extracting from
1561   // innermost to outermost enables the extraction pass to stop once it
1562   // encounters a Switch node instead of having to keep track of Switch/Merge
1563   // nodes seen.
1564   for (const auto& cluster : merge_clusters) {
1565     // Construct a Conditional with the predicate of the merge.
1566     Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_,
1567                      shape_refiner);
1568     for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
1569     TF_RETURN_IF_ERROR(
1570         cond.BuildAndReplace(graph_, library_, &merge_to_replacement_));
1571 
1572     if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
1573   }
1574 
1575   DeleteReachableAndDeadNodes(merge_order);
1576 
1577   return OkStatus();
1578 }
1579 
DumpGraphWithCondState(const string & name)1580 void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
1581   const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
1582 
1583   for (Node* n : graph_->nodes()) {
1584     n->ClearAttr(kCondGroupDebugAttr);
1585     n->AddAttr(kCondGroupDebugAttr,
1586                absl::StrCat(state_map_.CondStateToString(n), "_",
1587                             state_map_.AncestorStateToString(n)));
1588   }
1589   LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
1590             << DumpGraphToFile(absl::StrCat("functionalize_cond_", name),
1591                                *graph_, library_);
1592 }
1593 
AddSwitchId(int switch_id)1594 void FunctionalizeCond::AddSwitchId(int switch_id) {
1595   switch_ids_.push_back(switch_id);
1596 }
1597 
Functionalize(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1598 Status FunctionalizeCond::Functionalize(Graph* graph,
1599                                         FunctionLibraryDefinition* library,
1600                                         const NodeFilter& node_filter) {
1601   VLOG(1) << "FunctionalizeCond::Functionalize";
1602   FunctionalizeCond fc(graph, library, node_filter);
1603   return fc.FunctionalizeInternal();
1604 }
1605 
1606 }  // namespace functionalize_cond
1607 
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1608 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
1609                          const NodeFilter& node_filter) {
1610   // FunctionalizeControlFlow is invoked for every function, so the loops's
1611   // bodies and conditionals that were extracted into functions will be handled
1612   // in successive invocations.
1613   return functionalize_cond::FunctionalizeCond::Functionalize(graph, library,
1614                                                               node_filter);
1615 }
1616 
1617 }  // namespace tensorflow
1618