1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/conditional_code_motion.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/call_graph.h"
32 #include "tensorflow/compiler/xla/service/call_inliner.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_cse.h"
36 #include "tensorflow/compiler/xla/service/hlo_dce.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
41 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
42 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/platform/errors.h"
50
51 namespace xla {
52
53 namespace conditional_opt {
54
55 class BoundaryVisitor {
56 public:
57 // start with an existing conditional computation.
BoundaryVisitor(HloInstruction * conditional)58 explicit BoundaryVisitor(HloInstruction* conditional) {
59 Boundary b(Boundary::Position::kInsideBranch);
60 b.mutable_operands().push_back(conditional);
61 worklist_.push_back(b);
62 }
63 // Start with an empty work list.
BoundaryVisitor()64 BoundaryVisitor() {}
65 // Get next boundary to visit.
PopNextBoundary()66 Boundary PopNextBoundary() {
67 CHECK(!worklist_.empty());
68 Boundary b = worklist_.front();
69 worklist_.pop_front();
70 // if b is already visited, it must have multiple users and is already in
71 // new boundaries. Skip it.
72 while (!worklist_.empty() && ContainsKey(visited_, b)) {
73 b = worklist_.front();
74 worklist_.pop_front();
75 }
76 visited_.insert(b);
77 return b;
78 }
AddToWorkList(const Boundary & b)79 void AddToWorkList(const Boundary& b) {
80 CHECK(!b.operands().empty());
81 worklist_.push_back(b);
82 }
83
HasNextBoundary()84 bool HasNextBoundary() {
85 while (!worklist_.empty()) {
86 Boundary b = worklist_.front();
87 if (!ContainsKey(visited_, b)) {
88 break;
89 }
90 worklist_.pop_front();
91 }
92 return !worklist_.empty();
93 }
94
95 private:
96 // worklist is the deque that contains instructions to be visited.
97 std::deque<Boundary> worklist_;
98 absl::flat_hash_set<Boundary> visited_;
99 };
100
101 template <class OpCollection>
CountNonLeafOps(const OpCollection & ops)102 int64_t CountNonLeafOps(const OpCollection& ops) {
103 absl::flat_hash_set<HloInstruction*> op_set;
104 for (auto op : ops) {
105 if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) {
106 op_set.insert(op);
107 }
108 }
109 return op_set.size();
110 }
111
112 // Returns estimation of potential reuses carried by a given pair of
113 // instructions. Use different integers to classify different levels
114 // of reuses This is used as a placeholder only, assuming all
115 // instructions can be fused to enable data reuses
ReusesCarriedBy(HloOpcode op,HloOpcode user)116 int64_t ReusesCarriedBy(HloOpcode op, HloOpcode user) {
117 // Reuses in some way work like forces that pull instructions
118 // towards each other. We use a number 0-10 to classify how strong the force
119 // is between a pair of operations. Given a group of instructions that can be
120 // moved together, if the forces inside a conditional are stronger, the group
121 // will be moved incide or remain inside the conditional; otherwise, it will
122 // be moved outside to or remain outside of the conditional.
123 switch (user) {
124 case HloOpcode::kGetTupleElement:
125 return 0;
126 case HloOpcode::kConvert:
127 // Because convert is treated not moveable when following Dot or
128 // convolution, here if op is dot or convolution, they must be separated
129 // by a conditional boundary. Here we do not try to pull convert inside
130 // conditionals to be together with the dot or convolution.
131 switch (op) {
132 case HloOpcode::kConvolution:
133 case HloOpcode::kDot:
134 return 0;
135 default:
136 break;
137 }
138 break;
139 default:
140 break;
141 }
142 switch (op) {
143 // These instructions do not carry weight of reuse themselves.
144 case HloOpcode::kParameter:
145 case HloOpcode::kConstant:
146 case HloOpcode::kGetTupleElement:
147 return 0;
148 case HloOpcode::kConditional:
149 return 10;
150 default:
151 return -10;
152 }
153 }
154
155 // Returns true if `op` is worth hoisting.
WorthHoisting(HloOpcode op,HloOpcode child_op)156 bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
157 // TOOD[b/169182921] The following cost model is rather incomplete. Will
158 // need to extend to cover most of element-wise ops.
159 switch (op) {
160 case HloOpcode::kConvert:
161 // If Convert is after AllReduce, it is worth moving out AllReduce
162 // out of conditional for AR/CRS combine. If Convert is after other
163 // ops such as Dot or Convolutional, it is better to keep convert
164 // within conditional so that convert can be fused with Dot or
165 // Convolutional.
166 switch (child_op) {
167 case HloOpcode::kAllReduce:
168 case HloOpcode::kReshape:
169 case HloOpcode::kGetTupleElement:
170 return true;
171 default:
172 return false;
173 }
174 case HloOpcode::kGetTupleElement:
175 switch (child_op) {
176 // do not move GTE if its operand is a parameter
177 case HloOpcode::kParameter:
178 return false;
179 default:
180 return true;
181 }
182 case HloOpcode::kAllReduce:
183 case HloOpcode::kReduceScatter:
184 case HloOpcode::kAbs:
185 case HloOpcode::kReduce:
186 case HloOpcode::kAdd:
187 case HloOpcode::kPower:
188 case HloOpcode::kCopy:
189 case HloOpcode::kConstant:
190 case HloOpcode::kSubtract:
191 case HloOpcode::kMultiply:
192 case HloOpcode::kDivide:
193 case HloOpcode::kTuple:
194 case HloOpcode::kSqrt:
195 case HloOpcode::kRsqrt:
196 case HloOpcode::kReshape:
197 case HloOpcode::kMinimum:
198 case HloOpcode::kMaximum:
199 return true;
200 default:
201 return false;
202 }
203 }
204
205 // Compare if the instructions to be visited at each branches are identical.
InstructionWithinBranchIdentical(const std::vector<HloInstruction * > & instructions,bool is_layout_sensitive)206 bool InstructionWithinBranchIdentical(
207 const std::vector<HloInstruction*>& instructions,
208 bool is_layout_sensitive) {
209 // Identical includes the shape of each operands are equal.
210 auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
211 bool eq_operands = is_layout_sensitive
212 ? ShapeUtil::Equal(a->shape(), b->shape())
213 : ShapeUtil::Compatible(a->shape(), b->shape());
214 return eq_operands;
215 };
216
217 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
218 return *a == *b;
219 };
220
221 if (instructions.empty()) {
222 return false;
223 }
224
225 if (instructions[0]->IsCrossModuleAllReduce()) {
226 return std::all_of(
227 instructions.begin(), instructions.end(),
228 [&](HloInstruction* instruction) {
229 if (!instruction->IsCrossModuleAllReduce()) {
230 return false;
231 }
232 auto old_channel_id = instruction->channel_id();
233 instruction->set_channel_id(instructions[0]->channel_id());
234 bool eq_instructions = instructions[0]->Identical(
235 *instruction, eq_operand, eq_computations, is_layout_sensitive);
236 instruction->set_channel_id(old_channel_id);
237 return eq_instructions;
238 });
239 }
240
241 return std::all_of(instructions.begin(), instructions.end(),
242 [&](HloInstruction* instruction) {
243 return instructions[0]->Identical(
244 *instruction, eq_operand, eq_computations,
245 is_layout_sensitive);
246 });
247 }
248
249 // Copy the boundary out of the conditional and update hoisted_boundaries.
CopyOutOfConditional(Boundary & boundary,HloInstruction * conditional,absl::flat_hash_map<Boundary,Boundary> & hoisted_boundaries)250 void CopyOutOfConditional(
251 Boundary& boundary, HloInstruction* conditional,
252 absl::flat_hash_map<Boundary, Boundary>& hoisted_boundaries) {
253 CHECK(boundary.IsInsideBranch());
254 absl::InlinedVector<HloInstruction*, 4> new_operands;
255 // All of the branch operands should have the same opcode and shape, so just
256 // use branch 0.
257 const HloInstruction* branch0_inst = boundary.operands()[0];
258 for (int i = 0; i < branch0_inst->operands().size(); ++i) {
259 Boundary operand_boundary(boundary.GetPosition());
260 for (HloInstruction* operand : boundary.operands()) {
261 operand_boundary.mutable_operands().push_back(operand->operands()[i]);
262 }
263 VLOG(2) << "Looking for: " << operand_boundary.ToString();
264 auto hoisted_boundaries_it = hoisted_boundaries.find(operand_boundary);
265 CHECK(hoisted_boundaries_it != hoisted_boundaries.end());
266 Boundary hoisted_boundary = hoisted_boundaries_it->second;
267 CHECK(hoisted_boundary.IsOutsideBranch());
268 CHECK_EQ(hoisted_boundary.operands().size(), 1);
269 new_operands.push_back(hoisted_boundary.operands()[0]);
270 }
271 HloInstruction* new_instruction = conditional->parent()->AddInstruction(
272 branch0_inst->CloneWithNewOperands(branch0_inst->shape(), new_operands));
273 VLOG(2) << "new instruction:" << new_instruction->ToString();
274 // Maps the instruction outside of conditional to the instruction
275 // inside of the conditional.
276 Boundary hoisted_boundary(Boundary::Position::kOutsideBranch);
277 hoisted_boundary.mutable_operands().push_back(new_instruction);
278 hoisted_boundaries[boundary] = hoisted_boundary;
279 }
280
281 // Copy the boundary into the conditional and update hoisted_boundaries.
CopyIntoConditional(Boundary & boundary,HloInstruction * conditional,absl::flat_hash_map<Boundary,Boundary> & hoisted_boundaries)282 void CopyIntoConditional(
283 Boundary& boundary, HloInstruction* conditional,
284 absl::flat_hash_map<Boundary, Boundary>& hoisted_boundaries) {
285 CHECK(boundary.IsOutsideBranch());
286 CHECK_EQ(boundary.operands().size(), 1);
287 int num_branches = conditional->branch_count();
288 std::vector<absl::InlinedVector<HloInstruction*, 4>> new_operands(
289 num_branches);
290 HloInstruction* op = boundary.operands()[0];
291 for (HloInstruction* operand : op->operands()) {
292 Boundary operand_boundary(boundary.GetPosition());
293 operand_boundary.mutable_operands().push_back(operand);
294 VLOG(2) << "Looking for: " << operand_boundary.ToString();
295 auto hoisted_boundaries_it = hoisted_boundaries.find(operand_boundary);
296 if (hoisted_boundaries_it != hoisted_boundaries.end()) {
297 Boundary hoisted_boundary = hoisted_boundaries_it->second;
298 CHECK(hoisted_boundary.IsInsideBranch());
299 CHECK_EQ(hoisted_boundary.operands().size(), num_branches);
300 for (int j = 0; j < num_branches; ++j) {
301 new_operands[j].push_back(hoisted_boundary.operands()[j]);
302 }
303 } else {
304 for (int j = 0; j < num_branches; ++j) {
305 switch (operand->opcode()) {
306 case HloOpcode::kConstant: {
307 auto new_operand =
308 conditional->branch_computation(j)->AddInstruction(
309 operand->Clone());
310 VLOG(2) << "new instruction:" << new_operand->ToString();
311 new_operands[j].push_back(new_operand);
312 break;
313 }
314 case HloOpcode::kGetTupleElement: {
315 auto gte = Cast<HloGetTupleElementInstruction>(operand);
316 int64_t index = gte->tuple_index();
317 HloInstruction* root =
318 conditional->branch_computation(j)->root_instruction();
319 CHECK(root->opcode() == HloOpcode::kTuple &&
320 index < root->operand_count())
321 << root->ToString() << " " << gte->ToString();
322 auto new_operand = root->mutable_operand(index);
323 VLOG(2) << "new instruction:" << new_operand->ToString();
324 new_operands[j].push_back(new_operand);
325 break;
326 }
327 default:
328 LOG(FATAL) << "Unexpected out-of-boundary instruction:"
329 << operand->ToString() << "\n";
330 }
331 }
332 }
333 }
334
335 Boundary hoisted_boundary(Boundary::Position::kInsideBranch);
336 for (int j = 0; j < num_branches; ++j) {
337 HloInstruction* new_instruction =
338 conditional->branch_computation(j)->AddInstruction(
339 op->CloneWithNewOperands(op->shape(), new_operands[j]));
340 VLOG(2) << "new instruction:" << new_instruction->ToString();
341 hoisted_boundary.mutable_operands().push_back(new_instruction);
342 }
343 hoisted_boundaries[boundary] = hoisted_boundary;
344 }
345
346 // Identify converts to be hoisted/rematerialized out of the branch
347 // computations.
FindSpecialConverts(HloInstruction * old_root,int branch_count,HloInstruction * conditional,bool is_layout_sensitive)348 absl::flat_hash_set<int64_t> FindSpecialConverts(HloInstruction* old_root,
349 int branch_count,
350 HloInstruction* conditional,
351 bool is_layout_sensitive) {
352 absl::flat_hash_set<int64_t> special_convert;
353
354 // TODO(b/216487727): Allow hoisting converts that feed or fed by other
355 // converts by addressing possible duplicates left behind in the tuple output.
356 // The conditional code motion pass should handle these duplicates and hence,
357 // merging these snippets of code would be one alternative.
358 auto convert_invalid =
359 [](const HloInstruction* convert_set_candidate) -> bool {
360 bool invalid_user = absl::c_any_of(
361 convert_set_candidate->users(), [](const HloInstruction* user) -> bool {
362 return (user->opcode() == HloOpcode::kConvert);
363 });
364 bool invalid_producer =
365 absl::c_any_of(convert_set_candidate->operands(),
366 [](const HloInstruction* operand) -> bool {
367 return (operand->opcode() == HloOpcode::kConvert);
368 });
369 return (invalid_user || invalid_producer);
370 };
371
372 for (int64_t operand_num = 0; operand_num < old_root->operand_count();
373 ++operand_num) {
374 if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
375 continue;
376 }
377 bool replica = true;
378 HloInstruction* special_convert_candidate =
379 old_root->mutable_operand(operand_num);
380 // TODO(b/216487727): Remove duplicates in tuple outputs while hoisting.
381 auto repeated =
382 absl::c_count_if(old_root->operands(),
383 [&](const HloInstruction* operand) -> bool {
384 return (special_convert_candidate == operand);
385 }) > 1;
386 if (convert_invalid(special_convert_candidate) || repeated) {
387 continue;
388 }
389 // Check whether an identical candidate appears in other branches
390 for (int others = 1; others < branch_count; ++others) {
391 HloInstruction* others_root =
392 conditional->branch_computation(others)->root_instruction();
393 const HloInstruction* other_convert = others_root->operand(operand_num);
394 if (other_convert->opcode() != HloOpcode::kConvert ||
395 convert_invalid(other_convert)) {
396 replica = false;
397 break;
398 }
399 // Do not move converts if their operands have different shapes in
400 // different branches.
401 bool eq_shape =
402 is_layout_sensitive
403 ? ShapeUtil::Equal(other_convert->shape(),
404 special_convert_candidate->shape()) &&
405 ShapeUtil::Equal(
406 other_convert->operand(0)->shape(),
407 special_convert_candidate->operand(0)->shape())
408 : ShapeUtil::Compatible(other_convert->shape(),
409 special_convert_candidate->shape()) &&
410 ShapeUtil::Compatible(
411 other_convert->operand(0)->shape(),
412 special_convert_candidate->operand(0)->shape());
413 if (!eq_shape) {
414 replica = false;
415 break;
416 }
417 auto repeated =
418 absl::c_count_if(others_root->operands(),
419 [&](const HloInstruction* operand) -> bool {
420 return (special_convert_candidate == operand);
421 }) > 1;
422 if (repeated) {
423 replica = false;
424 break;
425 }
426 }
427 if (replica) {
428 special_convert.insert(operand_num);
429 }
430 }
431 return special_convert;
432 }
433
434 // Restructuring the conditional instruction as follows:
435 // i.e., %result = conditional() becomes
436 // x = conditional()
437 // y.{0..n} = gte(x, {0..n})
438 // z = tuple(y.0, y.1, ...y.n)
439 // Doing so ensures that we can accommodate the possible shape-change of the
440 // conditional when the instructions are hoisted.
RestructureConditionalInstruction(HloComputation * computation,HloInstruction * conditional)441 Status RestructureConditionalInstruction(HloComputation* computation,
442 HloInstruction* conditional) {
443 HloInstruction* old_root = computation->root_instruction();
444 std::vector<HloInstruction*> new_operands;
445 int cur_index = 0;
446 for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
447 ++cur_index) {
448 new_operands.push_back(
449 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
450 ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
451 conditional, cur_index)));
452 }
453 HloInstruction* new_tuple =
454 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
455 if (old_root == conditional) {
456 computation->set_root_instruction(new_tuple);
457 } else {
458 std::vector<HloInstruction*> new_tuple_users;
459 for (auto conditional_user : conditional->users()) {
460 auto is_new_gte = absl::c_find_if(
461 new_operands,
462 [&](HloInstruction* instr) { return instr == conditional_user; });
463 if (is_new_gte == new_operands.end()) {
464 new_tuple_users.push_back(conditional_user);
465 }
466 }
467 for (auto new_tuple_user : new_tuple_users) {
468 TF_RETURN_IF_ERROR(
469 conditional->ReplaceUseWith(new_tuple_user, new_tuple));
470 }
471 }
472 VLOG(2) << "computation after root restructure:\n" << computation->ToString();
473 return OkStatus();
474 }
475
ConvertSpecialMove(HloInstruction * conditional,bool is_layout_sensitive)476 StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
477 bool is_layout_sensitive) {
478 int branch_count = conditional->branch_count();
479 if (branch_count <= 0) {
480 return false;
481 }
482
483 // Determining whether all branch roots are tuples
484 for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
485 HloInstruction* branch_root =
486 conditional->branch_computation(branch_num)->root_instruction();
487 if (branch_root->opcode() != HloOpcode::kTuple) {
488 return false;
489 }
490 }
491
492 HloInstruction* old_root =
493 conditional->branch_computation(0)->root_instruction();
494 VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
495 // Identify the gte using `index'.
496 auto find_gte = [](const HloInstruction* conditional_result,
497 int64_t index) -> HloInstruction* {
498 for (HloInstruction* instr : conditional_result->users()) {
499 if (instr->opcode() != HloOpcode::kGetTupleElement) {
500 return nullptr;
501 }
502 if (instr->tuple_index() == index) {
503 return instr;
504 }
505 }
506 return nullptr;
507 };
508
509 // Captures tuple indices refering to converts to be rematerialized/hoisted.
510 absl::flat_hash_set<int64_t> special_convert = FindSpecialConverts(
511 old_root, branch_count, conditional, is_layout_sensitive);
512
513 // Exit if we cannot find any converts to be hoisted.
514 if (special_convert.empty()) {
515 return false;
516 }
517
518 TF_RETURN_IF_ERROR(
519 RestructureConditionalInstruction(conditional->parent(), conditional));
520
521 for (int branch = 0; branch < branch_count; branch++) {
522 old_root = conditional->branch_computation(branch)->root_instruction();
523 absl::flat_hash_map<HloInstruction*, int64_t> map_inst_to_tuple_index;
524 std::vector<HloInstruction*> new_operands(old_root->operand_count());
525 absl::flat_hash_set<HloInstruction*> to_hoist_set;
526
527 for (int64_t operand_num = 0; operand_num < old_root->operand_count();
528 ++operand_num) {
529 map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
530 operand_num;
531 }
532 for (int64_t operand_num = 0; operand_num < old_root->operand_count();
533 ++operand_num) {
534 HloInstruction* hoist = old_root->mutable_operand(operand_num);
535 if (!special_convert.contains(operand_num)) {
536 new_operands[operand_num] = old_root->mutable_operand(operand_num);
537 continue;
538 }
539
540 to_hoist_set.insert(hoist);
541 int64_t new_tuple_count = old_root->operand_count();
542
543 // Replace the hoisted instr in the tuple with the operand/operands.
544 // We will replace at least one of the operands of the hoist at the
545 // tuple place; the rest will be added at the end.
546 bool inplace = true;
547 CHECK(!hoist->operands().empty());
548 for (HloInstruction* prod : hoist->operands()) {
549 if (inplace) {
550 map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
551 new_operands[map_inst_to_tuple_index[hoist]] = prod;
552 inplace = false;
553 } else {
554 map_inst_to_tuple_index[prod] = new_tuple_count++;
555 new_operands.push_back(prod);
556 }
557 }
558 }
559
560 // Create the new root instruction.
561 HloComputation* cur_branch = conditional->branch_computation(branch);
562 HloInstruction* new_branch_root =
563 cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
564 // The shape can vary since the operands to convert are now
565 // being returned through the branches' root.
566 cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
567 TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
568
569 // Only one of the branches needs to change the conditional->parent().
570 if (branch != 0) {
571 continue;
572 }
573 HloComputation* conditional_parent = conditional->parent();
574 HloInstruction* newconditional =
575 conditional_parent->AddInstruction(HloInstruction::CreateConditional(
576 cur_branch->root_instruction()->shape(),
577 conditional->mutable_operand(0),
578 absl::MakeSpan(conditional->branch_computations()),
579 absl::MakeSpan(conditional->operands()).subspan(1)));
580 // Ensure that all the users of conditional refer to the new one.
581 TF_RETURN_IF_ERROR(
582 conditional->ReplaceAllUsesWithDifferentShape(newconditional));
583 TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
584 conditional = newconditional;
585 // Add the hoisted instructions in the parent.
586 for (HloInstruction* hoist : to_hoist_set) {
587 VLOG(2) << "Hoisting instruction:" << hoist->ToString();
588 int64_t hoist_index = map_inst_to_tuple_index[hoist];
589 // Find out the gte that captured the hoisted instr result.
590 HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
591 CHECK(gte_hoist != nullptr);
592 std::vector<HloInstruction*> new_operands;
593 for (HloInstruction* op : hoist->operands()) {
594 HloInstruction* gte = conditional_parent->AddInstruction(
595 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
596 map_inst_to_tuple_index[op]));
597 new_operands.push_back(gte);
598 }
599 HloInstruction* hoisted = conditional_parent->AddInstruction(
600 hoist->CloneWithNewOperands(hoist->shape(), new_operands));
601 VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
602 TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
603 TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
604 }
605 // No need to explicitly delete a hoisted instruction since if its dead
606 // then the subsequent DCE will remove it.
607 }
608 VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
609 return true;
610 }
611
612 // Hoist identical ops out of the conditional. The definition of identical
613 // are the shape of the operands are identical and their properties are
614 // identical. Will start from the root instruction of each branch and get
615 // the identical ops to hoist.
MoveInstructionOut(HloInstruction * conditional,std::vector<Boundary> & to_move_out,std::vector<Boundary> & new_boundaries)616 StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
617 HloInstruction* conditional, std::vector<Boundary>& to_move_out,
618 std::vector<Boundary>& new_boundaries) {
619 if (to_move_out.empty()) {
620 return false;
621 }
622 VLOG(1) << "Modifying code--number of boundaries to move out:"
623 << to_move_out.size() << "\n";
624 HloComputation* conditional_parent = conditional->parent();
625 // save the old users before add new conditional user instructions
626 std::vector<HloInstruction*> old_conditional_users = conditional->users();
627 // Maps boundaries in the conditional body to boundaries hoisted outside
628 // the conditional that compute the same value.
629 absl::flat_hash_map<Boundary, Boundary> hoisted_boundaries;
630 // Insert GetTupleElement before the instructions whose operands might still
631 // be within the conditional.
632 VLOG(1) << "before opt:"
633 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
634 << "\n";
635 int64_t op_index = 0;
636 for (const Boundary& b : new_boundaries) {
637 HloInstruction* op = b.operands()[0];
638 CHECK(op != nullptr);
639 VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
640 HloInstruction* gtr = conditional_parent->AddInstruction(
641 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
642 op_index++));
643 Boundary b2(Boundary::Position::kOutsideBranch);
644 b2.mutable_operands().push_back(gtr);
645 hoisted_boundaries[b] = b2;
646 }
647 // Copy boundary instructions out of the conditional.
648 // Visit the operands before its users and copy it, so that the copied
649 // user will point to the correct operand.
650 for (int64_t i = to_move_out.size() - 1; i >= 0; i--) {
651 CopyOutOfConditional(to_move_out[i], conditional, hoisted_boundaries);
652 }
653 VLOG(2) << "Done copy branch instructions out\n"
654 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
655 << "\n";
656 // Change original users of the conditional to use the correct operands.
657 for (auto user_instr : old_conditional_users) {
658 VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
659 CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
660 auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
661 int64_t index = tuple_opd->tuple_index();
662 Boundary old_user_boundary(Boundary::Position::kInsideBranch);
663 for (const HloComputation* called_computation :
664 conditional->called_computations()) {
665 HloInstruction* root = called_computation->root_instruction();
666 CHECK(root->operands().size() > index);
667 old_user_boundary.mutable_operands().push_back(root->operands()[index]);
668 }
669 CHECK(ContainsKey(hoisted_boundaries, old_user_boundary));
670 HloInstruction* new_opd =
671 hoisted_boundaries[old_user_boundary].operands()[0];
672 CHECK(new_opd != nullptr);
673 VLOG(2) << "Try replace all uses of :" << old_user_boundary.ToString()
674 << "\n";
675 TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
676 TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
677 }
678 VLOG(2) << "Done changing conditional users\n"
679 << conditional_parent->ToString() << "\n";
680 // Create tuple element within each branch and set it as root.
681 int64_t branch_count = conditional->branch_count();
682 for (int i = 0; i < branch_count; i++) {
683 auto computation = conditional->branch_computation(i);
684 std::vector<HloInstruction*> elements;
685 for (const auto& b1 : new_boundaries) {
686 HloInstruction* op = b1.operands()[i];
687 CHECK(op != nullptr);
688 VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
689 elements.push_back(op);
690 }
691 HloInstruction* tuple =
692 computation->AddInstruction(HloInstruction::CreateTuple(elements));
693 computation->set_root_instruction(tuple, true);
694 VLOG(2) << "computation is :" << computation->ToString() << "\n";
695 // Remove hoisted instructions from the branches.
696 for (const auto& b2 : to_move_out) {
697 auto instr_to_remove = b2.operands()[i];
698 // Double check to make sure it is safe to delete the instruction.
699 // Complications may arise due to some operations in the alternative
700 // branches (branches 1..n) being placed into the boundaries multiple
701 // times.
702 if (!computation->IsMarkedAsDead(instr_to_remove) &&
703 instr_to_remove->IsDead()) {
704 VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
705 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove));
706 }
707 }
708 }
709 // Change conditional instruction shape to the shape of the new root.
710 HloInstruction* new_root =
711 conditional->branch_computation(0)->root_instruction();
712 *conditional->mutable_shape() = new_root->shape();
713 // Keep conditional instruction sharding consistent with the branches. Note
714 // that this sharding could be lost after this pass.
715 conditional->set_sharding(new_root->sharding_ptr());
716 VLOG(1) << "done moving instructions out of branches\n"
717 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
718 << "\n";
719 return true;
720 }
721
722 // Hoist ops from outside of the conditional to inside the branches.
MoveInstructionIn(HloInstruction * conditional,std::vector<Boundary> & to_move_in,std::vector<Boundary> & new_boundaries)723 StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
724 HloInstruction* conditional, std::vector<Boundary>& to_move_in,
725 std::vector<Boundary>& new_boundaries) {
726 if (to_move_in.empty()) {
727 return false;
728 }
729 VLOG(1) << "Modifying code---number of boundaries to move in:"
730 << to_move_in.size() << "\n";
731 VLOG(1) << "before opt:"
732 << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
733 << "\n";
734 // Mapping boundaries to be moved to their new representations.
735 absl::flat_hash_map<Boundary, Boundary> hoisted_boundaries;
736 int64_t to_move_in_size = to_move_in.size();
737 int64_t branch_count = conditional->branch_count();
738 HloGetTupleElementInstruction* tuple_use =
739 DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
740 // If use_index is -1, the old conditional root entry used by to_move_in
741 // instructions still need to be included as an entry of the modified
742 // conditional root, and the new result of the to_move_in instructions
743 // need to be added as an extra entry of the modified root; otherwise, the
744 // old root entry will be replaced with the new result in the modified root.
745 // The entry replacement should be allowed only if tuple_use has <=1 users.
746 int64_t use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
747 ? tuple_use->tuple_index()
748 : -1;
749 VLOG(2) << "Tuple use index = " << use_index << "\n";
750 // Number of old conditional entries still to be used outside.
751 // If conditional shape is not tuple, will create a tuple and use subscript
752 // 0 to save the old operand being used.
753 int64_t op_index =
754 conditional->shape().IsTuple()
755 ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
756 : conditional->shape().tuple_shapes_size())
757 : 0;
758 // Use to map the tuple_use instruction to its operand;
759 Boundary b_opd_use(Boundary::Position::kInsideBranch);
760 Boundary b_old_root(Boundary::Position::kInsideBranch);
761 // Create a new root instruction in each branch.
762 for (int i = 0; i < branch_count; i++) {
763 auto computation = conditional->branch_computation(i);
764 auto old_root = computation->root_instruction();
765 b_old_root.mutable_operands().push_back(old_root);
766 std::vector<HloInstruction*> operands;
767 if (old_root->opcode() == HloOpcode::kTuple) {
768 // Use operands of old_root directly, so old_root can be removed later.
769 for (int i = 0; i < old_root->operand_count(); ++i) {
770 if (i != use_index) {
771 operands.push_back(old_root->operands()[i]);
772 } else { // Map conditional use to the tuple operand.
773 b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
774 }
775 }
776 } else if (old_root->shape().IsTuple()) {
777 // If old_root is not a kTuple but has tuple shape, elements within the
778 // tuple must be extracted first to be used by the new instructions.
779 const Shape& old_shape = old_root->shape();
780 for (int i = 0; i < old_shape.tuple_shapes_size(); ++i) {
781 auto element =
782 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
783 old_shape.tuple_shapes(i), old_root, i));
784 if (i != use_index) {
785 operands.push_back(element);
786 } else {
787 b_opd_use.mutable_operands().push_back(element);
788 }
789 }
790 } else {
791 // If old_root is not a tuple and does not have tuple shape, use it
792 // to replace the conditional directly in the new computation.
793 b_opd_use.mutable_operands().push_back(conditional);
794 }
795
796 HloInstruction* new_root =
797 computation->AddInstruction(HloInstruction::CreateTuple(operands));
798 VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
799 computation->set_root_instruction(new_root,
800 /*accept_different_shape*/ true);
801 if (old_root->opcode() == HloOpcode::kTuple) {
802 TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
803 }
804 VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
805 }
806 // Update get tuple element index of the conditional.
807 if (use_index != -1) {
808 for (auto* user : conditional->users()) {
809 if (user->opcode() == HloOpcode::kGetTupleElement &&
810 user->tuple_index() > use_index) {
811 user->set_tuple_index(user->tuple_index() - 1);
812 }
813 }
814 }
815 Boundary conditional_boundary(Boundary::Position::kOutsideBranch);
816 conditional_boundary.mutable_operands().push_back(conditional);
817 hoisted_boundaries[conditional_boundary] = b_old_root;
818 int64_t cp_start = 0;
819 if (use_index >= 0) {
820 VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
821 Boundary tuple_use_boundary(Boundary::Position::kOutsideBranch);
822 tuple_use_boundary.mutable_operands().push_back(tuple_use);
823 hoisted_boundaries[tuple_use_boundary] = b_opd_use;
824 }
825 cp_start = (tuple_use != nullptr) ? 1 : 0;
826 for (int64_t to_move_index = cp_start; to_move_index < to_move_in_size;
827 to_move_index++) {
828 Boundary b_to_move = to_move_in[to_move_index];
829 HloInstruction* op = b_to_move.operands()[0];
830 CHECK(op != nullptr);
831 bool to_be_used_outside = true;
832 VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
833 if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
834 op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
835 to_be_used_outside = false;
836 VLOG(2) << "Instruction is not to be used outside the branch\n";
837 }
838 Boundary b(Boundary::Position::kInsideBranch);
839 CopyIntoConditional(b_to_move, conditional, hoisted_boundaries);
840 if (to_be_used_outside) {
841 for (int i = 0; i < branch_count; ++i) {
842 auto computation = conditional->branch_computation(i);
843 auto new_op = hoisted_boundaries[b_to_move].operands()[i];
844 auto new_root = computation->root_instruction();
845 new_root->AppendOperand(new_op);
846 *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape();
847 VLOG(2) << "Extending conditional root " << i << " : "
848 << new_root->ToString() << "\n";
849 }
850 // Modify uses of instructions outside of the conditionals
851 HloInstruction* gtr = conditional->parent()->AddInstruction(
852 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
853 op_index++));
854 TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
855 if (conditional->parent()->root_instruction() == op) {
856 conditional->parent()->set_root_instruction(gtr);
857 }
858 }
859 }
860 VLOG(2) << "Done copying instructions inside branch: "
861 << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n";
862 // Change conditional instruction shape to the shape of the new root.
863 HloInstruction* new_root =
864 conditional->branch_computation(0)->root_instruction();
865 *conditional->mutable_shape() = new_root->shape();
866 // Keep conditional instruction sharding consistent with the branches. Note
867 // that this sharding could be lost after this pass.
868 conditional->set_sharding(new_root->sharding_ptr());
869 VLOG(2) << "Before removing instructions:"
870 << conditional->parent()->ToString() << "\n";
871 // Remove hoisted instructions from the branches.
872 for (int64_t i = to_move_in_size - 1; i >= 0; i--) {
873 Boundary boundary_to_move_in = to_move_in[i];
874 HloInstruction* op = boundary_to_move_in.operands()[0];
875 if (op->user_count() == 0) {
876 VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
877 TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
878 VLOG(2) << "Done removing boundary.\n";
879 }
880 }
881
882 // Reset shapes of user gtes to the new shape.
883 if (use_index != -1) {
884 for (auto* user : conditional->users()) {
885 if (user->opcode() == HloOpcode::kGetTupleElement) {
886 VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
887 *user->mutable_shape() =
888 conditional->shape().tuple_shapes(user->tuple_index());
889 }
890 }
891 }
892 VLOG(1) << "Done moving instructions inside branches\n"
893 << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
894 << "\n";
895 return true;
896 }
897
898 // Group single chains of operands or uses of boundaries into new boundaries
899 class GroupConnectedBoundaries {
900 private:
901 std::vector<Boundary> connected_boundaries_, new_boundaries_;
902 HloInstruction* conditional_;
903 HloComputation* conditional_parent_;
904 bool is_layout_sensitive_;
905 // Instructions that have been visited but are not going to be moved.
906 absl::flat_hash_map<HloInstruction*, int>& visited_count_;
907 // The following four lines are configurations of the cost model, which will
908 // be used to determine whether to move an instruction (move_config_) and how
909 // strongly preferred it is to keep a pair of ops together (reuse_config_).
910 // The search_config_ is used to control how to navigate the search space of
911 // the cost model in the context of auto/manual tuning. The flipped array is
912 // used to save which entries in the configuration have been changed in the
913 // search/tuning process.
914 std::vector<std::vector<int64_t>>& move_config_;
915 std::vector<std::vector<int64_t>>& reuse_config_;
916 absl::Span<int64_t> search_config_vec_;
917 int64_t& search_config_;
918 int64_t search_subscript_;
919 absl::flat_hash_map<const int64_t*, int64_t> flipped_;
920
921 // The FlipMutation function serves to implement the search of alternative
922 // cost models by deciding whether to flip a given configuration, saved in
923 // the loc parameter. The non_zero parameter provides the new value to use
924 // to flip a zero. The msg parameter is only used for debugging purpposes.
FlipMutation(int64_t * loc,const int64_t non_zero,const std::string & msg)925 int64_t FlipMutation(int64_t* loc, const int64_t non_zero,
926 const std::string& msg) {
927 if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
928 VLOG(2) << "Configured not to search or loc is already flipped.";
929 return *loc;
930 }
931 // The last 8 digits control when to start the first flip.
932 int c = ConditionalCodeMotion::flip_start(search_config_);
933 VLOG(2) << "flip start index = " << c << "\n";
934 // Only flip the decision if c reaches 0.
935 if (c > 0) {
936 search_config_--;
937 return *loc;
938 }
939 // The 8-16 digits control the maximum number of times to flip a config.
940 auto flip_count = ConditionalCodeMotion::DecrementMaxFlip(&search_config_);
941 VLOG(2) << "max flip count = " << flip_count << "\n";
942 VLOG(2) << "Updating max Flipping configuration = " << search_config_
943 << "\n";
944 if (flip_count == 0) {
945 VLOG(2) << "Maximum flip count has reached. ";
946 if (search_subscript_ + 1 < search_config_vec_.size()) {
947 VLOG(2) << "search_subscript_ = " << search_subscript_;
948 VLOG(2) << "search config vec size = " << search_config_vec_.size();
949 search_config_ = search_config_vec_[++search_subscript_];
950 } else {
951 return *loc;
952 }
953 }
954 // Reload the 16-23 digits of the configuration, which controls how
955 // frequently a configuration should be flipped.
956 auto flip_stride = ConditionalCodeMotion::flip_stride(search_config_);
957 search_config_ += flip_stride;
958 VLOG(2) << "flip stride = " << flip_stride << "\n";
959 VLOG(2) << "Updating Flipping Stride = " << search_config_ << "\n";
960
961 flipped_[loc] = *loc;
962 // Copy the last 8 bits back to the first 8 bits of configuration.
963 switch (*loc) {
964 case 0:
965 *loc = non_zero;
966 break;
967 default:
968 *loc = 0;
969 break;
970 }
971 VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
972 << " to " << *loc << "\n";
973 return *loc;
974 }
975
EnsureSearchConfig(std::vector<int64_t> & search_config)976 static std::vector<int64_t>& EnsureSearchConfig(
977 std::vector<int64_t>& search_config) {
978 if (search_config.empty()) {
979 search_config.push_back(0);
980 }
981 return search_config;
982 }
983
984 public:
GroupConnectedBoundaries(HloInstruction * conditional,bool is_layout_sensitive,absl::flat_hash_map<HloInstruction *,int> & visited_count,std::vector<std::vector<int64_t>> * move_config,std::vector<std::vector<int64_t>> * reuse_config,std::vector<int64_t> & search_config)985 explicit GroupConnectedBoundaries(
986 HloInstruction* conditional, bool is_layout_sensitive,
987 absl::flat_hash_map<HloInstruction*, int>& visited_count,
988 std::vector<std::vector<int64_t>>* move_config,
989 std::vector<std::vector<int64_t>>* reuse_config,
990 std::vector<int64_t>& search_config)
991 : conditional_(conditional),
992 conditional_parent_(conditional->parent()),
993 is_layout_sensitive_(is_layout_sensitive),
994 visited_count_(visited_count),
995 move_config_(*move_config),
996 reuse_config_(*reuse_config),
997 search_config_vec_(EnsureSearchConfig(search_config)),
998 search_config_(search_config_vec_.front()),
999 search_subscript_(0) {
1000 VLOG(2) << "Initializing Group Connected Boundaries\n";
1001 }
1002 // Returns estimation of potential reuses carried by a given pair of
1003 // instructions. Use different integers to classify different levels
1004 // of reuses. Assume all instructions can be fused to enable data reuses.
ReusesCarriedBy(HloInstruction * op,HloInstruction * user)1005 int64_t ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
1006 std::vector<int64_t>& curconfig =
1007 reuse_config_[static_cast<uint32_t>(op->opcode())];
1008 // Flip the reuse configuration if tuning the cost model.
1009 // When flipping, use -10 if flipping to the default reuse model. Other
1010 // values can be specified if needed to fine-control the decision making.
1011 int64_t config =
1012 (search_config_ < 0)
1013 ? FlipMutation(&curconfig[static_cast<uint32_t>(user->opcode())],
1014 -10,
1015 HloOpcodeString(op->opcode()) + "->" +
1016 HloOpcodeString(user->opcode()))
1017 : curconfig[static_cast<uint32_t>(user->opcode())];
1018 VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
1019 << op->ToString() << "=>" << user->ToString() << " : " << config
1020 << "\n";
1021 if (config < 0) {
1022 // Assume the reuse decreases with increasing user count.
1023 int count1 = CountNonLeafOps(op->users());
1024 int count2 = CountNonLeafOps(user->operands());
1025 return (-config) / count1 / count2;
1026 }
1027 return config;
1028 }
clear_recently_visited()1029 void clear_recently_visited() {
1030 for (const auto& boundary : new_boundaries_) {
1031 visited_count_.erase(boundary.operands()[0]);
1032 }
1033 }
1034 // Returns true if `instruction` is worth hoisting.
WorthHoisting(HloInstruction * instruction,bool is_inside_branch)1035 bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
1036 // This is needed for the "moving-in" transformation, to prevent the root
1037 // of the parent computation (which contains the conditional) to be moved
1038 // inside the conditional.
1039 HloOpcode opcode = instruction->opcode();
1040 if (opcode == HloOpcode::kTuple &&
1041 instruction == conditional_parent_->root_instruction()) {
1042 return false;
1043 }
1044 // It is not safe to move collective ops from outside to inside
1045 // conditional branches, as it may cause synchronization problems,
1046 // when different layouts are assigned to different branches.
1047 if (DynCast<HloCollectiveInstruction>(instruction) && !is_inside_branch) {
1048 return false;
1049 }
1050
1051 // It is not legal to move the parameter instructions.
1052 if (opcode == HloOpcode::kParameter) {
1053 return false;
1054 }
1055
1056 // Use configuration given from outside (e.g., by autotuner).
1057 std::vector<int64_t>& curconfig =
1058 move_config_[static_cast<uint32_t>(opcode)];
1059 auto col = (curconfig.size() == 1) ? 0
1060 : (instruction->operand_count() > 0)
1061 ? static_cast<uint32_t>(instruction->operand(0)->opcode())
1062 : 0;
1063 VLOG(2) << "column = " << col << "\n";
1064 VLOG(2) << "config size = " << curconfig.size() << "\n";
1065 VLOG(2) << "search_config = " << search_config_ << "\n";
1066 CHECK(col < curconfig.size());
1067 uint32_t config = (search_config_ > 0)
1068 ? FlipMutation(&curconfig[col], 1,
1069 "Move-" + HloOpcodeString(opcode))
1070 : curconfig[col];
1071 VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
1072 VLOG(2) << "after checking search_config = " << search_config_ << "\n";
1073 return (config != 0);
1074 }
1075
ReusesBeforeBoundary(HloInstruction * user)1076 int64_t ReusesBeforeBoundary(HloInstruction* user) {
1077 int64_t reuses = 0;
1078 for (auto op : user->operands()) {
1079 // The operand must be an instruction that is not going to be moved (if
1080 // user is inside the conditional); otherwise it must be the conditional
1081 // itself and its user must be outside of the conditional.
1082 if (!ContainsKey(visited_count_, op) && op != conditional_) {
1083 continue;
1084 }
1085 if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
1086 if (op->opcode() == HloOpcode::kConditional) {
1087 auto tuple = op->branch_computation(0)->root_instruction();
1088 if (tuple->opcode() == HloOpcode::kTuple) {
1089 auto index = tuple_gte->tuple_index();
1090 CHECK(index < tuple->operand_count());
1091 op = tuple->mutable_operand(index);
1092 }
1093 }
1094 reuses += ReusesCarriedBy(op, user->users()[0]);
1095 } else {
1096 reuses += ReusesCarriedBy(op, user);
1097 }
1098 }
1099 VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
1100 << "\n";
1101 return reuses;
1102 }
1103
ReusesAfterBoundary(HloInstruction * user)1104 int64_t ReusesAfterBoundary(HloInstruction* user) {
1105 CHECK(user != nullptr);
1106 auto all_users = user->users();
1107 // For now, assume that if an instruction has multiple-consumers, it
1108 // will not be reused, as the reuse may require duplication in
1109 // fusion and so is expensive. If the situation changes in the future,
1110 // some aspects of the overall algorithm need to be redesigned to
1111 // accommandate the change.
1112 if (all_users.size() > 1) {
1113 VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
1114 return 0;
1115 }
1116 if (!all_users.empty()) {
1117 auto op = all_users[0];
1118 int64_t reuses = 0;
1119 // Only count reuses that run through the conditional root.
1120 if (op == conditional_->branch_computation(0)->root_instruction()) {
1121 int64_t index = op->operand_index(user);
1122 for (auto op2 : conditional_->users()) {
1123 // If the use is not get tuple, right now do not consider it.
1124 if (op2->opcode() == HloOpcode::kGetTupleElement) {
1125 auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2);
1126 if (index == tuple_opd->tuple_index()) {
1127 all_users = op2->users();
1128 if (!all_users.empty()) {
1129 reuses += ReusesCarriedBy(user, all_users[0]);
1130 break;
1131 }
1132 }
1133 }
1134 }
1135 } else if (ContainsKey(visited_count_, op)) {
1136 reuses += ReusesCarriedBy(user, op);
1137 }
1138 VLOG(2) << "reuses after instruction " << user->ToString() << ":"
1139 << reuses << "\n";
1140 return reuses;
1141 }
1142 return 0;
1143 }
1144
BenefitForMovingBoundaries(const std::vector<Boundary> & boundaries,bool perform_reuse_analysis=true)1145 int64_t BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries,
1146 bool perform_reuse_analysis = true) {
1147 int64_t reuses_before = 0, reuses_after = 0;
1148 if (boundaries.size() == 1) {
1149 if (boundaries[0].IsOutsideBranch() &&
1150 boundaries[0].operands()[0]->opcode() ==
1151 HloOpcode::kGetTupleElement) {
1152 // The only boundary of moving-in is the get_tuple_element op.
1153 return -1;
1154 }
1155 if (boundaries[0].IsInsideBranch() &&
1156 boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
1157 // The only boundary of moving-out is the tuple op inside branches.
1158 return -1;
1159 }
1160 }
1161 // If trying alternative moving configurations, turn off reuse analysis.
1162 if (!perform_reuse_analysis) {
1163 return 1;
1164 }
1165 // For cases like :
1166 // branch0 {
1167 // ROOT copy
1168 // }
1169 // branch1 {
1170 // ...
1171 // }
1172 // cond = conditional(branch0, branch1)
1173 // copy = copy(cond)
1174 //
1175 // We can fold the two copies thus reducing computation.
1176 auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64_t {
1177 if (hlo->opcode() != HloOpcode::kCopy) {
1178 return 0;
1179 }
1180 const HloGetTupleElementInstruction* gte =
1181 DynCast<HloGetTupleElementInstruction>(hlo->operand(0));
1182 if (gte == nullptr) {
1183 return 0;
1184 }
1185 const HloInstruction* conditional = gte->operand(0);
1186 if (conditional != conditional_) {
1187 return 0;
1188 }
1189 int64_t benefit = 0;
1190 for (auto* branch : conditional->called_computations()) {
1191 HloInstruction* root = branch->root_instruction();
1192 if (root->opcode() == HloOpcode::kTuple) {
1193 const auto* tuple_operand = root->operand(gte->tuple_index());
1194 if (tuple_operand->opcode() == HloOpcode::kCopy) {
1195 if (Shape::Equal()(tuple_operand->operand(0)->shape(),
1196 hlo->shape())) {
1197 benefit += 10;
1198 }
1199 }
1200 }
1201 }
1202 return benefit;
1203 };
1204 for (const Boundary& b : boundaries) {
1205 auto op = b.operands()[0];
1206 if (op == conditional_->branch_computation(0)->root_instruction()) {
1207 continue;
1208 }
1209 VLOG(2) << "Benefit for " << op->ToString();
1210 reuses_before += ReusesBeforeBoundary(op);
1211 VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
1212 reuses_after += ReusesAfterBoundary(op);
1213 VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
1214 }
1215
1216 int64_t copy_folding_benefit = 0;
1217 if (boundaries[0].IsOutsideBranch()) {
1218 for (const Boundary& b : boundaries) {
1219 auto op = b.operands()[0];
1220 copy_folding_benefit += get_copy_folding_benefit(op);
1221 }
1222 }
1223 VLOG(2) << "Copy folding benefit: " << copy_folding_benefit;
1224
1225 if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) {
1226 return -1;
1227 } else if (boundaries[0].IsInsideBranch()) {
1228 return reuses_after - reuses_before;
1229 } else {
1230 return reuses_before - reuses_after - 1 + copy_folding_benefit;
1231 }
1232 }
1233
GetNextBoundary(const Boundary & b,int64_t op_index)1234 Boundary GetNextBoundary(const Boundary& b, int64_t op_index) {
1235 Boundary b2(b.GetPosition());
1236 for (int j = 0; j < b.operands().size(); ++j) {
1237 HloInstruction* inst = b.operands()[j];
1238 CHECK(inst != nullptr);
1239 HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
1240 : inst->users()[op_index];
1241 CHECK(op != nullptr);
1242 b2.mutable_operands().push_back(op);
1243 }
1244 return b2;
1245 }
1246
1247 // Checking whether it is safe to move a boundary when visited through a
1248 // dependent already considered for moving.
IsSafeToMoveBoundary(const Boundary & next_boundary)1249 bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
1250 int64_t next_boundary_count =
1251 (next_boundary.IsInsideBranch())
1252 ? next_boundary.operands()[0]->user_count()
1253 : CountNonLeafOps(next_boundary.operands()[0]->operands());
1254 if (next_boundary_count <= 1) {
1255 // If boundary has only a single or no dependent, safe to move.
1256 return true;
1257 } else {
1258 if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
1259 VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
1260 << " because it has multiple dependents: "
1261 << next_boundary_count << "\n";
1262 visited_count_[next_boundary.operands()[0]] = 1;
1263 new_boundaries_.push_back(next_boundary);
1264 } else {
1265 auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
1266 next_boundary);
1267 if (pos != new_boundaries_.end() ||
1268 next_boundary.operands().size() == 1) {
1269 int count = ++visited_count_[next_boundary.operands()[0]];
1270 if (count == next_boundary_count) {
1271 VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
1272 << "\n"
1273 << " because all of its dependents have been visited: "
1274 << next_boundary_count << "\n";
1275 visited_count_.erase(next_boundary.operands()[0]);
1276 if (pos != new_boundaries_.end()) {
1277 new_boundaries_.erase(pos);
1278 }
1279 return true;
1280 }
1281 } else {
1282 VLOG(2) << "Skip incompatible multi-dependent boundary: "
1283 << next_boundary.ToString() << ":" << next_boundary_count
1284 << "\n";
1285 }
1286 }
1287 }
1288 return false;
1289 }
1290 // This function is reused both for moving the boundary outside or into a
1291 // conditional. As the result, the readability is somewhat compromised.
1292 // It might be nice to refactor this function to factor the outside-inside
1293 // considerations into separate function pointer parameters to improve
1294 // readability.
AddBoundaries(const Boundary & boundary)1295 void AddBoundaries(const Boundary& boundary) {
1296 BoundaryVisitor visitor;
1297 visitor.AddToWorkList(boundary);
1298 while (visitor.HasNextBoundary()) {
1299 Boundary b = visitor.PopNextBoundary();
1300 VLOG(2) << "visiting boundary " << b.ToString() << "\n";
1301 if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
1302 b.operands(), is_layout_sensitive_)) &&
1303 IsSafeToMoveBoundary(b) &&
1304 WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
1305 connected_boundaries_.push_back(b);
1306 VLOG(2) << "boundary can be moved\n";
1307 int64_t operand_count = (b.IsInsideBranch())
1308 ? b.operands()[0]->operand_count()
1309 : b.operands()[0]->users().size();
1310 for (int i = 0; i < operand_count; i++) {
1311 Boundary next_boundary = GetNextBoundary(b, i);
1312 VLOG(2) << "Add operand/user " << i << " to visit later\n";
1313 visitor.AddToWorkList(next_boundary);
1314 }
1315 } else {
1316 VLOG(2) << "boundary cannot be moved\n";
1317 visited_count_[b.operands()[0]] = 1;
1318 new_boundaries_.push_back(b);
1319 }
1320 }
1321 }
BoundariesToMoveInOrOut(HloInstruction * conditional,const Boundary & b)1322 std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
1323 const Boundary& b) {
1324 // At the beginning of optimization, a conditional itself is added to a
1325 // worklist. Here the conditional is expanded into two sets of boundaries:
1326 // the first set contains the boundary that is inside branches and
1327 // contains the root of all branches; the second set of boundaries
1328 // contains all the users of the conditional.
1329 HloInstruction* inst = b.operands()[0];
1330 if (inst == conditional) {
1331 int branch_count = inst->branch_count();
1332 // Add conditional roots as a new boundary to visit.
1333 Boundary boundary_in(Boundary::Position::kInsideBranch);
1334 for (int i = 0; i < branch_count; i++) {
1335 HloComputation* branch_computation = inst->branch_computation(i);
1336 HloInstruction* root_inst = branch_computation->root_instruction();
1337 CHECK(root_inst != nullptr);
1338 boundary_in.mutable_operands().push_back(root_inst);
1339 }
1340 new_boundaries_.push_back(boundary_in);
1341 // Add conditional users as new boundaries to visit.
1342 for (auto u : inst->users()) {
1343 Boundary boundary_in(Boundary::Position::kOutsideBranch);
1344 boundary_in.mutable_operands().push_back(u);
1345 new_boundaries_.push_back(boundary_in);
1346 }
1347 } else {
1348 AddBoundaries(b);
1349 }
1350 return connected_boundaries_;
1351 }
AddNewBoundaries(std::vector<Boundary> & b)1352 void AddNewBoundaries(std::vector<Boundary>& b) {
1353 b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end());
1354 }
1355 };
1356
ConsiderCodeMotion(HloInstruction * conditional,const Boundary & cur_boundary,std::vector<Boundary> & to_move,std::vector<Boundary> & new_boundaries,absl::flat_hash_map<HloInstruction *,int> & visited_count)1357 ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
1358 HloInstruction* conditional, const Boundary& cur_boundary,
1359 std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
1360 absl::flat_hash_map<HloInstruction*, int>& visited_count) {
1361 GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
1362 visited_count, &move_config_, &reuse_config_,
1363 search_config_);
1364 auto move_in_or_out =
1365 connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
1366 if (!move_in_or_out.empty()) {
1367 auto benefit = connect.BenefitForMovingBoundaries(
1368 move_in_or_out, search_config_map_.empty());
1369 VLOG(2) << "benefit of moving in or out "
1370 << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
1371 if (benefit >= 0) {
1372 new_boundaries.clear();
1373 connect.AddNewBoundaries(new_boundaries);
1374 // The whole sequence in move_in_or_out is either all moving into a
1375 // conditional, or all moving out of a conditional. So looking only
1376 // at the first entry of the sequence is sufficient to know which
1377 // direction the move is intended.
1378 to_move = move_in_or_out;
1379 return Decision(to_move[0].IsInsideBranch()
1380 ? Decision::Direction::kMoveOutOfBranch
1381 : Decision::Direction::kMoveIntoBranch,
1382 benefit);
1383 } else {
1384 connect.clear_recently_visited();
1385 }
1386 } else {
1387 connect.AddNewBoundaries(new_boundaries);
1388 }
1389 return Decision(Decision::Direction::kNoChange, 0);
1390 }
1391
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1392 StatusOr<bool> ConditionalCodeMotion::Run(
1393 HloModule* module,
1394 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1395 VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
1396 // Use to support debugging of optimization, by disabling the opt after it has
1397 // been applied a pre-determined times (to isolate impact of transformations).
1398 if (!ConsumeFuel("conditional_code_motion", [&] {
1399 return "Skipping conditional opt after allowed limit reaching 0.\n";
1400 })) {
1401 return false;
1402 }
1403 bool changed = false;
1404 bool cleanup_changed = false;
1405 {
1406 HloPassPipeline subpipeline("before_conditional_code_motion");
1407 subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_);
1408 subpipeline.AddPass<HloDCE>();
1409 TF_ASSIGN_OR_RETURN(auto cleanup_changed_now,
1410 subpipeline.Run(module, execution_threads));
1411 cleanup_changed |= cleanup_changed_now;
1412 }
1413 // Gather all the conditional ops in the module ahead of time, to avoid
1414 // potential complications of modifying the code that affecting traversal.
1415 std::vector<HloInstruction*> conditional_ops;
1416 // Track how many times each branch computation is shared.
1417 absl::flat_hash_map<HloComputation*, int> conditional_computations;
1418 for (auto* comp : module->MakeComputationPostOrder(execution_threads)) {
1419 for (auto* instr : comp->MakeInstructionPostOrder()) {
1420 if (instr->opcode() == HloOpcode::kConditional) {
1421 int branch_count = instr->branch_count();
1422 for (int i = 0; i < branch_count; ++i) {
1423 HloComputation* branch_i = instr->branch_computation(i);
1424 if (ContainsKey(conditional_computations, branch_i)) {
1425 conditional_computations[branch_i]++;
1426 } else {
1427 conditional_computations[branch_i] = 0;
1428 }
1429 }
1430 if (instr->shape().IsTuple()) {
1431 bool can_change_tuple_shape = true;
1432 for (auto user : instr->users()) {
1433 VLOG(2) << "user is : " << user->ToString() << "\n";
1434 if (user->opcode() != HloOpcode::kGetTupleElement) {
1435 can_change_tuple_shape = false;
1436 }
1437 }
1438 if (can_change_tuple_shape) {
1439 conditional_ops.push_back(instr);
1440 }
1441 } else {
1442 conditional_ops.push_back(instr);
1443 }
1444 }
1445 }
1446 }
1447
1448 int64_t conditional_index = 0;
1449 // Use to collect mappings between cloned instructions.
1450 HloCloneContext clone_context(module);
1451 for (HloInstruction* conditional : conditional_ops) {
1452 if (conditional_index == 0 || !search_config_map_.empty()) {
1453 auto config_entry = search_config_map_.find(conditional_index);
1454 if (config_entry != search_config_map_.end()) {
1455 search_config_ = (*config_entry).second;
1456 VLOG(2) << "config entry value extracted:" << search_config_.size();
1457 search_config_index_ = 0;
1458 }
1459 VLOG(2) << "Obtaining default configuration for conditional "
1460 << conditional_index << "\n";
1461 SetDefaultMoveConfig();
1462 VLOG(2) << "Done obtaining default configuration\n";
1463 conditional_index++;
1464 }
1465 int branch_count = conditional->branch_count();
1466 // check for shared conditional computations
1467 bool conditional_is_shared = false;
1468 for (int i = 0; i < branch_count; ++i) {
1469 HloComputation* branch_i = conditional->branch_computation(i);
1470 if (conditional_computations[branch_i] > 0) {
1471 conditional_is_shared = true;
1472 break;
1473 }
1474 }
1475
1476 // Boundaries to move out or to move into the branches.
1477 std::vector<std::vector<Boundary>> to_move_out, to_move_in;
1478 std::vector<std::vector<Boundary>> new_boundaries_for_moveout;
1479 std::vector<std::vector<Boundary>> new_boundaries_for_movein;
1480 // Number of times each instruction has been visited for moving.
1481 absl::flat_hash_map<HloInstruction*, int> visited_count;
1482 int benefit_move_out = 0, benefit_move_in = 0;
1483 Decision::Direction final_d = Decision::Direction::kNoChange;
1484 // The conditional is moved into a worklist as the seed (starting point).
1485 // The conditional will be expanded into multiple seeds (starting points),
1486 // its roots and its users, when it is visited by GroupConnectedBoundaries.
1487 // A NO_CHANGE decision will always be returned for the conditional itself,
1488 // so that the other seeding boundaries can be visited in turn.
1489 BoundaryVisitor visitor(conditional);
1490 VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
1491 // Try visit all the boundaries, collect the analysis results, and save
1492 // all the benefitical non-conflicting decisions. If two decisions conflict
1493 // with each other, save the more benefitical one.
1494 while (visitor.HasNextBoundary()) {
1495 std::vector<Boundary> to_move, next_boundary;
1496 Boundary boundary = visitor.PopNextBoundary();
1497 VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
1498 auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
1499 visited_count);
1500 switch (d.GetDirection()) {
1501 case Decision::Direction::kMoveOutOfBranch:
1502 VLOG(2) << "Local Decision is move out of branch\n";
1503 to_move_out.push_back(to_move);
1504 new_boundaries_for_moveout.push_back(next_boundary);
1505 benefit_move_out += d.GetBenefit();
1506 if (benefit_move_out >= benefit_move_in) {
1507 final_d = Decision::Direction::kMoveOutOfBranch;
1508 VLOG(2) << "Current Decision is move out of branch ("
1509 << to_move_out.size() << ")\n";
1510 } else {
1511 VLOG(2) << "Current Decision remains move into branch\n";
1512 }
1513 break;
1514 case Decision::Direction::kMoveIntoBranch:
1515 VLOG(2) << "Decision is move into branch\n";
1516 to_move_in.push_back(to_move);
1517 new_boundaries_for_movein.push_back(next_boundary);
1518 benefit_move_in += d.GetBenefit();
1519 if (benefit_move_out >= benefit_move_in) {
1520 VLOG(2) << "Current Decision remains move out of branch\n";
1521 } else {
1522 final_d = Decision::Direction::kMoveIntoBranch;
1523 VLOG(2) << "Current Decision is move into branch ("
1524 << to_move_in.size() << ")\n";
1525 }
1526 break;
1527 case Decision::Direction::kNoChange:
1528 VLOG(2) << "Decision is no change\n";
1529 for (const Boundary& b : next_boundary) {
1530 visitor.AddToWorkList(b);
1531 VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
1532 << "\n";
1533 }
1534 break;
1535 }
1536 }
1537 // If modification is to be made, need to clone the shared branches.
1538 if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
1539 for (int i = 0; i < branch_count; ++i) {
1540 HloComputation* branch_i = conditional->branch_computation(i);
1541 if (conditional_computations[branch_i] > 0) {
1542 // Cloning is absolutely needed if the computation is shared by
1543 // different branches, but the cloning can be potentially avoided
1544 // if the sharing is only among branches of the same conditional.
1545 // If cloning these branches causes a problem due to space issues,
1546 // a fix can pass a vector of unique branches to the actual
1547 // transformations, as an alternative representation of the
1548 // conditional branches to be modified. Right now we assume the
1549 // overhead of cloning is minimal since later stages of the compiler
1550 // inline all the computations anyway.
1551 HloComputation* clone_i =
1552 conditional->parent()->parent()->AddEmbeddedComputation(
1553 branch_i->Clone("clone", &clone_context));
1554 conditional->set_branch_computation(i, clone_i);
1555 conditional_computations[branch_i]--;
1556 // Need to translate the analysis result to generate correct result.
1557 auto update_boundary = [&](Boundary& boundary) {
1558 auto cloned_instr =
1559 clone_context.FindInstruction(boundary.operands()[i]);
1560 CHECK(cloned_instr != nullptr);
1561 VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
1562 << "\n";
1563 boundary.mutable_operands()[i] = cloned_instr;
1564 VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
1565 << "\n";
1566 };
1567 // Only boundaries to move out need to be updated.
1568 if (final_d == Decision::Direction::kMoveOutOfBranch) {
1569 for (int i = 0; i < to_move_out.size(); ++i) {
1570 std::vector<Boundary>& m = to_move_out[i];
1571 std::for_each(m.begin(), m.end(), update_boundary);
1572 }
1573 for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
1574 std::vector<Boundary>& m = new_boundaries_for_moveout[i];
1575 std::for_each(m.begin(), m.end(), update_boundary);
1576 }
1577 }
1578 }
1579 }
1580 VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
1581 << "\n";
1582 }
1583 // At most one of to_move_out or to_move_in can be non-empty, since there is
1584 // only one optimization decision.
1585 if (final_d == Decision::Direction::kMoveOutOfBranch) {
1586 CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
1587 for (int i = 0; i < to_move_out.size(); ++i) {
1588 TF_ASSIGN_OR_RETURN(bool result,
1589 MoveInstructionOut(conditional, to_move_out[i],
1590 new_boundaries_for_moveout[i]));
1591 changed |= result;
1592 }
1593 VLOG(2) << "Done moving out of branches " << to_move_out.size()
1594 << " times. \n";
1595 if (!ConsumeFuel("conditional_code_motion", [&] {
1596 return "Skipping conditional opt after allowed limit reaching 0.\n";
1597 })) {
1598 break;
1599 }
1600 } else if (final_d == Decision::Direction::kMoveIntoBranch) {
1601 CHECK(to_move_in.size() == new_boundaries_for_movein.size());
1602 for (int i = 0; i < to_move_in.size(); ++i) {
1603 TF_ASSIGN_OR_RETURN(bool result,
1604 MoveInstructionIn(conditional, to_move_in[i],
1605 new_boundaries_for_movein[i]));
1606 changed |= result;
1607 }
1608 VLOG(2) << "Done moving into branches " << to_move_in.size()
1609 << " times. \n";
1610 if (!ConsumeFuel("conditional_code_motion", [&] {
1611 return "Skipping conditional opt after allowed limit reaching 0.\n";
1612 })) {
1613 break;
1614 }
1615 } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
1616 // Invoke special handling for convert rematerialization/hoisting
1617 // We need to make sure no sharing is present in the branches because no
1618 // cloning has been done by the earlier analysis.
1619 // TOOD[b/165848866]: extend solution to handle cloning for special move.
1620 TF_ASSIGN_OR_RETURN(
1621 bool convert_result,
1622 ConvertSpecialMove(conditional, is_layout_sensitive_));
1623 if (convert_result) {
1624 VLOG(2) << "Done special moving of convert\n";
1625 if (!ConsumeFuel("conditional_code_motion", [&] {
1626 return "Skipping conditional opt after allowed limit reaching "
1627 "0.\n";
1628 })) {
1629 break;
1630 }
1631 }
1632 changed |= convert_result;
1633 }
1634 }
1635 if (changed) {
1636 HloPassPipeline subpipeline(
1637 "after_conditional_code_motion_after_convert_hoisting");
1638 VLOG(2) << "starting after motion passes: DCE\n";
1639 subpipeline.AddPass<HloDCE>();
1640 subpipeline.AddPass<TupleSimplifier>();
1641 subpipeline.AddPass<HloDCE>();
1642 TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1643 cleanup_changed |= cleanup_changed_now;
1644 }
1645 if (cleanup_changed) {
1646 VLOG(2) << "subpipeline cleanup have modified code\n";
1647 }
1648 return changed;
1649 }
1650
SetDefaultMoveConfig()1651 void ConditionalCodeMotion::SetDefaultMoveConfig() {
1652 VLOG(2) << "search_config_index = " << search_config_index_ << "\n";
1653 VLOG(2) << "search_config_ size = " << search_config_.size() << "\n";
1654 int64_t cur_search_config = (search_config_index_ < 0 ||
1655 search_config_index_ >= search_config_.size())
1656 ? 0
1657 : search_config_[search_config_index_];
1658 enum class TuningOption {
1659 kDoNotTune = 0,
1660 kTuneTransformationDecision = 1,
1661 kTuneReuseModel = 2,
1662 };
1663 TuningOption tuning_option =
1664 (cur_search_config == 0) ? TuningOption::kDoNotTune
1665 : (cur_search_config > 0) ? TuningOption::kTuneTransformationDecision
1666 : TuningOption::kTuneReuseModel;
1667
1668 auto row = HloOpcodeCount();
1669 auto col = row;
1670 VLOG(2) << "Start setting default configuration\n";
1671 reuse_config_.clear();
1672 move_config_.clear();
1673 reuse_config_.reserve(row);
1674 move_config_.reserve(row);
1675 for (int64_t opcode = 0; opcode < row; ++opcode) {
1676 // To save whether an instruction is preferred to be moved.
1677 std::vector<int64_t> reuse_vec(col, 0);
1678 for (uint32_t j = 0; j < col; ++j) {
1679 reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
1680 static_cast<HloOpcode>(j));
1681 }
1682 reuse_config_.push_back(reuse_vec);
1683 std::vector<int64_t> move_vec;
1684 switch (tuning_option) {
1685 case TuningOption::kTuneTransformationDecision:
1686 // Tuning transformation decision --- start with all yes.
1687 // Only a single entry is needed if we don't consider operands of an op
1688 // when searching/tuning transformation decisions.
1689 move_vec.push_back(1);
1690 break;
1691 // Tune the ReusesCarriedBy results only.
1692 case TuningOption::kTuneReuseModel:
1693 case TuningOption::kDoNotTune:
1694 // No tuning --- use the default configuration.
1695 // Use the opcode of first operand to configure default.
1696 move_vec.reserve(col);
1697 for (uint32_t j = 0; j < col; ++j) {
1698 move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
1699 static_cast<HloOpcode>(j)));
1700 }
1701 break;
1702 }
1703 move_config_.push_back(move_vec);
1704 }
1705 }
1706
1707 // The search configuration is specified using a string in the format of
1708 // 'config1;config2; ...;config_n', where each config_i is in the format of
1709 // 'index,start,max,stride' (four integers separated by comma), which specify
1710 // the index number of the conditional being configured, the index of the first
1711 // transformation decision to flip for the conditional, the max number of
1712 // decisions to flip, and how many decisions to skip in between the flips.
ParseSearchConfiguration(const std::string & search_config)1713 void ConditionalCodeMotion::ParseSearchConfiguration(
1714 const std::string& search_config) {
1715 if (search_config.empty()) {
1716 return;
1717 }
1718 search_config_index_ = 0;
1719 std::vector<std::string> configs = absl::StrSplit(search_config, ';');
1720 for (const std::string& config : configs) {
1721 std::vector<std::string> specs = absl::StrSplit(config, ',');
1722 CHECK_EQ(specs.size(), 4);
1723 int64_t condition_index;
1724 CHECK(absl::SimpleAtoi(specs[0], &condition_index));
1725 auto& cur_config_entry = search_config_map_[condition_index];
1726 int64_t flip_start, max_flip, flip_stride;
1727 CHECK(absl::SimpleAtoi(specs[1], &flip_start));
1728 CHECK(absl::SimpleAtoi(specs[2], &max_flip));
1729 CHECK(absl::SimpleAtoi(specs[3], &flip_stride));
1730 int64_t cur_config = MakeSearchConfig(flip_start, max_flip, flip_stride);
1731 cur_config_entry.push_back(cur_config);
1732 VLOG(2) << "Setting search config " << condition_index << "->" << cur_config
1733 << "\n";
1734 }
1735 }
1736
1737 } // namespace conditional_opt
1738
1739 } // namespace xla
1740