1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
17
18 #include <iterator>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/call_inliner.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42
43 namespace xla {
44
45 namespace {
46
47 // A computation with array type that only contains parameters and tuples is
48 // considered emtpy.
ComputationIsEmptyWithArrayRoot(const HloComputation * computation)49 bool ComputationIsEmptyWithArrayRoot(const HloComputation* computation) {
50 bool empty_operations = absl::c_all_of(
51 computation->MakeInstructionPostOrder(), [](const HloInstruction* inst) {
52 return inst->opcode() == HloOpcode::kTuple ||
53 inst->opcode() == HloOpcode::kGetTupleElement ||
54 inst->opcode() == HloOpcode::kParameter;
55 });
56 bool contains_array = false;
57 ShapeUtil::ForEachSubshape(computation->root_instruction()->shape(),
58 [&](const Shape& shape, const ShapeIndex& index) {
59 if (shape.IsArray()) {
60 contains_array = true;
61 }
62 });
63 return empty_operations && contains_array;
64 }
65
TryRemoveUnusedConditionalOperands(HloComputation * computation,const absl::flat_hash_set<HloInstruction * > & calling_conditionals)66 StatusOr<bool> TryRemoveUnusedConditionalOperands(
67 HloComputation* computation,
68 const absl::flat_hash_set<HloInstruction*>& calling_conditionals) {
69 HloInstruction* param = computation->parameter_instruction(0);
70 // Do not remove from the root instruction.
71 if (param == computation->root_instruction()) {
72 return false;
73 }
74 // There is nothing to be removed for non-tuple operands.
75 if (!param->shape().IsTuple()) {
76 return false;
77 }
78 std::set<int64_t> tuple_indices_to_keep;
79 for (HloInstruction* user : param->users()) {
80 // If the user is not a get tuple element, assume it is unsafe to remove
81 // elements from the tuple.
82 if (user->opcode() != HloOpcode::kGetTupleElement) {
83 return false;
84 }
85 tuple_indices_to_keep.insert(user->tuple_index());
86 }
87 // If all tuple elements are used in this conditional branch, there is nothing
88 // to be removed.
89 int64_t old_tuple_element_count =
90 ShapeUtil::TupleElementCount(param->shape());
91 if (tuple_indices_to_keep.size() == old_tuple_element_count) {
92 return false;
93 }
94
95 // Create a new tuple shape based on the indices actually used by this
96 // computation branch.
97 std::vector<const Shape*> new_tuple_shapes;
98 new_tuple_shapes.reserve(tuple_indices_to_keep.size());
99 std::vector<int64_t> map(old_tuple_element_count, -1);
100 for (int64_t i : tuple_indices_to_keep) {
101 map[i] = new_tuple_shapes.size();
102 new_tuple_shapes.push_back(¶m->shape().tuple_shapes(i));
103 }
104 Shape tuple_shape = ShapeUtil::MakeTupleShapeWithPtrs(new_tuple_shapes);
105 // Clone the computation in case it is called by another non-conditional
106 // instruction.
107 HloComputation* new_computation =
108 computation->parent()->AddEmbeddedComputation(computation->Clone());
109 param = new_computation->parameter_instruction(0);
110 // Reset the parameter shape of the computation.
111 *param->mutable_shape() = tuple_shape;
112
113 // Reroute the GTE instructions to new tuple indices.
114 for (HloInstruction* user : param->users()) {
115 user->set_tuple_index(map[user->tuple_index()]);
116 }
117
118 // Adjust the operand shape of all calling conditionals.
119 for (HloInstruction* conditional : calling_conditionals) {
120 // Avoid dealing with sharding.
121 if (conditional->has_sharding()) {
122 continue;
123 }
124 for (int64_t branch = 0; branch < conditional->branch_count(); ++branch) {
125 if (conditional->branch_computation(branch) != computation) {
126 continue;
127 }
128 conditional->set_branch_computation(branch, new_computation);
129 const Shape& old_shape = conditional->operand(branch + 1)->shape();
130
131 // Reroute the operand tuple through a tuple of gte instructions of the
132 // original operand tuple.
133 std::vector<HloInstruction*> new_tuple_operands;
134 new_tuple_operands.reserve(tuple_indices_to_keep.size());
135 for (int64_t i : tuple_indices_to_keep) {
136 new_tuple_operands.push_back(conditional->parent()->AddInstruction(
137 HloInstruction::CreateGetTupleElement(
138 old_shape.tuple_shapes(i),
139 conditional->mutable_operand(branch + 1), i)));
140 }
141 HloInstruction* new_tuple = conditional->parent()->AddInstruction(
142 HloInstruction::CreateTuple(new_tuple_operands));
143 TF_RETURN_IF_ERROR(
144 conditional->ReplaceOperandWithDifferentShape(branch + 1, new_tuple));
145 CHECK(ShapeUtil::Compatible(conditional->operand(branch + 1)->shape(),
146 conditional->branch_computation(branch)
147 ->parameter_instruction(0)
148 ->shape()));
149 CHECK(ShapeUtil::Compatible(
150 conditional->shape(),
151 conditional->branch_computation(branch)->root_instruction()->shape()))
152 << conditional->branch_computation(branch)->ToString();
153 }
154 }
155 return true;
156 }
157
158 // Replaces the roots of all branches with an empty tuple if the conditional op
159 // has no users. Returns true if anything is changed.
ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction * conditional_op)160 bool ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction* conditional_op) {
161 const Shape empty_tuple = ShapeUtil::MakeTupleShape({});
162 if (conditional_op->user_count() == 0 &&
163 conditional_op != conditional_op->parent()->root_instruction() &&
164 !ShapeUtil::Compatible(empty_tuple, conditional_op->shape())) {
165 for (int64_t branch_id = 0; branch_id < conditional_op->branch_count();
166 ++branch_id) {
167 auto branch_computation =
168 conditional_op->GetModule()->AddEmbeddedComputation(
169 conditional_op->branch_computation(branch_id)->Clone());
170 conditional_op->set_branch_computation(branch_id, branch_computation);
171 auto new_empty_root =
172 branch_computation->AddInstruction(HloInstruction::CreateTuple({}));
173 branch_computation->set_root_instruction(new_empty_root,
174 /*accept_different_shape=*/true);
175 }
176 *conditional_op->mutable_shape() = empty_tuple;
177 return true;
178 }
179 return false;
180 }
181
182 // Removes all unused elements from result tuple. Returns true if anything is
183 // changed.
184 //
185 // Computes and only keeps a subset of result tuple indices which are actually
186 // being used. This simplification frees up some data-dependencies in branches'
187 // sub-computations and enables further optimizations.
188 //
189 // *) It is considered the whole tuple is used, and there will be no removal for
190 // this case:
191 //
192 // kTuple-result
193 // |
194 // |
195 // kWhile
196 //
197 // *) Only index=0 is used, so change (f32[10,10], f32[20,20]) to (f32[10,10])
198 // and drop f32[20,20].
199 //
200 // kTuple-result (f32[10,10], f32[20,20])
201 // |
202 // |
203 // get-tuple-element, index=0
204 //
RemoveUnusedTupleElements(HloInstruction * conditional_op)205 bool RemoveUnusedTupleElements(HloInstruction* conditional_op) {
206 if (conditional_op->user_count() == 0 ||
207 conditional_op == conditional_op->parent()->root_instruction() ||
208 !conditional_op->shape().IsTuple()) {
209 VLOG(3) << "Skip RemoveUnusedTupleElements due to non-tuple result:\n"
210 << conditional_op->ToShortString();
211 return false;
212 }
213
214 const int old_tuple_shapes_size = conditional_op->shape().tuple_shapes_size();
215
216 // Select indices that are actually used by some GTE instructions.
217 std::vector<bool> used_indices(old_tuple_shapes_size, false);
218 for (const HloInstruction* user : conditional_op->users()) {
219 // We only deal with the case where all users are GTE instructions.
220 if (user->opcode() != HloOpcode::kGetTupleElement) {
221 VLOG(3) << "Skip RemoveUnusedTupleElements due to non-GTE user:\n"
222 << user->ToShortString();
223 return false;
224 }
225 used_indices[user->tuple_index()] = true;
226 }
227
228 const int new_tuple_shapes_size =
229 std::count(used_indices.begin(), used_indices.end(), true);
230 if (new_tuple_shapes_size == old_tuple_shapes_size) {
231 VLOG(3) << "Skip RemoveUnusedTupleElements due to every index is in use.";
232 return false;
233 }
234
235 // Compute old-to-new (old-to-new) indices mapping.
236 absl::flat_hash_map<int, int> new_to_old_mapping, old_to_new_mapping;
237 auto old_iter = used_indices.begin();
238 for (int new_index = 0; new_index < new_tuple_shapes_size; ++new_index) {
239 old_iter = std::find(old_iter, used_indices.end(), true);
240 const int old_index = std::distance(used_indices.begin(), old_iter);
241 new_to_old_mapping[new_index] = old_index;
242 old_to_new_mapping[old_index] = new_index;
243 ++old_iter;
244 }
245
246 // Create new tuple shape, only keep active indices.
247 const Shape old_shape = conditional_op->shape();
248 std::vector<const Shape*> new_tuple_shapes;
249 new_tuple_shapes.reserve(new_tuple_shapes_size);
250 for (int new_index = 0; new_index < new_tuple_shapes_size; ++new_index) {
251 new_tuple_shapes.push_back(
252 &old_shape.tuple_shapes(new_to_old_mapping[new_index]));
253 }
254 const Shape new_shape = ShapeUtil::MakeTupleShapeWithPtrs(new_tuple_shapes);
255
256 // Double-check the old branch root shape is compatible (tuple-like).
257 for (HloComputation* branch : conditional_op->branch_computations()) {
258 const HloInstruction* root = branch->root_instruction();
259 if (!root->shape().IsTuple() ||
260 !ShapeUtil::Compatible(branch->root_instruction()->shape(),
261 old_shape)) {
262 VLOG(3) << "Skip RemoveUnusedTupleElements due to some branch "
263 << branch->name() << " has in-compatible root shape, expect "
264 << old_shape.ToString() << ", but got "
265 << root->shape().ToString() << "\n"
266 << conditional_op->ToString();
267 return false;
268 }
269 }
270
271 // Replace all branches with new tuple shape. Add 'gtes' for active indices
272 // and create a new root gathering them.
273 //
274 // non-kTuple-root
275 // | |
276 // gte gte
277 // \ /
278 // new_root
279 for (int branch_id = 0; branch_id < conditional_op->branch_count();
280 ++branch_id) {
281 HloComputation* old_branch = conditional_op->branch_computation(branch_id);
282 HloComputation* cloned_branch =
283 conditional_op->GetModule()->AddEmbeddedComputation(
284 old_branch->Clone());
285 conditional_op->set_branch_computation(branch_id, cloned_branch);
286
287 HloInstruction* old_root = cloned_branch->root_instruction();
288 std::vector<HloInstruction*> new_tuple_root_operands;
289 for (int old_index = 0; old_index < old_tuple_shapes_size; ++old_index) {
290 if (used_indices[old_index]) {
291 new_tuple_root_operands.push_back(
292 cloned_branch->AddInstruction(HloInstruction::CreateGetTupleElement(
293 old_shape.tuple_shapes(old_index), old_root, old_index)));
294 }
295 }
296 HloInstruction* new_tuple_root = cloned_branch->AddInstruction(
297 HloInstruction::CreateTuple(new_tuple_root_operands));
298 cloned_branch->set_root_instruction(new_tuple_root,
299 /*accept_different_shape=*/true);
300 }
301
302 // Replace the conditional instruction itself.
303 *conditional_op->mutable_shape() = new_shape;
304
305 // Reroute all user GTE instructions to new tuple indices.
306 for (HloInstruction* user : conditional_op->users()) {
307 const int old_index = user->tuple_index();
308 const int new_index = old_to_new_mapping[old_index];
309 user->set_tuple_index(new_index);
310 }
311 return true;
312 }
313
314 // Merges duplicate(identical) elements in result tuple.
315 //
316 // Two tuple elements(indices) are duplicate if they return identical value
317 // (from the same HloInstruction source) in every branch. In other words, if
318 // replacing j-th with i-th tuple index results in an invariant, i-th/j-th are
319 // identical and we can safely replace all GTE j-th (users this conditional
320 // instruction) with GTE i-th.
321 //
322 // Afterwards, any unused j-th tuple index will be removed by
323 // RemoveUnusedTupleElements and the size of tuple shape will be reduced.
324 // E.g.
325 //
326 // Before:
327 // gte add
328 // / \ / \
329 // | | | |
330 // on_true on_false
331 // (f32, f32) (f32, f32)
332 // | |
333 // \ /
334 // conditional
335 // (f32, f32)
336 // | |
337 // gte gte
338 // \ /
339 // tuple
340 // (f32, f32)
341 //
342 // After:
343 // gte add
344 // | |
345 // on_true on_false
346 // (f32) (f32)
347 // | |
348 // \ /
349 // conditional
350 // (f32)
351 // |
352 // gte
353 // | \
354 // | |
355 // tuple
356 // (f32, f32)
MergeDuplicateTupleElements(HloInstruction * conditional)357 bool MergeDuplicateTupleElements(HloInstruction* conditional) {
358 if (conditional->user_count() == 0 ||
359 conditional == conditional->parent()->root_instruction() ||
360 !conditional->shape().IsTuple()) {
361 VLOG(3) << "Skip MergeDuplicateTupleElements due not tuple shape nor root "
362 "instruction:\n"
363 << conditional->ToShortString();
364 return false;
365 }
366
367 for (const HloInstruction* user : conditional->users()) {
368 if (user->opcode() != HloOpcode::kGetTupleElement) {
369 VLOG(3) << "Skip MergeDuplicateTupleElements due not all users are "
370 "kGetTupleElement:\n"
371 << conditional->ToShortString();
372 return false;
373 }
374 }
375
376 for (const HloComputation* branch : conditional->branch_computations()) {
377 if (branch->root_instruction()->opcode() != HloOpcode::kTuple) {
378 VLOG(3) << "Skip MergeDuplicateTupleElements due not all branch roots "
379 "are kTuple:\n"
380 << conditional->ToShortString();
381 return false;
382 }
383 }
384
385 // For example,
386 //
387 // tuple index | 0 1 2
388 // ------------------------------------------
389 // branch #0 root: tuple(gte-0, add-0, add-0)
390 // branch #1 root: tuple(rng-1, add-1, add-1)
391 // branch #2 root: tuple(add-2, add-2, add-2)
392 //
393 // vectorize(0) will be [gte-0, rng-1, add-2]
394 // vectorize(1) will be [add-0, add-1, add-2]
395 // vectorize(2) will be [add-0, add-1, add-2]
396 //
397 // In this case, vectorize(1), vectorize(2) are equal and index 1, 2 are
398 // identical.
399 auto vectorize_branches_root_tuple_ith_operand = [conditional](int64_t i) {
400 std::vector<const HloInstruction*> operands;
401 absl::c_transform(conditional->branch_computations(),
402 std::back_inserter(operands),
403 [i](const HloComputation* branch) {
404 return branch->root_instruction()->operand(i);
405 });
406 return operands;
407 };
408
409 auto replace_root_user_gte_jth_with_gte_ith = [conditional](int64_t i,
410 int64_t j) {
411 bool changed = false;
412 for (HloInstruction* user : conditional->users()) {
413 if (user->tuple_index() == j) {
414 user->set_tuple_index(i);
415 changed |= true;
416 }
417 }
418 return changed;
419 };
420
421 bool changed = false;
422 absl::flat_hash_map<std::vector<const HloInstruction*>, int64_t>
423 index_collision_table;
424 for (int i = 0; i < conditional->shape().tuple_shapes_size(); ++i) {
425 const std::vector<const HloInstruction*> ith_operands_vector =
426 vectorize_branches_root_tuple_ith_operand(i);
427 const auto emplace_res =
428 index_collision_table.emplace(ith_operands_vector, i);
429 if (!emplace_res.second) {
430 changed |=
431 replace_root_user_gte_jth_with_gte_ith(emplace_res.first->second, i);
432 }
433 }
434 return changed;
435 }
436 } // namespace
437
438 // Tries to replace a conditional with a call operation of the corresponding
439 // computation. If the given conditional has a constant branch_index, tries to
440 // replace it with a call to its corresponding branch computation and then
441 // inline that computation.
442 //
443 // Returns true if it made a change to the graph.
TryRemoveConditional(HloInstruction * conditional)444 StatusOr<bool> ConditionalSimplifier::TryRemoveConditional(
445 HloInstruction* conditional) {
446 CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
447 // Do not remove conditionals that contain side-effecting instructions or
448 // have control predecessors/successors in either true/false computation.
449 if (!conditional->parent()->IsSafelyRemovable(conditional) ||
450 conditional->HasSideEffect()) {
451 VLOG(2) << "Not attempting to remove conditional as it is not removable or "
452 "has side effect: "
453 << conditional->ToShortString();
454 return false;
455 }
456
457 // We can always inline a 1-branch conditional due to default branch fallback.
458 auto computation = conditional->parent();
459 auto create_call = [&](int64_t branch) {
460 auto call = computation->AddInstruction(HloInstruction::CreateCall(
461 conditional->shape(), {conditional->mutable_operand(1 + branch)},
462 conditional->branch_computation(branch)));
463 conditional->SetupDerivedInstruction(call);
464 return call;
465 };
466
467 if (conditional->branch_count() == 1) {
468 HloInstruction* call_op = create_call(0);
469 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
470 TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
471 return true;
472 }
473
474 if (conditional->operand(0)->opcode() == HloOpcode::kConstant) {
475 int branch_index = 0;
476 if (conditional->operand(0)->shape().element_type() == PRED) {
477 branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
478 } else {
479 branch_index = conditional->operand(0)->literal().Get<int32_t>({});
480 if (branch_index < 0 || branch_index >= conditional->branch_count()) {
481 branch_index = conditional->branch_count() - 1;
482 }
483 }
484 HloInstruction* call_op = create_call(branch_index);
485 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
486 TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
487
488 return true;
489 }
490
491 auto instruction_is_expensive = [](const HloInstruction* hlo) {
492 switch (hlo->opcode()) {
493 case HloOpcode::kBroadcast:
494 case HloOpcode::kConcatenate:
495 case HloOpcode::kDynamicSlice:
496 case HloOpcode::kGetTupleElement:
497 case HloOpcode::kReduce:
498 case HloOpcode::kReshape:
499 case HloOpcode::kPad:
500 case HloOpcode::kParameter:
501 case HloOpcode::kSlice:
502 case HloOpcode::kTuple:
503 return false;
504 default:
505 return !hlo->IsElementwise();
506 }
507 };
508
509 if (conditional->branch_count() != 2 ||
510 conditional->operand(0)->shape().element_type() != PRED ||
511 absl::c_any_of(conditional->branch_computation(0)->instructions(),
512 instruction_is_expensive) ||
513 absl::c_any_of(conditional->branch_computation(1)->instructions(),
514 instruction_is_expensive)) {
515 VLOG(2)
516 << "Not attempting to remove conditional as its branch_index is not a "
517 "compile-time constant or contains expensive instructions: "
518 << conditional->ToShortString();
519 return false;
520 }
521
522 bool branch_empty =
523 ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
524 ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
525 // Empty branch is faster to execute than select.
526 if (branch_empty) {
527 return false;
528 }
529
530 HloInstruction* true_call_op = create_call(0);
531 HloInstruction* false_call_op = create_call(1);
532 auto condition_broadcast = [&](const Shape& shape) {
533 if (ShapeUtil::IsScalar(shape)) {
534 return conditional->mutable_operand(0);
535 }
536 Shape new_shape = ShapeUtil::ChangeElementType(shape, PRED);
537 UpdateLayout(&new_shape);
538 return computation->AddInstruction(HloInstruction::CreateBroadcast(
539 new_shape, conditional->mutable_operand(0), {}));
540 };
541
542 auto gte = [&](HloInstruction* hlo, int64_t i) {
543 return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
544 hlo->shape().tuple_shapes(i), hlo, i));
545 };
546
547 std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
548 [&](HloInstruction* t, HloInstruction* f) {
549 if (f->shape().IsToken()) {
550 return computation->AddInstruction(
551 HloInstruction::CreateAfterAll({t, f}));
552 }
553 if (f->shape().IsArray()) {
554 return computation->AddInstruction(HloInstruction::CreateTernary(
555 f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),
556 t, f));
557 }
558 std::vector<HloInstruction*> selects;
559 const int64_t tuple_element_count =
560 ShapeUtil::TupleElementCount(f->shape());
561 selects.reserve(tuple_element_count);
562 for (int64_t i = 0; i < tuple_element_count; ++i) {
563 selects.push_back(select(gte(t, i), gte(f, i)));
564 }
565 return computation->AddInstruction(
566 HloInstruction::CreateTuple(selects));
567 };
568
569 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
570 conditional, select(true_call_op, false_call_op)));
571
572 TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status());
573 TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status());
574 return true;
575 }
576
ComputationCallsChannelInstructions(const HloComputation & computation)577 static bool ComputationCallsChannelInstructions(
578 const HloComputation& computation) {
579 std::vector<const HloComputation*> worklist = {&computation};
580 while (!worklist.empty()) {
581 const HloComputation* work = worklist.back();
582 worklist.pop_back();
583 for (const HloInstruction* instruction : work->instructions()) {
584 if (DynCast<HloChannelInstruction>(instruction) != nullptr) {
585 return true;
586 }
587 worklist.insert(worklist.end(),
588 instruction->called_computations().begin(),
589 instruction->called_computations().end());
590 }
591 }
592 return false;
593 }
594
InstructionCallsChannelInstructions(const HloInstruction & instruction)595 static bool InstructionCallsChannelInstructions(
596 const HloInstruction& instruction) {
597 for (const HloComputation* called_computation :
598 instruction.called_computations()) {
599 if (ComputationCallsChannelInstructions(*called_computation)) {
600 return true;
601 }
602 }
603 return false;
604 }
605
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)606 StatusOr<bool> ConditionalSimplifier::Run(
607 HloModule* module,
608 const absl::flat_hash_set<absl::string_view>& execution_threads) {
609 XLA_VLOG_LINES(
610 3, "ConditionalSimplifier::Run(), before:\n" + module->ToString());
611 bool changed = false;
612
613 // Gather all the conditional ops in our module. We do this ahead of time so
614 // we don't have to worry about mutating the lists of computations or
615 // instructions as we iterate.
616 std::vector<HloInstruction*> conditional_ops;
617 for (auto* comp : module->computations(execution_threads)) {
618 for (auto* instr : comp->MakeInstructionPostOrder()) {
619 if (instr->opcode() == HloOpcode::kConditional) {
620 // Verifier wants a single send/recv with a given channel. This pass
621 // clones computations which can result in that getting violated.
622 if (InstructionCallsChannelInstructions(*instr)) {
623 continue;
624 }
625 if (instr->has_sharding()) {
626 // The code below doesn't handle sharding properly.
627 continue;
628 }
629 conditional_ops.push_back(instr);
630 }
631 }
632 }
633
634 absl::flat_hash_set<HloInstruction*> removed_conditionals;
635 for (HloInstruction* conditional_op : conditional_ops) {
636 changed |= MergeDuplicateTupleElements(conditional_op);
637 changed |= RemoveUnusedTupleElements(conditional_op);
638 changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op);
639 TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op));
640 if (result) {
641 removed_conditionals.insert(conditional_op);
642 changed = true;
643 }
644 }
645 // Try to remove unused conditional operands from branch computations. We need
646 // to be careful to adjust *all* calling conditional ops if we do that, so
647 // lets collect them first.
648 absl::flat_hash_map<HloComputation*, absl::flat_hash_set<HloInstruction*>>
649 calling_conditionals;
650 // Keys of calling_conditionals to get a deterministic ordering.
651 std::vector<HloComputation*> calling_computationals_vector;
652 for (HloInstruction* conditional : conditional_ops) {
653 if (removed_conditionals.contains(conditional)) {
654 continue;
655 }
656
657 for (int64_t branch = 0; branch < conditional->branch_count(); ++branch) {
658 auto* branch_comp = conditional->branch_computation(branch);
659 if (!calling_conditionals.contains(branch_comp)) {
660 calling_computationals_vector.push_back(branch_comp);
661 }
662 calling_conditionals[branch_comp].insert(conditional);
663 }
664 }
665
666 for (auto* comp : calling_computationals_vector) {
667 auto entry = calling_conditionals.find(comp);
668 CHECK(entry != calling_conditionals.end());
669 TF_ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands(
670 entry->first, entry->second));
671 changed |= result;
672 }
673
674 XLA_VLOG_LINES(3,
675 "ConditionalSimplifier::Run(), after:\n" + module->ToString());
676 return changed;
677 }
678
679 } // namespace xla
680