xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/conditional_code_motion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/conditional_code_motion.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/call_graph.h"
32 #include "tensorflow/compiler/xla/service/call_inliner.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_cse.h"
36 #include "tensorflow/compiler/xla/service/hlo_dce.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
41 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
42 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/platform/errors.h"
50 
51 namespace xla {
52 
53 namespace conditional_opt {
54 
55 class BoundaryVisitor {
56  public:
57   // start with an existing conditional computation.
BoundaryVisitor(HloInstruction * conditional)58   explicit BoundaryVisitor(HloInstruction* conditional) {
59     Boundary b(Boundary::Position::kInsideBranch);
60     b.mutable_operands().push_back(conditional);
61     worklist_.push_back(b);
62   }
63   // Start with an empty work list.
BoundaryVisitor()64   BoundaryVisitor() {}
65   // Get next boundary to visit.
PopNextBoundary()66   Boundary PopNextBoundary() {
67     CHECK(!worklist_.empty());
68     Boundary b = worklist_.front();
69     worklist_.pop_front();
70     // if b is already visited, it must have multiple users and is already in
71     // new boundaries. Skip it.
72     while (!worklist_.empty() && ContainsKey(visited_, b)) {
73       b = worklist_.front();
74       worklist_.pop_front();
75     }
76     visited_.insert(b);
77     return b;
78   }
AddToWorkList(const Boundary & b)79   void AddToWorkList(const Boundary& b) {
80     CHECK(!b.operands().empty());
81     worklist_.push_back(b);
82   }
83 
HasNextBoundary()84   bool HasNextBoundary() {
85     while (!worklist_.empty()) {
86       Boundary b = worklist_.front();
87       if (!ContainsKey(visited_, b)) {
88         break;
89       }
90       worklist_.pop_front();
91     }
92     return !worklist_.empty();
93   }
94 
95  private:
96   // worklist is the deque that contains instructions to be visited.
97   std::deque<Boundary> worklist_;
98   absl::flat_hash_set<Boundary> visited_;
99 };
100 
101 template <class OpCollection>
CountNonLeafOps(const OpCollection & ops)102 int64_t CountNonLeafOps(const OpCollection& ops) {
103   absl::flat_hash_set<HloInstruction*> op_set;
104   for (auto op : ops) {
105     if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) {
106       op_set.insert(op);
107     }
108   }
109   return op_set.size();
110 }
111 
112 // Returns estimation of potential reuses carried by a given pair of
113 // instructions.  Use different integers to classify different levels
114 // of reuses This is used as a placeholder only, assuming all
115 // instructions can be fused to enable data reuses
ReusesCarriedBy(HloOpcode op,HloOpcode user)116 int64_t ReusesCarriedBy(HloOpcode op, HloOpcode user) {
117   // Reuses in some way work like forces that pull instructions
118   // towards each other. We use a number 0-10 to classify how strong the force
119   // is between a pair of operations. Given a group of instructions that can be
120   // moved together, if the forces inside a conditional are stronger, the group
121   // will be moved incide or remain inside the conditional; otherwise, it will
122   // be moved outside to or remain outside of the conditional.
123   switch (user) {
124     case HloOpcode::kGetTupleElement:
125       return 0;
126     case HloOpcode::kConvert:
127       // Because convert is treated not moveable when following Dot or
128       // convolution, here if op is dot or convolution, they must be separated
129       // by a conditional boundary. Here we do not try to pull convert inside
130       // conditionals to be together with the dot or convolution.
131       switch (op) {
132         case HloOpcode::kConvolution:
133         case HloOpcode::kDot:
134           return 0;
135         default:
136           break;
137       }
138       break;
139     default:
140       break;
141   }
142   switch (op) {
143       // These instructions do not carry weight of reuse themselves.
144     case HloOpcode::kParameter:
145     case HloOpcode::kConstant:
146     case HloOpcode::kGetTupleElement:
147       return 0;
148     case HloOpcode::kConditional:
149       return 10;
150     default:
151       return -10;
152   }
153 }
154 
155 // Returns true if `op` is worth hoisting.
WorthHoisting(HloOpcode op,HloOpcode child_op)156 bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
157   // TOOD[b/169182921] The following cost model is rather incomplete. Will
158   // need to extend to cover most of element-wise ops.
159   switch (op) {
160     case HloOpcode::kConvert:
161       // If Convert is after AllReduce, it is worth moving out AllReduce
162       // out of conditional for AR/CRS combine. If Convert is after other
163       // ops such as Dot or Convolutional, it is better to keep convert
164       // within conditional so that convert can be fused with Dot or
165       // Convolutional.
166       switch (child_op) {
167         case HloOpcode::kAllReduce:
168         case HloOpcode::kReshape:
169         case HloOpcode::kGetTupleElement:
170           return true;
171         default:
172           return false;
173       }
174     case HloOpcode::kGetTupleElement:
175       switch (child_op) {
176         // do not move GTE if its operand is a parameter
177         case HloOpcode::kParameter:
178           return false;
179         default:
180           return true;
181       }
182     case HloOpcode::kAllReduce:
183     case HloOpcode::kReduceScatter:
184     case HloOpcode::kAbs:
185     case HloOpcode::kReduce:
186     case HloOpcode::kAdd:
187     case HloOpcode::kPower:
188     case HloOpcode::kCopy:
189     case HloOpcode::kConstant:
190     case HloOpcode::kSubtract:
191     case HloOpcode::kMultiply:
192     case HloOpcode::kDivide:
193     case HloOpcode::kTuple:
194     case HloOpcode::kSqrt:
195     case HloOpcode::kRsqrt:
196     case HloOpcode::kReshape:
197     case HloOpcode::kMinimum:
198     case HloOpcode::kMaximum:
199       return true;
200     default:
201       return false;
202   }
203 }
204 
205 // Compare if the instructions to be visited at each branches are identical.
InstructionWithinBranchIdentical(const std::vector<HloInstruction * > & instructions,bool is_layout_sensitive)206 bool InstructionWithinBranchIdentical(
207     const std::vector<HloInstruction*>& instructions,
208     bool is_layout_sensitive) {
209   // Identical includes the shape of each operands are equal.
210   auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
211     bool eq_operands = is_layout_sensitive
212                            ? ShapeUtil::Equal(a->shape(), b->shape())
213                            : ShapeUtil::Compatible(a->shape(), b->shape());
214     return eq_operands;
215   };
216 
217   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
218     return *a == *b;
219   };
220 
221   if (instructions.empty()) {
222     return false;
223   }
224 
225   if (instructions[0]->IsCrossModuleAllReduce()) {
226     return std::all_of(
227         instructions.begin(), instructions.end(),
228         [&](HloInstruction* instruction) {
229           if (!instruction->IsCrossModuleAllReduce()) {
230             return false;
231           }
232           auto old_channel_id = instruction->channel_id();
233           instruction->set_channel_id(instructions[0]->channel_id());
234           bool eq_instructions = instructions[0]->Identical(
235               *instruction, eq_operand, eq_computations, is_layout_sensitive);
236           instruction->set_channel_id(old_channel_id);
237           return eq_instructions;
238         });
239   }
240 
241   return std::all_of(instructions.begin(), instructions.end(),
242                      [&](HloInstruction* instruction) {
243                        return instructions[0]->Identical(
244                            *instruction, eq_operand, eq_computations,
245                            is_layout_sensitive);
246                      });
247 }
248 
249 // Copy the boundary out of the conditional and update hoisted_boundaries.
CopyOutOfConditional(Boundary & boundary,HloInstruction * conditional,absl::flat_hash_map<Boundary,Boundary> & hoisted_boundaries)250 void CopyOutOfConditional(
251     Boundary& boundary, HloInstruction* conditional,
252     absl::flat_hash_map<Boundary, Boundary>& hoisted_boundaries) {
253   CHECK(boundary.IsInsideBranch());
254   absl::InlinedVector<HloInstruction*, 4> new_operands;
255   // All of the branch operands should have the same opcode and shape, so just
256   // use branch 0.
257   const HloInstruction* branch0_inst = boundary.operands()[0];
258   for (int i = 0; i < branch0_inst->operands().size(); ++i) {
259     Boundary operand_boundary(boundary.GetPosition());
260     for (HloInstruction* operand : boundary.operands()) {
261       operand_boundary.mutable_operands().push_back(operand->operands()[i]);
262     }
263     VLOG(2) << "Looking for: " << operand_boundary.ToString();
264     auto hoisted_boundaries_it = hoisted_boundaries.find(operand_boundary);
265     CHECK(hoisted_boundaries_it != hoisted_boundaries.end());
266     Boundary hoisted_boundary = hoisted_boundaries_it->second;
267     CHECK(hoisted_boundary.IsOutsideBranch());
268     CHECK_EQ(hoisted_boundary.operands().size(), 1);
269     new_operands.push_back(hoisted_boundary.operands()[0]);
270   }
271   HloInstruction* new_instruction = conditional->parent()->AddInstruction(
272       branch0_inst->CloneWithNewOperands(branch0_inst->shape(), new_operands));
273   VLOG(2) << "new instruction:" << new_instruction->ToString();
274   // Maps the instruction outside of conditional to the instruction
275   // inside of the conditional.
276   Boundary hoisted_boundary(Boundary::Position::kOutsideBranch);
277   hoisted_boundary.mutable_operands().push_back(new_instruction);
278   hoisted_boundaries[boundary] = hoisted_boundary;
279 }
280 
281 // Copy the boundary into the conditional and update hoisted_boundaries.
CopyIntoConditional(Boundary & boundary,HloInstruction * conditional,absl::flat_hash_map<Boundary,Boundary> & hoisted_boundaries)282 void CopyIntoConditional(
283     Boundary& boundary, HloInstruction* conditional,
284     absl::flat_hash_map<Boundary, Boundary>& hoisted_boundaries) {
285   CHECK(boundary.IsOutsideBranch());
286   CHECK_EQ(boundary.operands().size(), 1);
287   int num_branches = conditional->branch_count();
288   std::vector<absl::InlinedVector<HloInstruction*, 4>> new_operands(
289       num_branches);
290   HloInstruction* op = boundary.operands()[0];
291   for (HloInstruction* operand : op->operands()) {
292     Boundary operand_boundary(boundary.GetPosition());
293     operand_boundary.mutable_operands().push_back(operand);
294     VLOG(2) << "Looking for: " << operand_boundary.ToString();
295     auto hoisted_boundaries_it = hoisted_boundaries.find(operand_boundary);
296     if (hoisted_boundaries_it != hoisted_boundaries.end()) {
297       Boundary hoisted_boundary = hoisted_boundaries_it->second;
298       CHECK(hoisted_boundary.IsInsideBranch());
299       CHECK_EQ(hoisted_boundary.operands().size(), num_branches);
300       for (int j = 0; j < num_branches; ++j) {
301         new_operands[j].push_back(hoisted_boundary.operands()[j]);
302       }
303     } else {
304       for (int j = 0; j < num_branches; ++j) {
305         switch (operand->opcode()) {
306           case HloOpcode::kConstant: {
307             auto new_operand =
308                 conditional->branch_computation(j)->AddInstruction(
309                     operand->Clone());
310             VLOG(2) << "new instruction:" << new_operand->ToString();
311             new_operands[j].push_back(new_operand);
312             break;
313           }
314           case HloOpcode::kGetTupleElement: {
315             auto gte = Cast<HloGetTupleElementInstruction>(operand);
316             int64_t index = gte->tuple_index();
317             HloInstruction* root =
318                 conditional->branch_computation(j)->root_instruction();
319             CHECK(root->opcode() == HloOpcode::kTuple &&
320                   index < root->operand_count())
321                 << root->ToString() << " " << gte->ToString();
322             auto new_operand = root->mutable_operand(index);
323             VLOG(2) << "new instruction:" << new_operand->ToString();
324             new_operands[j].push_back(new_operand);
325             break;
326           }
327           default:
328             LOG(FATAL) << "Unexpected out-of-boundary instruction:"
329                        << operand->ToString() << "\n";
330         }
331       }
332     }
333   }
334 
335   Boundary hoisted_boundary(Boundary::Position::kInsideBranch);
336   for (int j = 0; j < num_branches; ++j) {
337     HloInstruction* new_instruction =
338         conditional->branch_computation(j)->AddInstruction(
339             op->CloneWithNewOperands(op->shape(), new_operands[j]));
340     VLOG(2) << "new instruction:" << new_instruction->ToString();
341     hoisted_boundary.mutable_operands().push_back(new_instruction);
342   }
343   hoisted_boundaries[boundary] = hoisted_boundary;
344 }
345 
346 // Identify converts to be hoisted/rematerialized out of the branch
347 // computations.
FindSpecialConverts(HloInstruction * old_root,int branch_count,HloInstruction * conditional,bool is_layout_sensitive)348 absl::flat_hash_set<int64_t> FindSpecialConverts(HloInstruction* old_root,
349                                                  int branch_count,
350                                                  HloInstruction* conditional,
351                                                  bool is_layout_sensitive) {
352   absl::flat_hash_set<int64_t> special_convert;
353 
354   // TODO(b/216487727): Allow hoisting converts that feed or fed by other
355   // converts by addressing possible duplicates left behind in the tuple output.
356   // The conditional code motion pass should handle these duplicates and hence,
357   // merging these snippets of code would be one alternative.
358   auto convert_invalid =
359       [](const HloInstruction* convert_set_candidate) -> bool {
360     bool invalid_user = absl::c_any_of(
361         convert_set_candidate->users(), [](const HloInstruction* user) -> bool {
362           return (user->opcode() == HloOpcode::kConvert);
363         });
364     bool invalid_producer =
365         absl::c_any_of(convert_set_candidate->operands(),
366                        [](const HloInstruction* operand) -> bool {
367                          return (operand->opcode() == HloOpcode::kConvert);
368                        });
369     return (invalid_user || invalid_producer);
370   };
371 
372   for (int64_t operand_num = 0; operand_num < old_root->operand_count();
373        ++operand_num) {
374     if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
375       continue;
376     }
377     bool replica = true;
378     HloInstruction* special_convert_candidate =
379         old_root->mutable_operand(operand_num);
380     // TODO(b/216487727): Remove duplicates in tuple outputs while hoisting.
381     auto repeated =
382         absl::c_count_if(old_root->operands(),
383                          [&](const HloInstruction* operand) -> bool {
384                            return (special_convert_candidate == operand);
385                          }) > 1;
386     if (convert_invalid(special_convert_candidate) || repeated) {
387       continue;
388     }
389     // Check whether an identical candidate appears in other branches
390     for (int others = 1; others < branch_count; ++others) {
391       HloInstruction* others_root =
392           conditional->branch_computation(others)->root_instruction();
393       const HloInstruction* other_convert = others_root->operand(operand_num);
394       if (other_convert->opcode() != HloOpcode::kConvert ||
395           convert_invalid(other_convert)) {
396         replica = false;
397         break;
398       }
399       // Do not move converts if their operands have different shapes in
400       // different branches.
401       bool eq_shape =
402           is_layout_sensitive
403               ? ShapeUtil::Equal(other_convert->shape(),
404                                  special_convert_candidate->shape()) &&
405                     ShapeUtil::Equal(
406                         other_convert->operand(0)->shape(),
407                         special_convert_candidate->operand(0)->shape())
408               : ShapeUtil::Compatible(other_convert->shape(),
409                                       special_convert_candidate->shape()) &&
410                     ShapeUtil::Compatible(
411                         other_convert->operand(0)->shape(),
412                         special_convert_candidate->operand(0)->shape());
413       if (!eq_shape) {
414         replica = false;
415         break;
416       }
417       auto repeated =
418           absl::c_count_if(others_root->operands(),
419                            [&](const HloInstruction* operand) -> bool {
420                              return (special_convert_candidate == operand);
421                            }) > 1;
422       if (repeated) {
423         replica = false;
424         break;
425       }
426     }
427     if (replica) {
428       special_convert.insert(operand_num);
429     }
430   }
431   return special_convert;
432 }
433 
434 // Restructuring the conditional instruction as follows:
435 // i.e., %result = conditional() becomes
436 // x = conditional()
437 // y.{0..n} = gte(x, {0..n})
438 // z = tuple(y.0, y.1, ...y.n)
439 // Doing so ensures that we can accommodate the possible shape-change of the
440 // conditional when the instructions are hoisted.
RestructureConditionalInstruction(HloComputation * computation,HloInstruction * conditional)441 Status RestructureConditionalInstruction(HloComputation* computation,
442                                          HloInstruction* conditional) {
443   HloInstruction* old_root = computation->root_instruction();
444   std::vector<HloInstruction*> new_operands;
445   int cur_index = 0;
446   for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
447        ++cur_index) {
448     new_operands.push_back(
449         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
450             ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
451             conditional, cur_index)));
452   }
453   HloInstruction* new_tuple =
454       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
455   if (old_root == conditional) {
456     computation->set_root_instruction(new_tuple);
457   } else {
458     std::vector<HloInstruction*> new_tuple_users;
459     for (auto conditional_user : conditional->users()) {
460       auto is_new_gte = absl::c_find_if(
461           new_operands,
462           [&](HloInstruction* instr) { return instr == conditional_user; });
463       if (is_new_gte == new_operands.end()) {
464         new_tuple_users.push_back(conditional_user);
465       }
466     }
467     for (auto new_tuple_user : new_tuple_users) {
468       TF_RETURN_IF_ERROR(
469           conditional->ReplaceUseWith(new_tuple_user, new_tuple));
470     }
471   }
472   VLOG(2) << "computation after root restructure:\n" << computation->ToString();
473   return OkStatus();
474 }
475 
ConvertSpecialMove(HloInstruction * conditional,bool is_layout_sensitive)476 StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
477                                   bool is_layout_sensitive) {
478   int branch_count = conditional->branch_count();
479   if (branch_count <= 0) {
480     return false;
481   }
482 
483   // Determining whether all branch roots are tuples
484   for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
485     HloInstruction* branch_root =
486         conditional->branch_computation(branch_num)->root_instruction();
487     if (branch_root->opcode() != HloOpcode::kTuple) {
488       return false;
489     }
490   }
491 
492   HloInstruction* old_root =
493       conditional->branch_computation(0)->root_instruction();
494   VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
495   // Identify the gte using `index'.
496   auto find_gte = [](const HloInstruction* conditional_result,
497                      int64_t index) -> HloInstruction* {
498     for (HloInstruction* instr : conditional_result->users()) {
499       if (instr->opcode() != HloOpcode::kGetTupleElement) {
500         return nullptr;
501       }
502       if (instr->tuple_index() == index) {
503         return instr;
504       }
505     }
506     return nullptr;
507   };
508 
509   // Captures tuple indices refering to converts to be rematerialized/hoisted.
510   absl::flat_hash_set<int64_t> special_convert = FindSpecialConverts(
511       old_root, branch_count, conditional, is_layout_sensitive);
512 
513   // Exit if we cannot find any converts to be hoisted.
514   if (special_convert.empty()) {
515     return false;
516   }
517 
518   TF_RETURN_IF_ERROR(
519       RestructureConditionalInstruction(conditional->parent(), conditional));
520 
521   for (int branch = 0; branch < branch_count; branch++) {
522     old_root = conditional->branch_computation(branch)->root_instruction();
523     absl::flat_hash_map<HloInstruction*, int64_t> map_inst_to_tuple_index;
524     std::vector<HloInstruction*> new_operands(old_root->operand_count());
525     absl::flat_hash_set<HloInstruction*> to_hoist_set;
526 
527     for (int64_t operand_num = 0; operand_num < old_root->operand_count();
528          ++operand_num) {
529       map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
530           operand_num;
531     }
532     for (int64_t operand_num = 0; operand_num < old_root->operand_count();
533          ++operand_num) {
534       HloInstruction* hoist = old_root->mutable_operand(operand_num);
535       if (!special_convert.contains(operand_num)) {
536         new_operands[operand_num] = old_root->mutable_operand(operand_num);
537         continue;
538       }
539 
540       to_hoist_set.insert(hoist);
541       int64_t new_tuple_count = old_root->operand_count();
542 
543       // Replace the hoisted instr in the tuple with the operand/operands.
544       // We will replace at least one of the operands of the hoist at the
545       // tuple place; the rest will be added at the end.
546       bool inplace = true;
547       CHECK(!hoist->operands().empty());
548       for (HloInstruction* prod : hoist->operands()) {
549         if (inplace) {
550           map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
551           new_operands[map_inst_to_tuple_index[hoist]] = prod;
552           inplace = false;
553         } else {
554           map_inst_to_tuple_index[prod] = new_tuple_count++;
555           new_operands.push_back(prod);
556         }
557       }
558     }
559 
560     // Create the new root instruction.
561     HloComputation* cur_branch = conditional->branch_computation(branch);
562     HloInstruction* new_branch_root =
563         cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
564     // The shape can vary since the operands to convert are now
565     // being returned through the branches' root.
566     cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
567     TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
568 
569     // Only one of the branches needs to change the conditional->parent().
570     if (branch != 0) {
571       continue;
572     }
573     HloComputation* conditional_parent = conditional->parent();
574     HloInstruction* newconditional =
575         conditional_parent->AddInstruction(HloInstruction::CreateConditional(
576             cur_branch->root_instruction()->shape(),
577             conditional->mutable_operand(0),
578             absl::MakeSpan(conditional->branch_computations()),
579             absl::MakeSpan(conditional->operands()).subspan(1)));
580     // Ensure that all the users of conditional refer to the new one.
581     TF_RETURN_IF_ERROR(
582         conditional->ReplaceAllUsesWithDifferentShape(newconditional));
583     TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
584     conditional = newconditional;
585     // Add the hoisted instructions in the parent.
586     for (HloInstruction* hoist : to_hoist_set) {
587       VLOG(2) << "Hoisting instruction:" << hoist->ToString();
588       int64_t hoist_index = map_inst_to_tuple_index[hoist];
589       // Find out the gte that captured the hoisted instr result.
590       HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
591       CHECK(gte_hoist != nullptr);
592       std::vector<HloInstruction*> new_operands;
593       for (HloInstruction* op : hoist->operands()) {
594         HloInstruction* gte = conditional_parent->AddInstruction(
595             HloInstruction::CreateGetTupleElement(op->shape(), conditional,
596                                                   map_inst_to_tuple_index[op]));
597         new_operands.push_back(gte);
598       }
599       HloInstruction* hoisted = conditional_parent->AddInstruction(
600           hoist->CloneWithNewOperands(hoist->shape(), new_operands));
601       VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
602       TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
603       TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
604     }
605     // No need to explicitly delete a hoisted instruction since if its dead
606     // then the subsequent DCE will remove it.
607   }
608   VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
609   return true;
610 }
611 
612 // Hoist identical ops out of the conditional. The definition of identical
613 // are the shape of the operands are identical and their properties are
614 // identical. Will start from the root instruction of each branch and get
615 // the identical ops to hoist.
MoveInstructionOut(HloInstruction * conditional,std::vector<Boundary> & to_move_out,std::vector<Boundary> & new_boundaries)616 StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
617     HloInstruction* conditional, std::vector<Boundary>& to_move_out,
618     std::vector<Boundary>& new_boundaries) {
619   if (to_move_out.empty()) {
620     return false;
621   }
622   VLOG(1) << "Modifying code--number of boundaries to move out:"
623           << to_move_out.size() << "\n";
624   HloComputation* conditional_parent = conditional->parent();
625   // save the old users before add new conditional user instructions
626   std::vector<HloInstruction*> old_conditional_users = conditional->users();
627   // Maps boundaries in the conditional body to boundaries hoisted outside
628   // the conditional that compute the same value.
629   absl::flat_hash_map<Boundary, Boundary> hoisted_boundaries;
630   // Insert GetTupleElement before the instructions whose operands might still
631   // be within the conditional.
632   VLOG(1) << "before opt:"
633           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
634           << "\n";
635   int64_t op_index = 0;
636   for (const Boundary& b : new_boundaries) {
637     HloInstruction* op = b.operands()[0];
638     CHECK(op != nullptr);
639     VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
640     HloInstruction* gtr = conditional_parent->AddInstruction(
641         HloInstruction::CreateGetTupleElement(op->shape(), conditional,
642                                               op_index++));
643     Boundary b2(Boundary::Position::kOutsideBranch);
644     b2.mutable_operands().push_back(gtr);
645     hoisted_boundaries[b] = b2;
646   }
647   // Copy boundary instructions out of the conditional.
648   // Visit the operands before its users and copy it, so that the copied
649   // user will point to the correct operand.
650   for (int64_t i = to_move_out.size() - 1; i >= 0; i--) {
651     CopyOutOfConditional(to_move_out[i], conditional, hoisted_boundaries);
652   }
653   VLOG(2) << "Done copy branch instructions out\n"
654           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
655           << "\n";
656   // Change original users of the conditional to use the correct operands.
657   for (auto user_instr : old_conditional_users) {
658     VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
659     CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
660     auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
661     int64_t index = tuple_opd->tuple_index();
662     Boundary old_user_boundary(Boundary::Position::kInsideBranch);
663     for (const HloComputation* called_computation :
664          conditional->called_computations()) {
665       HloInstruction* root = called_computation->root_instruction();
666       CHECK(root->operands().size() > index);
667       old_user_boundary.mutable_operands().push_back(root->operands()[index]);
668     }
669     CHECK(ContainsKey(hoisted_boundaries, old_user_boundary));
670     HloInstruction* new_opd =
671         hoisted_boundaries[old_user_boundary].operands()[0];
672     CHECK(new_opd != nullptr);
673     VLOG(2) << "Try replace all uses of :" << old_user_boundary.ToString()
674             << "\n";
675     TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
676     TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
677   }
678   VLOG(2) << "Done changing conditional users\n"
679           << conditional_parent->ToString() << "\n";
680   // Create tuple element within each branch and set it as root.
681   int64_t branch_count = conditional->branch_count();
682   for (int i = 0; i < branch_count; i++) {
683     auto computation = conditional->branch_computation(i);
684     std::vector<HloInstruction*> elements;
685     for (const auto& b1 : new_boundaries) {
686       HloInstruction* op = b1.operands()[i];
687       CHECK(op != nullptr);
688       VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
689       elements.push_back(op);
690     }
691     HloInstruction* tuple =
692         computation->AddInstruction(HloInstruction::CreateTuple(elements));
693     computation->set_root_instruction(tuple, true);
694     VLOG(2) << "computation is :" << computation->ToString() << "\n";
695     // Remove hoisted instructions from the branches.
696     for (const auto& b2 : to_move_out) {
697       auto instr_to_remove = b2.operands()[i];
698       // Double check to make sure it is safe to delete the instruction.
699       // Complications may arise due to some operations in the alternative
700       // branches (branches 1..n) being placed into the boundaries multiple
701       // times.
702       if (!computation->IsMarkedAsDead(instr_to_remove) &&
703           instr_to_remove->IsDead()) {
704         VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
705         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove));
706       }
707     }
708   }
709   // Change conditional instruction shape to the shape of the new root.
710   HloInstruction* new_root =
711       conditional->branch_computation(0)->root_instruction();
712   *conditional->mutable_shape() = new_root->shape();
713   // Keep conditional instruction sharding consistent with the branches. Note
714   // that this sharding could be lost after this pass.
715   conditional->set_sharding(new_root->sharding_ptr());
716   VLOG(1) << "done moving instructions out of branches\n"
717           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
718           << "\n";
719   return true;
720 }
721 
722 // Hoist ops from outside of the conditional to inside the branches.
MoveInstructionIn(HloInstruction * conditional,std::vector<Boundary> & to_move_in,std::vector<Boundary> & new_boundaries)723 StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
724     HloInstruction* conditional, std::vector<Boundary>& to_move_in,
725     std::vector<Boundary>& new_boundaries) {
726   if (to_move_in.empty()) {
727     return false;
728   }
729   VLOG(1) << "Modifying code---number of boundaries to move in:"
730           << to_move_in.size() << "\n";
731   VLOG(1) << "before opt:"
732           << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
733           << "\n";
734   // Mapping boundaries to be moved to their new representations.
735   absl::flat_hash_map<Boundary, Boundary> hoisted_boundaries;
736   int64_t to_move_in_size = to_move_in.size();
737   int64_t branch_count = conditional->branch_count();
738   HloGetTupleElementInstruction* tuple_use =
739       DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
740   // If use_index is -1, the old conditional root entry used by to_move_in
741   // instructions still need to be included as an entry of the modified
742   // conditional root, and the new result of the to_move_in instructions
743   // need to be added as an extra entry of the modified root; otherwise, the
744   // old root entry will be replaced with the new result in the modified root.
745   // The entry replacement should be allowed only if tuple_use has <=1 users.
746   int64_t use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
747                           ? tuple_use->tuple_index()
748                           : -1;
749   VLOG(2) << "Tuple use index = " << use_index << "\n";
750   // Number of old conditional entries still to be used outside.
751   // If conditional shape is not tuple, will create a tuple and use subscript
752   // 0 to save the old operand being used.
753   int64_t op_index =
754       conditional->shape().IsTuple()
755           ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
756                               : conditional->shape().tuple_shapes_size())
757           : 0;
758   // Use to map the tuple_use instruction to its operand;
759   Boundary b_opd_use(Boundary::Position::kInsideBranch);
760   Boundary b_old_root(Boundary::Position::kInsideBranch);
761   // Create a new root instruction in each branch.
762   for (int i = 0; i < branch_count; i++) {
763     auto computation = conditional->branch_computation(i);
764     auto old_root = computation->root_instruction();
765     b_old_root.mutable_operands().push_back(old_root);
766     std::vector<HloInstruction*> operands;
767     if (old_root->opcode() == HloOpcode::kTuple) {
768       // Use operands of old_root directly, so old_root can be removed later.
769       for (int i = 0; i < old_root->operand_count(); ++i) {
770         if (i != use_index) {
771           operands.push_back(old_root->operands()[i]);
772         } else {  // Map conditional use to the tuple operand.
773           b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
774         }
775       }
776     } else if (old_root->shape().IsTuple()) {
777       // If old_root is not a kTuple but has tuple shape, elements within the
778       // tuple must be extracted first to be used by the new instructions.
779       const Shape& old_shape = old_root->shape();
780       for (int i = 0; i < old_shape.tuple_shapes_size(); ++i) {
781         auto element =
782             computation->AddInstruction(HloInstruction::CreateGetTupleElement(
783                 old_shape.tuple_shapes(i), old_root, i));
784         if (i != use_index) {
785           operands.push_back(element);
786         } else {
787           b_opd_use.mutable_operands().push_back(element);
788         }
789       }
790     } else {
791       // If old_root is not a tuple and does not have tuple shape, use it
792       // to replace the conditional directly in the new computation.
793       b_opd_use.mutable_operands().push_back(conditional);
794     }
795 
796     HloInstruction* new_root =
797         computation->AddInstruction(HloInstruction::CreateTuple(operands));
798     VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
799     computation->set_root_instruction(new_root,
800                                       /*accept_different_shape*/ true);
801     if (old_root->opcode() == HloOpcode::kTuple) {
802       TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
803     }
804     VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
805   }
806   // Update get tuple element index of the conditional.
807   if (use_index != -1) {
808     for (auto* user : conditional->users()) {
809       if (user->opcode() == HloOpcode::kGetTupleElement &&
810           user->tuple_index() > use_index) {
811         user->set_tuple_index(user->tuple_index() - 1);
812       }
813     }
814   }
815   Boundary conditional_boundary(Boundary::Position::kOutsideBranch);
816   conditional_boundary.mutable_operands().push_back(conditional);
817   hoisted_boundaries[conditional_boundary] = b_old_root;
818   int64_t cp_start = 0;
819   if (use_index >= 0) {
820     VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
821     Boundary tuple_use_boundary(Boundary::Position::kOutsideBranch);
822     tuple_use_boundary.mutable_operands().push_back(tuple_use);
823     hoisted_boundaries[tuple_use_boundary] = b_opd_use;
824   }
825   cp_start = (tuple_use != nullptr) ? 1 : 0;
826   for (int64_t to_move_index = cp_start; to_move_index < to_move_in_size;
827        to_move_index++) {
828     Boundary b_to_move = to_move_in[to_move_index];
829     HloInstruction* op = b_to_move.operands()[0];
830     CHECK(op != nullptr);
831     bool to_be_used_outside = true;
832     VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
833     if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
834         op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
835       to_be_used_outside = false;
836       VLOG(2) << "Instruction is not to be used outside the branch\n";
837     }
838     Boundary b(Boundary::Position::kInsideBranch);
839     CopyIntoConditional(b_to_move, conditional, hoisted_boundaries);
840     if (to_be_used_outside) {
841       for (int i = 0; i < branch_count; ++i) {
842         auto computation = conditional->branch_computation(i);
843         auto new_op = hoisted_boundaries[b_to_move].operands()[i];
844         auto new_root = computation->root_instruction();
845         new_root->AppendOperand(new_op);
846         *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape();
847         VLOG(2) << "Extending conditional root " << i << " : "
848                 << new_root->ToString() << "\n";
849       }
850       // Modify uses of instructions outside of the conditionals
851       HloInstruction* gtr = conditional->parent()->AddInstruction(
852           HloInstruction::CreateGetTupleElement(op->shape(), conditional,
853                                                 op_index++));
854       TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
855       if (conditional->parent()->root_instruction() == op) {
856         conditional->parent()->set_root_instruction(gtr);
857       }
858     }
859   }
860   VLOG(2) << "Done copying instructions inside branch: "
861           << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n";
862   // Change conditional instruction shape to the shape of the new root.
863   HloInstruction* new_root =
864       conditional->branch_computation(0)->root_instruction();
865   *conditional->mutable_shape() = new_root->shape();
866   // Keep conditional instruction sharding consistent with the branches. Note
867   // that this sharding could be lost after this pass.
868   conditional->set_sharding(new_root->sharding_ptr());
869   VLOG(2) << "Before removing instructions:"
870           << conditional->parent()->ToString() << "\n";
871   // Remove hoisted instructions from the branches.
872   for (int64_t i = to_move_in_size - 1; i >= 0; i--) {
873     Boundary boundary_to_move_in = to_move_in[i];
874     HloInstruction* op = boundary_to_move_in.operands()[0];
875     if (op->user_count() == 0) {
876       VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
877       TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
878       VLOG(2) << "Done removing boundary.\n";
879     }
880   }
881 
882   // Reset shapes of user gtes to the new shape.
883   if (use_index != -1) {
884     for (auto* user : conditional->users()) {
885       if (user->opcode() == HloOpcode::kGetTupleElement) {
886         VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
887         *user->mutable_shape() =
888             conditional->shape().tuple_shapes(user->tuple_index());
889       }
890     }
891   }
892   VLOG(1) << "Done moving instructions inside branches\n"
893           << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
894           << "\n";
895   return true;
896 }
897 
898 // Group single chains of operands or uses of boundaries into new boundaries
899 class GroupConnectedBoundaries {
900  private:
901   std::vector<Boundary> connected_boundaries_, new_boundaries_;
902   HloInstruction* conditional_;
903   HloComputation* conditional_parent_;
904   bool is_layout_sensitive_;
905   // Instructions that have been visited but are not going to be moved.
906   absl::flat_hash_map<HloInstruction*, int>& visited_count_;
907   // The following four lines are configurations of the cost model, which will
908   // be used to determine whether to move an instruction (move_config_) and how
909   // strongly preferred it is to keep a pair of ops together (reuse_config_).
910   // The search_config_ is used to control how to navigate the search space of
911   // the cost model in the context of auto/manual tuning. The flipped array is
912   // used to save which entries in the configuration have been changed in the
913   // search/tuning process.
914   std::vector<std::vector<int64_t>>& move_config_;
915   std::vector<std::vector<int64_t>>& reuse_config_;
916   absl::Span<int64_t> search_config_vec_;
917   int64_t& search_config_;
918   int64_t search_subscript_;
919   absl::flat_hash_map<const int64_t*, int64_t> flipped_;
920 
921   // The FlipMutation function serves to implement the search of alternative
922   // cost models by deciding whether to flip a given configuration, saved in
923   // the loc parameter. The non_zero parameter provides the new value to use
924   // to flip a zero. The msg parameter is only used for debugging purpposes.
FlipMutation(int64_t * loc,const int64_t non_zero,const std::string & msg)925   int64_t FlipMutation(int64_t* loc, const int64_t non_zero,
926                        const std::string& msg) {
927     if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
928       VLOG(2) << "Configured not to search or loc is already flipped.";
929       return *loc;
930     }
931     // The last 8 digits control when to start the first flip.
932     int c = ConditionalCodeMotion::flip_start(search_config_);
933     VLOG(2) << "flip start index = " << c << "\n";
934     // Only flip the decision if c reaches 0.
935     if (c > 0) {
936       search_config_--;
937       return *loc;
938     }
939     // The 8-16 digits control the maximum number of times to flip a config.
940     auto flip_count = ConditionalCodeMotion::DecrementMaxFlip(&search_config_);
941     VLOG(2) << "max flip count = " << flip_count << "\n";
942     VLOG(2) << "Updating max Flipping configuration = " << search_config_
943             << "\n";
944     if (flip_count == 0) {
945       VLOG(2) << "Maximum flip count has reached. ";
946       if (search_subscript_ + 1 < search_config_vec_.size()) {
947         VLOG(2) << "search_subscript_ = " << search_subscript_;
948         VLOG(2) << "search config vec size = " << search_config_vec_.size();
949         search_config_ = search_config_vec_[++search_subscript_];
950       } else {
951         return *loc;
952       }
953     }
954     // Reload the 16-23 digits of the configuration, which controls how
955     // frequently a configuration should be flipped.
956     auto flip_stride = ConditionalCodeMotion::flip_stride(search_config_);
957     search_config_ += flip_stride;
958     VLOG(2) << "flip stride = " << flip_stride << "\n";
959     VLOG(2) << "Updating Flipping Stride = " << search_config_ << "\n";
960 
961     flipped_[loc] = *loc;
962     // Copy the last 8 bits back to the first 8 bits of configuration.
963     switch (*loc) {
964       case 0:
965         *loc = non_zero;
966         break;
967       default:
968         *loc = 0;
969         break;
970     }
971     VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
972             << " to " << *loc << "\n";
973     return *loc;
974   }
975 
EnsureSearchConfig(std::vector<int64_t> & search_config)976   static std::vector<int64_t>& EnsureSearchConfig(
977       std::vector<int64_t>& search_config) {
978     if (search_config.empty()) {
979       search_config.push_back(0);
980     }
981     return search_config;
982   }
983 
984  public:
GroupConnectedBoundaries(HloInstruction * conditional,bool is_layout_sensitive,absl::flat_hash_map<HloInstruction *,int> & visited_count,std::vector<std::vector<int64_t>> * move_config,std::vector<std::vector<int64_t>> * reuse_config,std::vector<int64_t> & search_config)985   explicit GroupConnectedBoundaries(
986       HloInstruction* conditional, bool is_layout_sensitive,
987       absl::flat_hash_map<HloInstruction*, int>& visited_count,
988       std::vector<std::vector<int64_t>>* move_config,
989       std::vector<std::vector<int64_t>>* reuse_config,
990       std::vector<int64_t>& search_config)
991       : conditional_(conditional),
992         conditional_parent_(conditional->parent()),
993         is_layout_sensitive_(is_layout_sensitive),
994         visited_count_(visited_count),
995         move_config_(*move_config),
996         reuse_config_(*reuse_config),
997         search_config_vec_(EnsureSearchConfig(search_config)),
998         search_config_(search_config_vec_.front()),
999         search_subscript_(0) {
1000     VLOG(2) << "Initializing Group Connected Boundaries\n";
1001   }
1002   // Returns estimation of potential reuses carried by a given pair of
1003   // instructions. Use different integers to classify different levels
1004   // of reuses. Assume all instructions can be fused to enable data reuses.
ReusesCarriedBy(HloInstruction * op,HloInstruction * user)1005   int64_t ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
1006     std::vector<int64_t>& curconfig =
1007         reuse_config_[static_cast<uint32_t>(op->opcode())];
1008     // Flip the reuse configuration if tuning the cost model.
1009     // When flipping, use -10 if flipping to the default reuse model. Other
1010     // values can be specified if needed to fine-control the decision making.
1011     int64_t config =
1012         (search_config_ < 0)
1013             ? FlipMutation(&curconfig[static_cast<uint32_t>(user->opcode())],
1014                            -10,
1015                            HloOpcodeString(op->opcode()) + "->" +
1016                                HloOpcodeString(user->opcode()))
1017             : curconfig[static_cast<uint32_t>(user->opcode())];
1018     VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
1019             << op->ToString() << "=>" << user->ToString() << " : " << config
1020             << "\n";
1021     if (config < 0) {
1022       // Assume the reuse decreases with increasing user count.
1023       int count1 = CountNonLeafOps(op->users());
1024       int count2 = CountNonLeafOps(user->operands());
1025       return (-config) / count1 / count2;
1026     }
1027     return config;
1028   }
clear_recently_visited()1029   void clear_recently_visited() {
1030     for (const auto& boundary : new_boundaries_) {
1031       visited_count_.erase(boundary.operands()[0]);
1032     }
1033   }
1034   // Returns true if `instruction` is worth hoisting.
WorthHoisting(HloInstruction * instruction,bool is_inside_branch)1035   bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
1036     // This is needed for the "moving-in" transformation, to prevent the root
1037     // of the parent computation (which contains the conditional) to be moved
1038     // inside the conditional.
1039     HloOpcode opcode = instruction->opcode();
1040     if (opcode == HloOpcode::kTuple &&
1041         instruction == conditional_parent_->root_instruction()) {
1042       return false;
1043     }
1044     // It is not safe to move collective ops from outside to inside
1045     // conditional branches, as it may cause synchronization problems,
1046     // when different layouts are assigned to different branches.
1047     if (DynCast<HloCollectiveInstruction>(instruction) && !is_inside_branch) {
1048       return false;
1049     }
1050 
1051     // It is not legal to move the parameter instructions.
1052     if (opcode == HloOpcode::kParameter) {
1053       return false;
1054     }
1055 
1056     // Use configuration given from outside (e.g., by autotuner).
1057     std::vector<int64_t>& curconfig =
1058         move_config_[static_cast<uint32_t>(opcode)];
1059     auto col = (curconfig.size() == 1) ? 0
1060                : (instruction->operand_count() > 0)
1061                    ? static_cast<uint32_t>(instruction->operand(0)->opcode())
1062                    : 0;
1063     VLOG(2) << "column = " << col << "\n";
1064     VLOG(2) << "config size = " << curconfig.size() << "\n";
1065     VLOG(2) << "search_config = " << search_config_ << "\n";
1066     CHECK(col < curconfig.size());
1067     uint32_t config = (search_config_ > 0)
1068                           ? FlipMutation(&curconfig[col], 1,
1069                                          "Move-" + HloOpcodeString(opcode))
1070                           : curconfig[col];
1071     VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
1072     VLOG(2) << "after checking search_config = " << search_config_ << "\n";
1073     return (config != 0);
1074   }
1075 
ReusesBeforeBoundary(HloInstruction * user)1076   int64_t ReusesBeforeBoundary(HloInstruction* user) {
1077     int64_t reuses = 0;
1078     for (auto op : user->operands()) {
1079       // The operand must be an instruction that is not going to be moved (if
1080       // user is inside the conditional); otherwise it must be the conditional
1081       // itself and its user must be outside of the conditional.
1082       if (!ContainsKey(visited_count_, op) && op != conditional_) {
1083         continue;
1084       }
1085       if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
1086         if (op->opcode() == HloOpcode::kConditional) {
1087           auto tuple = op->branch_computation(0)->root_instruction();
1088           if (tuple->opcode() == HloOpcode::kTuple) {
1089             auto index = tuple_gte->tuple_index();
1090             CHECK(index < tuple->operand_count());
1091             op = tuple->mutable_operand(index);
1092           }
1093         }
1094         reuses += ReusesCarriedBy(op, user->users()[0]);
1095       } else {
1096         reuses += ReusesCarriedBy(op, user);
1097       }
1098     }
1099     VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
1100             << "\n";
1101     return reuses;
1102   }
1103 
ReusesAfterBoundary(HloInstruction * user)1104   int64_t ReusesAfterBoundary(HloInstruction* user) {
1105     CHECK(user != nullptr);
1106     auto all_users = user->users();
1107     // For now, assume that if an instruction has multiple-consumers, it
1108     // will not be reused, as the reuse may require duplication in
1109     // fusion and so is expensive. If the situation changes in the future,
1110     // some aspects of the overall algorithm need to be redesigned to
1111     // accommandate the change.
1112     if (all_users.size() > 1) {
1113       VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
1114       return 0;
1115     }
1116     if (!all_users.empty()) {
1117       auto op = all_users[0];
1118       int64_t reuses = 0;
1119       // Only count reuses that run through the conditional root.
1120       if (op == conditional_->branch_computation(0)->root_instruction()) {
1121         int64_t index = op->operand_index(user);
1122         for (auto op2 : conditional_->users()) {
1123           // If the use is not get tuple, right now do not consider it.
1124           if (op2->opcode() == HloOpcode::kGetTupleElement) {
1125             auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2);
1126             if (index == tuple_opd->tuple_index()) {
1127               all_users = op2->users();
1128               if (!all_users.empty()) {
1129                 reuses += ReusesCarriedBy(user, all_users[0]);
1130                 break;
1131               }
1132             }
1133           }
1134         }
1135       } else if (ContainsKey(visited_count_, op)) {
1136         reuses += ReusesCarriedBy(user, op);
1137       }
1138       VLOG(2) << "reuses after instruction " << user->ToString() << ":"
1139               << reuses << "\n";
1140       return reuses;
1141     }
1142     return 0;
1143   }
1144 
BenefitForMovingBoundaries(const std::vector<Boundary> & boundaries,bool perform_reuse_analysis=true)1145   int64_t BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries,
1146                                      bool perform_reuse_analysis = true) {
1147     int64_t reuses_before = 0, reuses_after = 0;
1148     if (boundaries.size() == 1) {
1149       if (boundaries[0].IsOutsideBranch() &&
1150           boundaries[0].operands()[0]->opcode() ==
1151               HloOpcode::kGetTupleElement) {
1152         // The only boundary of moving-in is the get_tuple_element op.
1153         return -1;
1154       }
1155       if (boundaries[0].IsInsideBranch() &&
1156           boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
1157         // The only boundary of moving-out is the tuple op inside branches.
1158         return -1;
1159       }
1160     }
1161     // If trying alternative moving configurations, turn off reuse analysis.
1162     if (!perform_reuse_analysis) {
1163       return 1;
1164     }
1165     // For cases like :
1166     // branch0 {
1167     //   ROOT copy
1168     // }
1169     // branch1 {
1170     //   ...
1171     // }
1172     // cond = conditional(branch0, branch1)
1173     // copy = copy(cond)
1174     //
1175     // We can fold the two copies thus reducing computation.
1176     auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64_t {
1177       if (hlo->opcode() != HloOpcode::kCopy) {
1178         return 0;
1179       }
1180       const HloGetTupleElementInstruction* gte =
1181           DynCast<HloGetTupleElementInstruction>(hlo->operand(0));
1182       if (gte == nullptr) {
1183         return 0;
1184       }
1185       const HloInstruction* conditional = gte->operand(0);
1186       if (conditional != conditional_) {
1187         return 0;
1188       }
1189       int64_t benefit = 0;
1190       for (auto* branch : conditional->called_computations()) {
1191         HloInstruction* root = branch->root_instruction();
1192         if (root->opcode() == HloOpcode::kTuple) {
1193           const auto* tuple_operand = root->operand(gte->tuple_index());
1194           if (tuple_operand->opcode() == HloOpcode::kCopy) {
1195             if (Shape::Equal()(tuple_operand->operand(0)->shape(),
1196                                hlo->shape())) {
1197               benefit += 10;
1198             }
1199           }
1200         }
1201       }
1202       return benefit;
1203     };
1204     for (const Boundary& b : boundaries) {
1205       auto op = b.operands()[0];
1206       if (op == conditional_->branch_computation(0)->root_instruction()) {
1207         continue;
1208       }
1209       VLOG(2) << "Benefit for " << op->ToString();
1210       reuses_before += ReusesBeforeBoundary(op);
1211       VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
1212       reuses_after += ReusesAfterBoundary(op);
1213       VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
1214     }
1215 
1216     int64_t copy_folding_benefit = 0;
1217     if (boundaries[0].IsOutsideBranch()) {
1218       for (const Boundary& b : boundaries) {
1219         auto op = b.operands()[0];
1220         copy_folding_benefit += get_copy_folding_benefit(op);
1221       }
1222     }
1223     VLOG(2) << "Copy folding benefit: " << copy_folding_benefit;
1224 
1225     if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) {
1226       return -1;
1227     } else if (boundaries[0].IsInsideBranch()) {
1228       return reuses_after - reuses_before;
1229     } else {
1230       return reuses_before - reuses_after - 1 + copy_folding_benefit;
1231     }
1232   }
1233 
GetNextBoundary(const Boundary & b,int64_t op_index)1234   Boundary GetNextBoundary(const Boundary& b, int64_t op_index) {
1235     Boundary b2(b.GetPosition());
1236     for (int j = 0; j < b.operands().size(); ++j) {
1237       HloInstruction* inst = b.operands()[j];
1238       CHECK(inst != nullptr);
1239       HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
1240                                                 : inst->users()[op_index];
1241       CHECK(op != nullptr);
1242       b2.mutable_operands().push_back(op);
1243     }
1244     return b2;
1245   }
1246 
1247   // Checking whether it is safe to move a boundary when visited through a
1248   // dependent already considered for moving.
IsSafeToMoveBoundary(const Boundary & next_boundary)1249   bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
1250     int64_t next_boundary_count =
1251         (next_boundary.IsInsideBranch())
1252             ? next_boundary.operands()[0]->user_count()
1253             : CountNonLeafOps(next_boundary.operands()[0]->operands());
1254     if (next_boundary_count <= 1) {
1255       // If boundary has only a single or no dependent, safe to move.
1256       return true;
1257     } else {
1258       if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
1259         VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
1260                 << " because it has multiple dependents: "
1261                 << next_boundary_count << "\n";
1262         visited_count_[next_boundary.operands()[0]] = 1;
1263         new_boundaries_.push_back(next_boundary);
1264       } else {
1265         auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
1266                              next_boundary);
1267         if (pos != new_boundaries_.end() ||
1268             next_boundary.operands().size() == 1) {
1269           int count = ++visited_count_[next_boundary.operands()[0]];
1270           if (count == next_boundary_count) {
1271             VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
1272                     << "\n"
1273                     << " because all of its dependents have been visited: "
1274                     << next_boundary_count << "\n";
1275             visited_count_.erase(next_boundary.operands()[0]);
1276             if (pos != new_boundaries_.end()) {
1277               new_boundaries_.erase(pos);
1278             }
1279             return true;
1280           }
1281         } else {
1282           VLOG(2) << "Skip incompatible multi-dependent boundary: "
1283                   << next_boundary.ToString() << ":" << next_boundary_count
1284                   << "\n";
1285         }
1286       }
1287     }
1288     return false;
1289   }
1290   // This function is reused both for moving the boundary outside or into a
1291   // conditional. As the result, the readability is somewhat compromised.
1292   // It might be nice to refactor this function to factor the outside-inside
1293   // considerations into separate function pointer parameters to improve
1294   // readability.
AddBoundaries(const Boundary & boundary)1295   void AddBoundaries(const Boundary& boundary) {
1296     BoundaryVisitor visitor;
1297     visitor.AddToWorkList(boundary);
1298     while (visitor.HasNextBoundary()) {
1299       Boundary b = visitor.PopNextBoundary();
1300       VLOG(2) << "visiting boundary " << b.ToString() << "\n";
1301       if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
1302                                       b.operands(), is_layout_sensitive_)) &&
1303           IsSafeToMoveBoundary(b) &&
1304           WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
1305         connected_boundaries_.push_back(b);
1306         VLOG(2) << "boundary can be moved\n";
1307         int64_t operand_count = (b.IsInsideBranch())
1308                                     ? b.operands()[0]->operand_count()
1309                                     : b.operands()[0]->users().size();
1310         for (int i = 0; i < operand_count; i++) {
1311           Boundary next_boundary = GetNextBoundary(b, i);
1312           VLOG(2) << "Add operand/user " << i << " to visit later\n";
1313           visitor.AddToWorkList(next_boundary);
1314         }
1315       } else {
1316         VLOG(2) << "boundary cannot be moved\n";
1317         visited_count_[b.operands()[0]] = 1;
1318         new_boundaries_.push_back(b);
1319       }
1320     }
1321   }
BoundariesToMoveInOrOut(HloInstruction * conditional,const Boundary & b)1322   std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
1323                                                 const Boundary& b) {
1324     // At the beginning of optimization, a conditional itself is added to a
1325     // worklist. Here the conditional is expanded into two sets of boundaries:
1326     // the first set contains the boundary that is inside branches and
1327     // contains the root of all branches; the second set of boundaries
1328     // contains all the users of the conditional.
1329     HloInstruction* inst = b.operands()[0];
1330     if (inst == conditional) {
1331       int branch_count = inst->branch_count();
1332       // Add conditional roots as a new boundary to visit.
1333       Boundary boundary_in(Boundary::Position::kInsideBranch);
1334       for (int i = 0; i < branch_count; i++) {
1335         HloComputation* branch_computation = inst->branch_computation(i);
1336         HloInstruction* root_inst = branch_computation->root_instruction();
1337         CHECK(root_inst != nullptr);
1338         boundary_in.mutable_operands().push_back(root_inst);
1339       }
1340       new_boundaries_.push_back(boundary_in);
1341       // Add conditional users as new boundaries to visit.
1342       for (auto u : inst->users()) {
1343         Boundary boundary_in(Boundary::Position::kOutsideBranch);
1344         boundary_in.mutable_operands().push_back(u);
1345         new_boundaries_.push_back(boundary_in);
1346       }
1347     } else {
1348       AddBoundaries(b);
1349     }
1350     return connected_boundaries_;
1351   }
AddNewBoundaries(std::vector<Boundary> & b)1352   void AddNewBoundaries(std::vector<Boundary>& b) {
1353     b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end());
1354   }
1355 };
1356 
ConsiderCodeMotion(HloInstruction * conditional,const Boundary & cur_boundary,std::vector<Boundary> & to_move,std::vector<Boundary> & new_boundaries,absl::flat_hash_map<HloInstruction *,int> & visited_count)1357 ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
1358     HloInstruction* conditional, const Boundary& cur_boundary,
1359     std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
1360     absl::flat_hash_map<HloInstruction*, int>& visited_count) {
1361   GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
1362                                    visited_count, &move_config_, &reuse_config_,
1363                                    search_config_);
1364   auto move_in_or_out =
1365       connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
1366   if (!move_in_or_out.empty()) {
1367     auto benefit = connect.BenefitForMovingBoundaries(
1368         move_in_or_out, search_config_map_.empty());
1369     VLOG(2) << "benefit of moving in or out "
1370             << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
1371     if (benefit >= 0) {
1372       new_boundaries.clear();
1373       connect.AddNewBoundaries(new_boundaries);
1374       // The whole sequence in move_in_or_out is either all moving into a
1375       // conditional, or all moving out of a conditional. So looking only
1376       // at the first entry of the sequence is sufficient to know which
1377       // direction the move is intended.
1378       to_move = move_in_or_out;
1379       return Decision(to_move[0].IsInsideBranch()
1380                           ? Decision::Direction::kMoveOutOfBranch
1381                           : Decision::Direction::kMoveIntoBranch,
1382                       benefit);
1383     } else {
1384       connect.clear_recently_visited();
1385     }
1386   } else {
1387     connect.AddNewBoundaries(new_boundaries);
1388   }
1389   return Decision(Decision::Direction::kNoChange, 0);
1390 }
1391 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1392 StatusOr<bool> ConditionalCodeMotion::Run(
1393     HloModule* module,
1394     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1395   VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
1396   // Use to support debugging of optimization, by disabling the opt after it has
1397   // been applied a pre-determined times (to isolate impact of transformations).
1398   if (!ConsumeFuel("conditional_code_motion", [&] {
1399         return "Skipping conditional opt after allowed limit reaching 0.\n";
1400       })) {
1401     return false;
1402   }
1403   bool changed = false;
1404   bool cleanup_changed = false;
1405   {
1406     HloPassPipeline subpipeline("before_conditional_code_motion");
1407     subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_);
1408     subpipeline.AddPass<HloDCE>();
1409     TF_ASSIGN_OR_RETURN(auto cleanup_changed_now,
1410                         subpipeline.Run(module, execution_threads));
1411     cleanup_changed |= cleanup_changed_now;
1412   }
1413   // Gather all the conditional ops in the module ahead of time, to avoid
1414   // potential complications of modifying the code that affecting traversal.
1415   std::vector<HloInstruction*> conditional_ops;
1416   // Track how many times each branch computation is shared.
1417   absl::flat_hash_map<HloComputation*, int> conditional_computations;
1418   for (auto* comp : module->MakeComputationPostOrder(execution_threads)) {
1419     for (auto* instr : comp->MakeInstructionPostOrder()) {
1420       if (instr->opcode() == HloOpcode::kConditional) {
1421         int branch_count = instr->branch_count();
1422         for (int i = 0; i < branch_count; ++i) {
1423           HloComputation* branch_i = instr->branch_computation(i);
1424           if (ContainsKey(conditional_computations, branch_i)) {
1425             conditional_computations[branch_i]++;
1426           } else {
1427             conditional_computations[branch_i] = 0;
1428           }
1429         }
1430         if (instr->shape().IsTuple()) {
1431           bool can_change_tuple_shape = true;
1432           for (auto user : instr->users()) {
1433             VLOG(2) << "user is : " << user->ToString() << "\n";
1434             if (user->opcode() != HloOpcode::kGetTupleElement) {
1435               can_change_tuple_shape = false;
1436             }
1437           }
1438           if (can_change_tuple_shape) {
1439             conditional_ops.push_back(instr);
1440           }
1441         } else {
1442           conditional_ops.push_back(instr);
1443         }
1444       }
1445     }
1446   }
1447 
1448   int64_t conditional_index = 0;
1449   // Use to collect mappings between cloned instructions.
1450   HloCloneContext clone_context(module);
1451   for (HloInstruction* conditional : conditional_ops) {
1452     if (conditional_index == 0 || !search_config_map_.empty()) {
1453       auto config_entry = search_config_map_.find(conditional_index);
1454       if (config_entry != search_config_map_.end()) {
1455         search_config_ = (*config_entry).second;
1456         VLOG(2) << "config entry value extracted:" << search_config_.size();
1457         search_config_index_ = 0;
1458       }
1459       VLOG(2) << "Obtaining default configuration for conditional "
1460               << conditional_index << "\n";
1461       SetDefaultMoveConfig();
1462       VLOG(2) << "Done obtaining default configuration\n";
1463       conditional_index++;
1464     }
1465     int branch_count = conditional->branch_count();
1466     // check for shared conditional computations
1467     bool conditional_is_shared = false;
1468     for (int i = 0; i < branch_count; ++i) {
1469       HloComputation* branch_i = conditional->branch_computation(i);
1470       if (conditional_computations[branch_i] > 0) {
1471         conditional_is_shared = true;
1472         break;
1473       }
1474     }
1475 
1476     // Boundaries to move out or to move into the branches.
1477     std::vector<std::vector<Boundary>> to_move_out, to_move_in;
1478     std::vector<std::vector<Boundary>> new_boundaries_for_moveout;
1479     std::vector<std::vector<Boundary>> new_boundaries_for_movein;
1480     // Number of times each instruction has been visited for moving.
1481     absl::flat_hash_map<HloInstruction*, int> visited_count;
1482     int benefit_move_out = 0, benefit_move_in = 0;
1483     Decision::Direction final_d = Decision::Direction::kNoChange;
1484     // The conditional is moved into a worklist as the seed (starting point).
1485     // The conditional will be expanded into multiple seeds (starting points),
1486     // its roots and its users, when it is visited by GroupConnectedBoundaries.
1487     // A NO_CHANGE decision will always be returned for the conditional itself,
1488     // so that the other seeding boundaries can be visited in turn.
1489     BoundaryVisitor visitor(conditional);
1490     VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
1491     // Try visit all the boundaries, collect the analysis results, and save
1492     // all the benefitical non-conflicting decisions. If two decisions conflict
1493     // with each other, save the more benefitical one.
1494     while (visitor.HasNextBoundary()) {
1495       std::vector<Boundary> to_move, next_boundary;
1496       Boundary boundary = visitor.PopNextBoundary();
1497       VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
1498       auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
1499                                   visited_count);
1500       switch (d.GetDirection()) {
1501         case Decision::Direction::kMoveOutOfBranch:
1502           VLOG(2) << "Local Decision is move out of branch\n";
1503           to_move_out.push_back(to_move);
1504           new_boundaries_for_moveout.push_back(next_boundary);
1505           benefit_move_out += d.GetBenefit();
1506           if (benefit_move_out >= benefit_move_in) {
1507             final_d = Decision::Direction::kMoveOutOfBranch;
1508             VLOG(2) << "Current Decision is move out of branch ("
1509                     << to_move_out.size() << ")\n";
1510           } else {
1511             VLOG(2) << "Current Decision remains move into branch\n";
1512           }
1513           break;
1514         case Decision::Direction::kMoveIntoBranch:
1515           VLOG(2) << "Decision is move into branch\n";
1516           to_move_in.push_back(to_move);
1517           new_boundaries_for_movein.push_back(next_boundary);
1518           benefit_move_in += d.GetBenefit();
1519           if (benefit_move_out >= benefit_move_in) {
1520             VLOG(2) << "Current Decision remains move out of branch\n";
1521           } else {
1522             final_d = Decision::Direction::kMoveIntoBranch;
1523             VLOG(2) << "Current Decision is move into branch ("
1524                     << to_move_in.size() << ")\n";
1525           }
1526           break;
1527         case Decision::Direction::kNoChange:
1528           VLOG(2) << "Decision is no change\n";
1529           for (const Boundary& b : next_boundary) {
1530             visitor.AddToWorkList(b);
1531             VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
1532                     << "\n";
1533           }
1534           break;
1535       }
1536     }
1537     // If modification is to be made, need to clone the shared branches.
1538     if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
1539       for (int i = 0; i < branch_count; ++i) {
1540         HloComputation* branch_i = conditional->branch_computation(i);
1541         if (conditional_computations[branch_i] > 0) {
1542           // Cloning is absolutely needed if the computation is shared by
1543           // different branches, but the cloning can be potentially avoided
1544           // if the sharing is only among branches of the same conditional.
1545           // If cloning these branches causes a problem due to space issues,
1546           // a fix can pass a vector of unique branches to the actual
1547           // transformations, as an alternative representation of the
1548           // conditional branches to be modified. Right now we assume the
1549           // overhead of cloning is minimal since later stages of the compiler
1550           // inline all the computations anyway.
1551           HloComputation* clone_i =
1552               conditional->parent()->parent()->AddEmbeddedComputation(
1553                   branch_i->Clone("clone", &clone_context));
1554           conditional->set_branch_computation(i, clone_i);
1555           conditional_computations[branch_i]--;
1556           // Need to translate the analysis result to generate correct result.
1557           auto update_boundary = [&](Boundary& boundary) {
1558             auto cloned_instr =
1559                 clone_context.FindInstruction(boundary.operands()[i]);
1560             CHECK(cloned_instr != nullptr);
1561             VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
1562                     << "\n";
1563             boundary.mutable_operands()[i] = cloned_instr;
1564             VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
1565                     << "\n";
1566           };
1567           // Only boundaries to move out need to be updated.
1568           if (final_d == Decision::Direction::kMoveOutOfBranch) {
1569             for (int i = 0; i < to_move_out.size(); ++i) {
1570               std::vector<Boundary>& m = to_move_out[i];
1571               std::for_each(m.begin(), m.end(), update_boundary);
1572             }
1573             for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
1574               std::vector<Boundary>& m = new_boundaries_for_moveout[i];
1575               std::for_each(m.begin(), m.end(), update_boundary);
1576             }
1577           }
1578         }
1579       }
1580       VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
1581               << "\n";
1582     }
1583     // At most one of to_move_out or to_move_in can be non-empty, since there is
1584     // only one optimization decision.
1585     if (final_d == Decision::Direction::kMoveOutOfBranch) {
1586       CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
1587       for (int i = 0; i < to_move_out.size(); ++i) {
1588         TF_ASSIGN_OR_RETURN(bool result,
1589                             MoveInstructionOut(conditional, to_move_out[i],
1590                                                new_boundaries_for_moveout[i]));
1591         changed |= result;
1592       }
1593       VLOG(2) << "Done moving out of branches " << to_move_out.size()
1594               << " times. \n";
1595       if (!ConsumeFuel("conditional_code_motion", [&] {
1596             return "Skipping conditional opt after allowed limit reaching 0.\n";
1597           })) {
1598         break;
1599       }
1600     } else if (final_d == Decision::Direction::kMoveIntoBranch) {
1601       CHECK(to_move_in.size() == new_boundaries_for_movein.size());
1602       for (int i = 0; i < to_move_in.size(); ++i) {
1603         TF_ASSIGN_OR_RETURN(bool result,
1604                             MoveInstructionIn(conditional, to_move_in[i],
1605                                               new_boundaries_for_movein[i]));
1606         changed |= result;
1607       }
1608       VLOG(2) << "Done moving into branches " << to_move_in.size()
1609               << " times. \n";
1610       if (!ConsumeFuel("conditional_code_motion", [&] {
1611             return "Skipping conditional opt after allowed limit reaching 0.\n";
1612           })) {
1613         break;
1614       }
1615     } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
1616       // Invoke special handling for convert rematerialization/hoisting
1617       // We need to make sure no sharing is present in the branches because no
1618       // cloning has been done by the earlier analysis.
1619       // TOOD[b/165848866]: extend solution to handle cloning for special move.
1620       TF_ASSIGN_OR_RETURN(
1621           bool convert_result,
1622           ConvertSpecialMove(conditional, is_layout_sensitive_));
1623       if (convert_result) {
1624         VLOG(2) << "Done special moving of convert\n";
1625         if (!ConsumeFuel("conditional_code_motion", [&] {
1626               return "Skipping conditional opt after allowed limit reaching "
1627                      "0.\n";
1628             })) {
1629           break;
1630         }
1631       }
1632       changed |= convert_result;
1633     }
1634   }
1635   if (changed) {
1636     HloPassPipeline subpipeline(
1637         "after_conditional_code_motion_after_convert_hoisting");
1638     VLOG(2) << "starting after motion passes: DCE\n";
1639     subpipeline.AddPass<HloDCE>();
1640     subpipeline.AddPass<TupleSimplifier>();
1641     subpipeline.AddPass<HloDCE>();
1642     TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1643     cleanup_changed |= cleanup_changed_now;
1644   }
1645   if (cleanup_changed) {
1646     VLOG(2) << "subpipeline cleanup have modified code\n";
1647   }
1648   return changed;
1649 }
1650 
SetDefaultMoveConfig()1651 void ConditionalCodeMotion::SetDefaultMoveConfig() {
1652   VLOG(2) << "search_config_index = " << search_config_index_ << "\n";
1653   VLOG(2) << "search_config_ size = " << search_config_.size() << "\n";
1654   int64_t cur_search_config = (search_config_index_ < 0 ||
1655                                search_config_index_ >= search_config_.size())
1656                                   ? 0
1657                                   : search_config_[search_config_index_];
1658   enum class TuningOption {
1659     kDoNotTune = 0,
1660     kTuneTransformationDecision = 1,
1661     kTuneReuseModel = 2,
1662   };
1663   TuningOption tuning_option =
1664       (cur_search_config == 0)  ? TuningOption::kDoNotTune
1665       : (cur_search_config > 0) ? TuningOption::kTuneTransformationDecision
1666                                 : TuningOption::kTuneReuseModel;
1667 
1668   auto row = HloOpcodeCount();
1669   auto col = row;
1670   VLOG(2) << "Start setting default configuration\n";
1671   reuse_config_.clear();
1672   move_config_.clear();
1673   reuse_config_.reserve(row);
1674   move_config_.reserve(row);
1675   for (int64_t opcode = 0; opcode < row; ++opcode) {
1676     // To save whether an instruction is preferred to be moved.
1677     std::vector<int64_t> reuse_vec(col, 0);
1678     for (uint32_t j = 0; j < col; ++j) {
1679       reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
1680                                      static_cast<HloOpcode>(j));
1681     }
1682     reuse_config_.push_back(reuse_vec);
1683     std::vector<int64_t> move_vec;
1684     switch (tuning_option) {
1685       case TuningOption::kTuneTransformationDecision:
1686         // Tuning transformation decision --- start with all yes.
1687         // Only a single entry is needed if we don't consider operands of an op
1688         // when searching/tuning transformation decisions.
1689         move_vec.push_back(1);
1690         break;
1691         // Tune the ReusesCarriedBy results only.
1692       case TuningOption::kTuneReuseModel:
1693       case TuningOption::kDoNotTune:
1694         // No tuning --- use the default configuration.
1695         // Use the opcode of first operand to configure default.
1696         move_vec.reserve(col);
1697         for (uint32_t j = 0; j < col; ++j) {
1698           move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
1699                                            static_cast<HloOpcode>(j)));
1700         }
1701         break;
1702     }
1703     move_config_.push_back(move_vec);
1704   }
1705 }
1706 
1707 // The search configuration is specified using a string in the format of
1708 // 'config1;config2; ...;config_n', where each config_i is in the format of
1709 // 'index,start,max,stride' (four integers separated by comma), which specify
1710 // the index number of the conditional being configured, the index of the first
1711 // transformation decision to flip for the conditional, the max number of
1712 // decisions to flip, and how many decisions to skip in between the flips.
ParseSearchConfiguration(const std::string & search_config)1713 void ConditionalCodeMotion::ParseSearchConfiguration(
1714     const std::string& search_config) {
1715   if (search_config.empty()) {
1716     return;
1717   }
1718   search_config_index_ = 0;
1719   std::vector<std::string> configs = absl::StrSplit(search_config, ';');
1720   for (const std::string& config : configs) {
1721     std::vector<std::string> specs = absl::StrSplit(config, ',');
1722     CHECK_EQ(specs.size(), 4);
1723     int64_t condition_index;
1724     CHECK(absl::SimpleAtoi(specs[0], &condition_index));
1725     auto& cur_config_entry = search_config_map_[condition_index];
1726     int64_t flip_start, max_flip, flip_stride;
1727     CHECK(absl::SimpleAtoi(specs[1], &flip_start));
1728     CHECK(absl::SimpleAtoi(specs[2], &max_flip));
1729     CHECK(absl::SimpleAtoi(specs[3], &flip_stride));
1730     int64_t cur_config = MakeSearchConfig(flip_start, max_flip, flip_stride);
1731     cur_config_entry.push_back(cur_config);
1732     VLOG(2) << "Setting search config " << condition_index << "->" << cur_config
1733             << "\n";
1734   }
1735 }
1736 
1737 }  // namespace conditional_opt
1738 
1739 }  // namespace xla
1740