xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/conditional_simplifier.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
17 
18 #include <iterator>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/call_inliner.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 
43 namespace xla {
44 
45 namespace {
46 
47 // A computation with array type that only contains parameters and tuples is
48 // considered emtpy.
ComputationIsEmptyWithArrayRoot(const HloComputation * computation)49 bool ComputationIsEmptyWithArrayRoot(const HloComputation* computation) {
50   bool empty_operations = absl::c_all_of(
51       computation->MakeInstructionPostOrder(), [](const HloInstruction* inst) {
52         return inst->opcode() == HloOpcode::kTuple ||
53                inst->opcode() == HloOpcode::kGetTupleElement ||
54                inst->opcode() == HloOpcode::kParameter;
55       });
56   bool contains_array = false;
57   ShapeUtil::ForEachSubshape(computation->root_instruction()->shape(),
58                              [&](const Shape& shape, const ShapeIndex& index) {
59                                if (shape.IsArray()) {
60                                  contains_array = true;
61                                }
62                              });
63   return empty_operations && contains_array;
64 }
65 
TryRemoveUnusedConditionalOperands(HloComputation * computation,const absl::flat_hash_set<HloInstruction * > & calling_conditionals)66 StatusOr<bool> TryRemoveUnusedConditionalOperands(
67     HloComputation* computation,
68     const absl::flat_hash_set<HloInstruction*>& calling_conditionals) {
69   HloInstruction* param = computation->parameter_instruction(0);
70   // Do not remove from the root instruction.
71   if (param == computation->root_instruction()) {
72     return false;
73   }
74   // There is nothing to be removed for non-tuple operands.
75   if (!param->shape().IsTuple()) {
76     return false;
77   }
78   std::set<int64_t> tuple_indices_to_keep;
79   for (HloInstruction* user : param->users()) {
80     // If the user is not a get tuple element, assume it is unsafe to remove
81     // elements from the tuple.
82     if (user->opcode() != HloOpcode::kGetTupleElement) {
83       return false;
84     }
85     tuple_indices_to_keep.insert(user->tuple_index());
86   }
87   // If all tuple elements are used in this conditional branch, there is nothing
88   // to be removed.
89   int64_t old_tuple_element_count =
90       ShapeUtil::TupleElementCount(param->shape());
91   if (tuple_indices_to_keep.size() == old_tuple_element_count) {
92     return false;
93   }
94 
95   // Create a new tuple shape based on the indices actually used by this
96   // computation branch.
97   std::vector<const Shape*> new_tuple_shapes;
98   new_tuple_shapes.reserve(tuple_indices_to_keep.size());
99   std::vector<int64_t> map(old_tuple_element_count, -1);
100   for (int64_t i : tuple_indices_to_keep) {
101     map[i] = new_tuple_shapes.size();
102     new_tuple_shapes.push_back(&param->shape().tuple_shapes(i));
103   }
104   Shape tuple_shape = ShapeUtil::MakeTupleShapeWithPtrs(new_tuple_shapes);
105   // Clone the computation in case it is called by another non-conditional
106   // instruction.
107   HloComputation* new_computation =
108       computation->parent()->AddEmbeddedComputation(computation->Clone());
109   param = new_computation->parameter_instruction(0);
110   // Reset the parameter shape of the computation.
111   *param->mutable_shape() = tuple_shape;
112 
113   // Reroute the GTE instructions to new tuple indices.
114   for (HloInstruction* user : param->users()) {
115     user->set_tuple_index(map[user->tuple_index()]);
116   }
117 
118   // Adjust the operand shape of all calling conditionals.
119   for (HloInstruction* conditional : calling_conditionals) {
120     // Avoid dealing with sharding.
121     if (conditional->has_sharding()) {
122       continue;
123     }
124     for (int64_t branch = 0; branch < conditional->branch_count(); ++branch) {
125       if (conditional->branch_computation(branch) != computation) {
126         continue;
127       }
128       conditional->set_branch_computation(branch, new_computation);
129       const Shape& old_shape = conditional->operand(branch + 1)->shape();
130 
131       // Reroute the operand tuple through a tuple of gte instructions of the
132       // original operand tuple.
133       std::vector<HloInstruction*> new_tuple_operands;
134       new_tuple_operands.reserve(tuple_indices_to_keep.size());
135       for (int64_t i : tuple_indices_to_keep) {
136         new_tuple_operands.push_back(conditional->parent()->AddInstruction(
137             HloInstruction::CreateGetTupleElement(
138                 old_shape.tuple_shapes(i),
139                 conditional->mutable_operand(branch + 1), i)));
140       }
141       HloInstruction* new_tuple = conditional->parent()->AddInstruction(
142           HloInstruction::CreateTuple(new_tuple_operands));
143       TF_RETURN_IF_ERROR(
144           conditional->ReplaceOperandWithDifferentShape(branch + 1, new_tuple));
145       CHECK(ShapeUtil::Compatible(conditional->operand(branch + 1)->shape(),
146                                   conditional->branch_computation(branch)
147                                       ->parameter_instruction(0)
148                                       ->shape()));
149       CHECK(ShapeUtil::Compatible(
150           conditional->shape(),
151           conditional->branch_computation(branch)->root_instruction()->shape()))
152           << conditional->branch_computation(branch)->ToString();
153     }
154   }
155   return true;
156 }
157 
158 // Replaces the roots of all branches with an empty tuple if the conditional op
159 // has no users. Returns true if anything is changed.
ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction * conditional_op)160 bool ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction* conditional_op) {
161   const Shape empty_tuple = ShapeUtil::MakeTupleShape({});
162   if (conditional_op->user_count() == 0 &&
163       conditional_op != conditional_op->parent()->root_instruction() &&
164       !ShapeUtil::Compatible(empty_tuple, conditional_op->shape())) {
165     for (int64_t branch_id = 0; branch_id < conditional_op->branch_count();
166          ++branch_id) {
167       auto branch_computation =
168           conditional_op->GetModule()->AddEmbeddedComputation(
169               conditional_op->branch_computation(branch_id)->Clone());
170       conditional_op->set_branch_computation(branch_id, branch_computation);
171       auto new_empty_root =
172           branch_computation->AddInstruction(HloInstruction::CreateTuple({}));
173       branch_computation->set_root_instruction(new_empty_root,
174                                                /*accept_different_shape=*/true);
175     }
176     *conditional_op->mutable_shape() = empty_tuple;
177     return true;
178   }
179   return false;
180 }
181 
182 // Removes all unused elements from result tuple. Returns true if anything is
183 // changed.
184 //
185 // Computes and only keeps a subset of result tuple indices which are actually
186 // being used. This simplification frees up some data-dependencies in branches'
187 // sub-computations and enables further optimizations.
188 //
189 // *) It is considered the whole tuple is used, and there will be no removal for
190 //    this case:
191 //
192 //        kTuple-result
193 //              |
194 //              |
195 //           kWhile
196 //
197 // *) Only index=0 is used, so change (f32[10,10], f32[20,20]) to (f32[10,10])
198 //    and drop f32[20,20].
199 //
200 //        kTuple-result (f32[10,10], f32[20,20])
201 //              |
202 //              |
203 //        get-tuple-element, index=0
204 //
RemoveUnusedTupleElements(HloInstruction * conditional_op)205 bool RemoveUnusedTupleElements(HloInstruction* conditional_op) {
206   if (conditional_op->user_count() == 0 ||
207       conditional_op == conditional_op->parent()->root_instruction() ||
208       !conditional_op->shape().IsTuple()) {
209     VLOG(3) << "Skip RemoveUnusedTupleElements due to non-tuple result:\n"
210             << conditional_op->ToShortString();
211     return false;
212   }
213 
214   const int old_tuple_shapes_size = conditional_op->shape().tuple_shapes_size();
215 
216   // Select indices that are actually used by some GTE instructions.
217   std::vector<bool> used_indices(old_tuple_shapes_size, false);
218   for (const HloInstruction* user : conditional_op->users()) {
219     // We only deal with the case where all users are GTE instructions.
220     if (user->opcode() != HloOpcode::kGetTupleElement) {
221       VLOG(3) << "Skip RemoveUnusedTupleElements due to non-GTE user:\n"
222               << user->ToShortString();
223       return false;
224     }
225     used_indices[user->tuple_index()] = true;
226   }
227 
228   const int new_tuple_shapes_size =
229       std::count(used_indices.begin(), used_indices.end(), true);
230   if (new_tuple_shapes_size == old_tuple_shapes_size) {
231     VLOG(3) << "Skip RemoveUnusedTupleElements due to every index is in use.";
232     return false;
233   }
234 
235   // Compute old-to-new (old-to-new) indices mapping.
236   absl::flat_hash_map<int, int> new_to_old_mapping, old_to_new_mapping;
237   auto old_iter = used_indices.begin();
238   for (int new_index = 0; new_index < new_tuple_shapes_size; ++new_index) {
239     old_iter = std::find(old_iter, used_indices.end(), true);
240     const int old_index = std::distance(used_indices.begin(), old_iter);
241     new_to_old_mapping[new_index] = old_index;
242     old_to_new_mapping[old_index] = new_index;
243     ++old_iter;
244   }
245 
246   // Create new tuple shape, only keep active indices.
247   const Shape old_shape = conditional_op->shape();
248   std::vector<const Shape*> new_tuple_shapes;
249   new_tuple_shapes.reserve(new_tuple_shapes_size);
250   for (int new_index = 0; new_index < new_tuple_shapes_size; ++new_index) {
251     new_tuple_shapes.push_back(
252         &old_shape.tuple_shapes(new_to_old_mapping[new_index]));
253   }
254   const Shape new_shape = ShapeUtil::MakeTupleShapeWithPtrs(new_tuple_shapes);
255 
256   // Double-check the old branch root shape is compatible (tuple-like).
257   for (HloComputation* branch : conditional_op->branch_computations()) {
258     const HloInstruction* root = branch->root_instruction();
259     if (!root->shape().IsTuple() ||
260         !ShapeUtil::Compatible(branch->root_instruction()->shape(),
261                                old_shape)) {
262       VLOG(3) << "Skip RemoveUnusedTupleElements due to some branch "
263               << branch->name() << " has in-compatible root shape, expect "
264               << old_shape.ToString() << ", but got "
265               << root->shape().ToString() << "\n"
266               << conditional_op->ToString();
267       return false;
268     }
269   }
270 
271   // Replace all branches with new tuple shape. Add 'gtes' for active indices
272   // and create a new root gathering them.
273   //
274   //  non-kTuple-root
275   //    |      |
276   //   gte   gte
277   //     \    /
278   //    new_root
279   for (int branch_id = 0; branch_id < conditional_op->branch_count();
280        ++branch_id) {
281     HloComputation* old_branch = conditional_op->branch_computation(branch_id);
282     HloComputation* cloned_branch =
283         conditional_op->GetModule()->AddEmbeddedComputation(
284             old_branch->Clone());
285     conditional_op->set_branch_computation(branch_id, cloned_branch);
286 
287     HloInstruction* old_root = cloned_branch->root_instruction();
288     std::vector<HloInstruction*> new_tuple_root_operands;
289     for (int old_index = 0; old_index < old_tuple_shapes_size; ++old_index) {
290       if (used_indices[old_index]) {
291         new_tuple_root_operands.push_back(
292             cloned_branch->AddInstruction(HloInstruction::CreateGetTupleElement(
293                 old_shape.tuple_shapes(old_index), old_root, old_index)));
294       }
295     }
296     HloInstruction* new_tuple_root = cloned_branch->AddInstruction(
297         HloInstruction::CreateTuple(new_tuple_root_operands));
298     cloned_branch->set_root_instruction(new_tuple_root,
299                                         /*accept_different_shape=*/true);
300   }
301 
302   // Replace the conditional instruction itself.
303   *conditional_op->mutable_shape() = new_shape;
304 
305   // Reroute all user GTE instructions to new tuple indices.
306   for (HloInstruction* user : conditional_op->users()) {
307     const int old_index = user->tuple_index();
308     const int new_index = old_to_new_mapping[old_index];
309     user->set_tuple_index(new_index);
310   }
311   return true;
312 }
313 
314 // Merges duplicate(identical) elements in result tuple.
315 //
316 // Two tuple elements(indices) are duplicate if they return identical value
317 // (from the same HloInstruction source) in every branch. In other words, if
318 // replacing j-th with i-th tuple index results in an invariant, i-th/j-th are
319 // identical and we can safely replace all GTE j-th (users this conditional
320 // instruction) with GTE i-th.
321 //
322 // Afterwards, any unused j-th tuple index will be removed by
323 // RemoveUnusedTupleElements and the size of tuple shape will be reduced.
324 // E.g.
325 //
326 // Before:
327 //       gte          add
328 //      /   \        /   \
329 //      |   |        |   |
330 //     on_true      on_false
331 //    (f32, f32)   (f32, f32)
332 //         |           |
333 //          \         /
334 //          conditional
335 //          (f32, f32)
336 //            |    |
337 //           gte  gte
338 //            \    /
339 //            tuple
340 //          (f32, f32)
341 //
342 // After:
343 //       gte          add
344 //        |            |
345 //     on_true      on_false
346 //      (f32)        (f32)
347 //         |           |
348 //          \         /
349 //          conditional
350 //             (f32)
351 //               |
352 //              gte
353 //              |  \
354 //              |   |
355 //              tuple
356 //            (f32, f32)
MergeDuplicateTupleElements(HloInstruction * conditional)357 bool MergeDuplicateTupleElements(HloInstruction* conditional) {
358   if (conditional->user_count() == 0 ||
359       conditional == conditional->parent()->root_instruction() ||
360       !conditional->shape().IsTuple()) {
361     VLOG(3) << "Skip MergeDuplicateTupleElements due not tuple shape nor root "
362                "instruction:\n"
363             << conditional->ToShortString();
364     return false;
365   }
366 
367   for (const HloInstruction* user : conditional->users()) {
368     if (user->opcode() != HloOpcode::kGetTupleElement) {
369       VLOG(3) << "Skip MergeDuplicateTupleElements due not all users are "
370                  "kGetTupleElement:\n"
371               << conditional->ToShortString();
372       return false;
373     }
374   }
375 
376   for (const HloComputation* branch : conditional->branch_computations()) {
377     if (branch->root_instruction()->opcode() != HloOpcode::kTuple) {
378       VLOG(3) << "Skip MergeDuplicateTupleElements due not all branch roots "
379                  "are kTuple:\n"
380               << conditional->ToShortString();
381       return false;
382     }
383   }
384 
385   // For example,
386   //
387   //    tuple index   |         0      1      2
388   //    ------------------------------------------
389   //    branch #0 root: tuple(gte-0, add-0, add-0)
390   //    branch #1 root: tuple(rng-1, add-1, add-1)
391   //    branch #2 root: tuple(add-2, add-2, add-2)
392   //
393   // vectorize(0) will be [gte-0, rng-1, add-2]
394   // vectorize(1) will be [add-0, add-1, add-2]
395   // vectorize(2) will be [add-0, add-1, add-2]
396   //
397   // In this case, vectorize(1), vectorize(2) are equal and index 1, 2 are
398   // identical.
399   auto vectorize_branches_root_tuple_ith_operand = [conditional](int64_t i) {
400     std::vector<const HloInstruction*> operands;
401     absl::c_transform(conditional->branch_computations(),
402                       std::back_inserter(operands),
403                       [i](const HloComputation* branch) {
404                         return branch->root_instruction()->operand(i);
405                       });
406     return operands;
407   };
408 
409   auto replace_root_user_gte_jth_with_gte_ith = [conditional](int64_t i,
410                                                               int64_t j) {
411     bool changed = false;
412     for (HloInstruction* user : conditional->users()) {
413       if (user->tuple_index() == j) {
414         user->set_tuple_index(i);
415         changed |= true;
416       }
417     }
418     return changed;
419   };
420 
421   bool changed = false;
422   absl::flat_hash_map<std::vector<const HloInstruction*>, int64_t>
423       index_collision_table;
424   for (int i = 0; i < conditional->shape().tuple_shapes_size(); ++i) {
425     const std::vector<const HloInstruction*> ith_operands_vector =
426         vectorize_branches_root_tuple_ith_operand(i);
427     const auto emplace_res =
428         index_collision_table.emplace(ith_operands_vector, i);
429     if (!emplace_res.second) {
430       changed |=
431           replace_root_user_gte_jth_with_gte_ith(emplace_res.first->second, i);
432     }
433   }
434   return changed;
435 }
436 }  // namespace
437 
438 // Tries to replace a conditional with a call operation of the corresponding
439 // computation. If the given conditional has a constant branch_index, tries to
440 // replace it with a call to its corresponding branch computation and then
441 // inline that computation.
442 //
443 // Returns true if it made a change to the graph.
TryRemoveConditional(HloInstruction * conditional)444 StatusOr<bool> ConditionalSimplifier::TryRemoveConditional(
445     HloInstruction* conditional) {
446   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
447   // Do not remove conditionals that contain side-effecting instructions or
448   // have control predecessors/successors in either true/false computation.
449   if (!conditional->parent()->IsSafelyRemovable(conditional) ||
450       conditional->HasSideEffect()) {
451     VLOG(2) << "Not attempting to remove conditional as it is not removable or "
452                "has side effect: "
453             << conditional->ToShortString();
454     return false;
455   }
456 
457   // We can always inline a 1-branch conditional due to default branch fallback.
458   auto computation = conditional->parent();
459   auto create_call = [&](int64_t branch) {
460     auto call = computation->AddInstruction(HloInstruction::CreateCall(
461         conditional->shape(), {conditional->mutable_operand(1 + branch)},
462         conditional->branch_computation(branch)));
463     conditional->SetupDerivedInstruction(call);
464     return call;
465   };
466 
467   if (conditional->branch_count() == 1) {
468     HloInstruction* call_op = create_call(0);
469     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
470     TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
471     return true;
472   }
473 
474   if (conditional->operand(0)->opcode() == HloOpcode::kConstant) {
475     int branch_index = 0;
476     if (conditional->operand(0)->shape().element_type() == PRED) {
477       branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
478     } else {
479       branch_index = conditional->operand(0)->literal().Get<int32_t>({});
480       if (branch_index < 0 || branch_index >= conditional->branch_count()) {
481         branch_index = conditional->branch_count() - 1;
482       }
483     }
484     HloInstruction* call_op = create_call(branch_index);
485     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
486     TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
487 
488     return true;
489   }
490 
491   auto instruction_is_expensive = [](const HloInstruction* hlo) {
492     switch (hlo->opcode()) {
493       case HloOpcode::kBroadcast:
494       case HloOpcode::kConcatenate:
495       case HloOpcode::kDynamicSlice:
496       case HloOpcode::kGetTupleElement:
497       case HloOpcode::kReduce:
498       case HloOpcode::kReshape:
499       case HloOpcode::kPad:
500       case HloOpcode::kParameter:
501       case HloOpcode::kSlice:
502       case HloOpcode::kTuple:
503         return false;
504       default:
505         return !hlo->IsElementwise();
506     }
507   };
508 
509   if (conditional->branch_count() != 2 ||
510       conditional->operand(0)->shape().element_type() != PRED ||
511       absl::c_any_of(conditional->branch_computation(0)->instructions(),
512                      instruction_is_expensive) ||
513       absl::c_any_of(conditional->branch_computation(1)->instructions(),
514                      instruction_is_expensive)) {
515     VLOG(2)
516         << "Not attempting  to remove conditional as its branch_index is not a "
517            "compile-time constant or contains expensive instructions: "
518         << conditional->ToShortString();
519     return false;
520   }
521 
522   bool branch_empty =
523       ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
524       ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
525   // Empty branch is faster to execute than select.
526   if (branch_empty) {
527     return false;
528   }
529 
530   HloInstruction* true_call_op = create_call(0);
531   HloInstruction* false_call_op = create_call(1);
532   auto condition_broadcast = [&](const Shape& shape) {
533     if (ShapeUtil::IsScalar(shape)) {
534       return conditional->mutable_operand(0);
535     }
536     Shape new_shape = ShapeUtil::ChangeElementType(shape, PRED);
537     UpdateLayout(&new_shape);
538     return computation->AddInstruction(HloInstruction::CreateBroadcast(
539         new_shape, conditional->mutable_operand(0), {}));
540   };
541 
542   auto gte = [&](HloInstruction* hlo, int64_t i) {
543     return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
544         hlo->shape().tuple_shapes(i), hlo, i));
545   };
546 
547   std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
548       [&](HloInstruction* t, HloInstruction* f) {
549         if (f->shape().IsToken()) {
550           return computation->AddInstruction(
551               HloInstruction::CreateAfterAll({t, f}));
552         }
553         if (f->shape().IsArray()) {
554           return computation->AddInstruction(HloInstruction::CreateTernary(
555               f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),
556               t, f));
557         }
558         std::vector<HloInstruction*> selects;
559         const int64_t tuple_element_count =
560             ShapeUtil::TupleElementCount(f->shape());
561         selects.reserve(tuple_element_count);
562         for (int64_t i = 0; i < tuple_element_count; ++i) {
563           selects.push_back(select(gte(t, i), gte(f, i)));
564         }
565         return computation->AddInstruction(
566             HloInstruction::CreateTuple(selects));
567       };
568 
569   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
570       conditional, select(true_call_op, false_call_op)));
571 
572   TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status());
573   TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status());
574   return true;
575 }
576 
ComputationCallsChannelInstructions(const HloComputation & computation)577 static bool ComputationCallsChannelInstructions(
578     const HloComputation& computation) {
579   std::vector<const HloComputation*> worklist = {&computation};
580   while (!worklist.empty()) {
581     const HloComputation* work = worklist.back();
582     worklist.pop_back();
583     for (const HloInstruction* instruction : work->instructions()) {
584       if (DynCast<HloChannelInstruction>(instruction) != nullptr) {
585         return true;
586       }
587       worklist.insert(worklist.end(),
588                       instruction->called_computations().begin(),
589                       instruction->called_computations().end());
590     }
591   }
592   return false;
593 }
594 
InstructionCallsChannelInstructions(const HloInstruction & instruction)595 static bool InstructionCallsChannelInstructions(
596     const HloInstruction& instruction) {
597   for (const HloComputation* called_computation :
598        instruction.called_computations()) {
599     if (ComputationCallsChannelInstructions(*called_computation)) {
600       return true;
601     }
602   }
603   return false;
604 }
605 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)606 StatusOr<bool> ConditionalSimplifier::Run(
607     HloModule* module,
608     const absl::flat_hash_set<absl::string_view>& execution_threads) {
609   XLA_VLOG_LINES(
610       3, "ConditionalSimplifier::Run(), before:\n" + module->ToString());
611   bool changed = false;
612 
613   // Gather all the conditional ops in our module. We do this ahead of time so
614   // we don't have to worry about mutating the lists of computations or
615   // instructions as we iterate.
616   std::vector<HloInstruction*> conditional_ops;
617   for (auto* comp : module->computations(execution_threads)) {
618     for (auto* instr : comp->MakeInstructionPostOrder()) {
619       if (instr->opcode() == HloOpcode::kConditional) {
620         // Verifier wants a single send/recv with a given channel. This pass
621         // clones computations which can result in that getting violated.
622         if (InstructionCallsChannelInstructions(*instr)) {
623           continue;
624         }
625         if (instr->has_sharding()) {
626           // The code below doesn't handle sharding properly.
627           continue;
628         }
629         conditional_ops.push_back(instr);
630       }
631     }
632   }
633 
634   absl::flat_hash_set<HloInstruction*> removed_conditionals;
635   for (HloInstruction* conditional_op : conditional_ops) {
636     changed |= MergeDuplicateTupleElements(conditional_op);
637     changed |= RemoveUnusedTupleElements(conditional_op);
638     changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op);
639     TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op));
640     if (result) {
641       removed_conditionals.insert(conditional_op);
642       changed = true;
643     }
644   }
645   // Try to remove unused conditional operands from branch computations. We need
646   // to be careful to adjust *all* calling conditional ops if we do that, so
647   // lets collect them first.
648   absl::flat_hash_map<HloComputation*, absl::flat_hash_set<HloInstruction*>>
649       calling_conditionals;
650   // Keys of calling_conditionals to get a deterministic ordering.
651   std::vector<HloComputation*> calling_computationals_vector;
652   for (HloInstruction* conditional : conditional_ops) {
653     if (removed_conditionals.contains(conditional)) {
654       continue;
655     }
656 
657     for (int64_t branch = 0; branch < conditional->branch_count(); ++branch) {
658       auto* branch_comp = conditional->branch_computation(branch);
659       if (!calling_conditionals.contains(branch_comp)) {
660         calling_computationals_vector.push_back(branch_comp);
661       }
662       calling_conditionals[branch_comp].insert(conditional);
663     }
664   }
665 
666   for (auto* comp : calling_computationals_vector) {
667     auto entry = calling_conditionals.find(comp);
668     CHECK(entry != calling_conditionals.end());
669     TF_ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands(
670                                          entry->first, entry->second));
671     changed |= result;
672   }
673 
674   XLA_VLOG_LINES(3,
675                  "ConditionalSimplifier::Run(), after:\n" + module->ToString());
676   return changed;
677 }
678 
679 }  // namespace xla
680