xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/while_loop_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17 
18 #include "absl/base/casts.h"
19 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25 
26 namespace xla {
27 
28 using std::nullopt;
29 using std::optional;
30 namespace m = match;
31 
32 // Finds and returns the non-constant operand in instr.
33 //
34 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)35 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
36   const HloInstruction* result = nullptr;
37   for (const HloInstruction* operand : instr->operands()) {
38     if (!operand->IsConstant()) {
39       if (result != nullptr) {
40         CHECK_EQ(result, operand);
41       }
42       result = operand;
43     }
44   }
45   CHECK_NE(result, nullptr);
46   return result;
47 }
48 
49 // If all of instr's operands are either constants or have the form
50 //   get-tuple-element(gte_operand, N)
51 // for the same value N, returns N.  Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)52 static optional<int64_t> GetGTEOperandIndex(const HloInstruction* instr,
53                                             const HloInstruction* gte_operand) {
54   VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
55           << gte_operand->ToString() << ")";
56 
57   // All operands of `instr` must be either constants or of the form
58   //   get-tuple-element(gte_operand, tuple_idx)
59   // for the same value tuple_idx. We also support the case where GTE feeds a
60   // copy that is then used.
61   optional<int64_t> tuple_idx;
62   for (const HloInstruction* operand : instr->operands()) {
63     if (Match(operand, m::Constant())) {
64       continue;
65     }
66     auto possibly_gte_operand = operand;
67 
68     if (operand->opcode() == HloOpcode::kCopy) {
69       possibly_gte_operand = operand->operand(0);
70     }
71 
72     if (possibly_gte_operand->opcode() != HloOpcode::kGetTupleElement) {
73       return nullopt;
74     }
75 
76     if (!Match(possibly_gte_operand,
77                m::GetTupleElement(m::Op().Is(gte_operand)))) {
78       return nullopt;
79     }
80 
81     int64_t operand_tuple_idx = possibly_gte_operand->tuple_index();
82     // This is the first GTE we are seeing. Set tuple_idx.
83     if (!tuple_idx.has_value()) {
84       tuple_idx = operand_tuple_idx;
85     } else {
86       if (operand_tuple_idx != tuple_idx) {
87         return nullopt;
88       }
89     }
90   }
91   return tuple_idx;
92 }
93 
94 // The below function identifies a subset of all possible auxiliary
95 // induction variables (AIV). Specifically, candidates are gtes, e.g.,
96 // gte(param0, N)
97 // The function checks if the loop body plumbs the AIV
98 // through the same tuple index at root, and that ops involving AIV
99 // involve constants.
100 //   op2 = op(constants, gte(param0, N), constants)
101 //   op3 = op(constants, f(op2, gte(param0, N), constants)
102 //   op4 = op(constants, f(op3, constants)
103 //   root = tuple(..., op4, ...)
104 // Further, the ops are restricted to basic math ops (+,-,*,/).
105 // Finally, loop invariant GTEs are excluded from AIVs.
106 // We can expand the ops category/nature of AIVs as needed.
GetAuxiliaryLoopInductionVars(const HloInstruction * while_op)107 std::vector<const HloInstruction*> GetAuxiliaryLoopInductionVars(
108     const HloInstruction* while_op) {
109   std::vector<const HloInstruction*> aux_ind_gte;
110   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
111   auto* while_body = while_op->while_body();
112   auto* while_body_param = while_body->parameter_instruction(0);
113   VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString();
114   VLOG(2) << "the parameter instr:" << while_body_param->ToShortString();
115   VLOG(2) << "the parameter user count:" << while_body_param->users().size();
116   if (while_body_param == nullptr) return aux_ind_gte;
117 
118   // candidates_pairs = pair<inst, inst>(
119   //   operands of the root while body,
120   //   GTE only operands that index into the same position in the parameter)
121   // for each candidate_pair (x, y)
122   //  find all paths between x and y,
123   //  each paths should satisfy the above listed criterion
124   //  index that x and y used is added as a aux variable index
125   std::map<int64_t, const HloInstruction*> extractions;
126   for (const HloInstruction* indx_instr : while_body_param->users()) {
127     if (indx_instr->opcode() != HloOpcode::kGetTupleElement) {
128       continue;
129     }
130     auto it = extractions.find(indx_instr->tuple_index());
131     // if we find two extractions at the same index, we ignore such
132     // a candidate
133     if (it != extractions.end()) {
134       it->second = nullptr;
135       VLOG(2) << "two extractions at same index:" << indx_instr->ToString();
136     } else {
137       extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr));
138       VLOG(2) << "inserting extraction :" << indx_instr->ToString();
139     }
140   }
141   VLOG(2) << "total extractions size:" << extractions.size() << std::endl;
142   if (extractions.empty()) {
143     return aux_ind_gte;
144   }
145 
146   auto* while_body_root = while_body->root_instruction();
147   if (while_body_root->opcode() != HloOpcode::kTuple) {
148     VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString();
149     return aux_ind_gte;
150   }
151   int64_t index = -1;
152   std::map<int64_t, const HloInstruction*> insertions;
153   for (const HloInstruction* operand : while_body_root->operands()) {
154     index++;
155     if (!operand->IsConstant()) {
156       auto it = insertions.find(index);
157       if (it != insertions.end()) {
158         it->second = nullptr;
159         VLOG(2) << "two insertions at same index:" << operand->ToString();
160       } else {
161         insertions.insert(std::make_pair(index, operand));
162         VLOG(2) << "inserting insertions:" << operand->ToString();
163       }
164     }
165   }
166   if (insertions.empty()) {
167     return aux_ind_gte;
168   }
169 
170   std::map<int64_t, std::pair<const HloInstruction*, const HloInstruction*>>
171       candidate_pairs;
172   for (; index >= 0; --index) {
173     const HloInstruction *ext, *inst;
174     ext = (extractions.find(index) != extractions.end())
175               ? extractions.find(index)->second
176               : nullptr;
177     inst = (insertions.find(index) != insertions.end())
178                ? insertions.find(index)->second
179                : nullptr;
180     if (ext != nullptr && inst != nullptr) {
181       // Filter out trivial aux, i.e., extract directly to an insert.
182       if (ext != inst) {
183         candidate_pairs.insert(
184             std::make_pair(index, std::make_pair(ext, inst)));
185       }
186     }
187   }
188   VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl;
189 
190   // Passed to ReachabilityMap to decide the type of produce-consumer edges
191   // along the reachability path.
192   const auto add_dependencies = [](const HloInstruction* hlo,
193                                    std::vector<HloInstruction*>* inputs) {
194     HloInstruction* non_const_operand = nullptr;
195     int num_non_constants = 0;
196     for (HloInstruction* operand : hlo->operands()) {
197       if (!operand->IsConstant()) {
198         num_non_constants++;
199         non_const_operand = operand;
200       }
201     }
202     if (num_non_constants == 1 &&
203         (hlo->opcode() == HloOpcode::kGetTupleElement ||
204          hlo->opcode() == HloOpcode::kAdd ||
205          hlo->opcode() == HloOpcode::kMultiply ||
206          hlo->opcode() == HloOpcode::kDivide ||
207          hlo->opcode() == HloOpcode::kSubtract)) {
208       inputs->push_back(non_const_operand);
209     }
210   };
211 
212   std::unique_ptr<HloReachabilityMap> hrm =
213       HloReachabilityMap::BuildWithRestrictions(
214           while_body,
215           absl::FunctionRef<void(const HloInstruction* hlo,
216                                  std::vector<HloInstruction*>* inputs)>(
217               add_dependencies));
218 
219   for (auto candidates : candidate_pairs) {
220     VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString()
221             << "*************" << (candidates.second.second)->ToString()
222             << std::endl;
223     if (hrm->IsReachable(candidates.second.first, candidates.second.second)) {
224       aux_ind_gte.push_back(candidates.second.first);
225       VLOG(2) << "YES";
226     } else {
227       VLOG(2) << "NO";
228     }
229   }
230   VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size();
231   return aux_ind_gte;
232 }
233 
234 // Tries to get the tuple index of the induction variable of a while loop.
235 //
236 // Checks that the loop condition and body both plumb the induction variable
237 // through the same tuple index, and that they both apply exactly one op to the
238 // induction variable before  deciding whether to do another loop iteration (in
239 // the loop condition's case) or packing the induction variable into the result
240 // tuple (in the loop body's case).
241 //
242 // Specifically, checks that the loop condition has structure
243 //
244 //   root = op(constants, get-tuple-elem(param0, N), constants)
245 //
246 // and the loop body has the structure
247 //
248 //   inc = op(constants, get-tuple-elem(param0, N), constants)
249 //   root = tuple(..., inc, ...)  // inc is N'th operand of tuple().
250 //
251 // If so, returns N.  Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)252 optional<int64_t> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
253   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
254   VLOG(2) << "Finding induction variable for loop "
255           << while_op->ToShortString();
256 
257   // The while_cond computation should have the form
258   //
259   //   while_cond_root =
260   //       op(constants, get-tuple-elem(while_cond_param, N), constants).
261   //
262   // If it does, set indvar_tuple_idx to N.
263   auto* while_cond = while_op->while_condition();
264   auto* while_cond_root = while_cond->root_instruction();
265   auto* while_cond_param = while_cond->parameter_instruction(0);
266   optional<int64_t> indvar_tuple_idx =
267       GetGTEOperandIndex(while_cond_root, while_cond_param);
268   if (!indvar_tuple_idx) {
269     VLOG(2) << "Induction variable not found in loop condition: "
270             << while_cond->root_instruction()->ToString();
271     return nullopt;
272   }
273 
274   // The while_body computation should have the form
275   //
276   //   while_body_inc =
277   //       op(constants, get-tuple-elem(while_body_param, N), constants)
278   //   while_body_root = tuple(..., while_body_inc, ...)
279   //
280   // where while_body_inc is operand N of while_body_root.
281   auto* while_body = while_op->while_body();
282   auto* while_body_root = while_body->root_instruction();
283   if (while_body_root->opcode() != HloOpcode::kTuple) {
284     VLOG(2) << "While body's root is not a tuple instruction: "
285             << while_body_root->ToString();
286     return nullopt;
287   }
288 
289   auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
290   auto* while_body_param = while_body->parameter_instruction(0);
291   optional<int64_t> while_body_indvar_tuple_idx =
292       GetGTEOperandIndex(while_body_inc, while_body_param);
293   if (!while_body_indvar_tuple_idx) {
294     VLOG(2)
295         << "Induction variable not found in while body increment instruction: "
296         << while_body_inc->ToString();
297     return nullopt;
298   }
299   if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
300     VLOG(2) << "Tuple index of induction variable does not match between loop "
301                "condition ("
302             << *indvar_tuple_idx << ") and while body ("
303             << *while_body_indvar_tuple_idx << ")";
304     return nullopt;
305   }
306 
307   // Finally, check that the while loop's initial value is a tuple with enough
308   // elements.
309   auto* while_init = while_op->operand(0);
310   if (while_init->opcode() != HloOpcode::kTuple) {
311     VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
312     return nullopt;
313   }
314 
315   VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
316   return indvar_tuple_idx;
317 }
318 
319 // Converts the given literal to a scalar int64_t, if possible.
320 //
321 // Fails if the literal is not an integral type or if the value it contains
322 // cannot be represented in an int64_t.
LiteralAsScalarInt64(const Literal & l)323 static optional<int64_t> LiteralAsScalarInt64(const Literal& l) {
324   if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
325     VLOG(2) << "literal is not an effective scalar: " << l.ToString();
326     return nullopt;
327   }
328   switch (l.shape().element_type()) {
329     case S8:
330       return l.GetFirstElement<int8_t>();
331     case S16:
332       return l.GetFirstElement<int16_t>();
333     case S32:
334       return l.GetFirstElement<int32_t>();
335     case S64:
336       return l.GetFirstElement<int64_t>();
337     case U8:
338       return l.GetFirstElement<uint8_t>();
339     case U16:
340       return l.GetFirstElement<uint16_t>();
341     case U32:
342       return l.GetFirstElement<uint32_t>();
343     case U64: {
344       uint64_t v = l.GetFirstElement<uint64_t>();
345       if (v > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
346         VLOG(2) << "uint64_t literal is out of range for int64_t: " << v;
347         return nullopt;
348       }
349       return v;
350     }
351     default:
352       VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
353       return nullopt;
354   }
355 }
356 
357 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64_t a,int64_t b)358 optional<int64_t> CheckedAdd(int64_t a, int64_t b) {
359   // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
360   // different sign, see Hacker's Delignt 2nd Ed. pp 28.
361   uint64_t aa = absl::bit_cast<uint64_t>(a);
362   uint64_t bb = absl::bit_cast<uint64_t>(b);
363   int64_t result = absl::bit_cast<int64_t>(aa + bb);
364   if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
365     return nullopt;
366   }
367   return result;
368 }
369 
370 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64_t a,int64_t b)371 optional<int64_t> CheckedSubtract(int64_t a, int64_t b) {
372   uint64_t aa = absl::bit_cast<uint64_t>(a);
373   uint64_t bb = absl::bit_cast<uint64_t>(b);
374   int64_t result = absl::bit_cast<int64_t>(aa - bb);
375   // Overflow occurred iff `a` and `b` have different signs and the sign of
376   // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
377   if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
378     return nullopt;
379   }
380   return result;
381 }
382 
383 // Check if
384 //  - `i` is initialized to a scalar constant K (namely, `indvar_init`),
385 //  - the while condition does `i < N` or `i <= N`, and
386 //  - the while body does `i++`.
387 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64_t indvar_tuple_idx,const Literal & indvar_init)388 static optional<int64_t> PatternMatchLoopTripCount(HloInstruction* while_op,
389                                                    int64_t indvar_tuple_idx,
390                                                    const Literal& indvar_init) {
391   // First, find the scalar constant K that `i` is initialized to.
392   optional<int64_t> indvar_init_val = LiteralAsScalarInt64(indvar_init);
393   if (!indvar_init_val) {
394     VLOG(2) << "Pattern-match failed: induction variable init is not a "
395                "constant scalar representable as an int64_t: "
396             << indvar_init.ToString();
397     return nullopt;
398   }
399 
400   // Check that `i` goes as `i++` in the while body.
401   //
402   // TODO(jlebar): We could also handle i-- and other idioms.
403   auto* while_body = while_op->while_body();
404   auto* while_body_indvar_update =
405       while_body->root_instruction()->operand(indvar_tuple_idx);
406   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
407   if (!Match(while_body_indvar_update,
408              m::AddAnyOrder(m::Op().Is(while_body_indvar),
409                             m::ConstantEffectiveScalar(1)))) {
410     VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
411             << while_body_indvar_update->ToString();
412     return nullopt;
413   }
414 
415   // Check that we do op(i, N) or op(N, i) as the while condition.  Capture the
416   // value N.
417   auto* while_cond = while_op->while_condition();
418   auto* while_cond_root = while_cond->root_instruction();
419   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
420   HloInstruction* while_cond_bound = nullptr;
421   if (!Match(while_cond_root,
422              m::Op().WithBinaryOperandsAnyOrder(
423                  m::Op().Is(while_cond_indvar),
424                  m::ConstantEffectiveScalar(&while_cond_bound)))) {
425     VLOG(2) << "Pattern-match failed: while condition is not of the form "
426                "op(i, N) or op(N, i).";
427     return nullopt;
428   }
429   // Note: If this succeeds, the constant `N` is representable as an int64_t --
430   // that is, if it's an XLA U64, it fits within an int64_t.
431   optional<int64_t> while_cond_bound_val =
432       LiteralAsScalarInt64(while_cond_bound->literal());
433   if (!while_cond_bound_val) {
434     VLOG(2) << "Pattern-match failed: while condition induction variable is "
435                "not a constant scalar representable as an int64_t.";
436     return nullopt;
437   }
438 
439   // Handle `i = K; i < N; ++i`.
440   if (Match(while_cond_root,
441             m::Op()
442                 .WithComparisonDirection(ComparisonDirection::kLt)
443                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
444     VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
445             << while_cond_root->ToString();
446     optional<int64_t> trips =
447         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
448     if (trips) {
449       return std::max(int64_t{0}, *trips);
450     } else {
451       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
452       return nullopt;
453     }
454   }
455 
456   // Handle `i = K; i <= N; ++i`.
457   if (Match(while_cond_root,
458             m::Op()
459                 .WithComparisonDirection(ComparisonDirection::kLe)
460                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
461     VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
462             << while_cond_root->ToString();
463     optional<int64_t> trips =
464         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
465     if (!trips) {
466       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
467       return nullopt;
468     }
469     trips = CheckedAdd(*trips, 1);
470     if (!trips) {
471       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
472       return nullopt;
473     }
474     return std::max<int64_t>(0, *trips);
475   }
476 
477   VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
478           << while_cond_root->ToString();
479   return nullopt;
480 }
481 
ComputeWhileLoopTripCount(HloInstruction * while_op,int64_t max_brute_force_iters)482 optional<int64_t> ComputeWhileLoopTripCount(HloInstruction* while_op,
483                                             int64_t max_brute_force_iters) {
484   VLOG(2) << "Getting trip count for loop " << while_op->ToString();
485 
486   // The loop's induction variable is found at
487   //
488   //   get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
489   //
490   // where comp is while_op->while_body() or while_op->while_condition().
491   optional<int64_t> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
492   if (!indvar_tuple_idx) {
493     return nullopt;
494   }
495 
496   // Now that we know the index of the induction variable, we can we can try to
497   // compute how many times the loop executes.  Start by computing the induction
498   // variable's initial value.
499   HloEvaluator evaluator(/*max_loop_iterations=*/0);
500   auto* while_init = while_op->mutable_operand(0);
501   auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
502   StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
503   if (!indvar_init_result.ok()) {
504     VLOG(2) << "Couldn't evaluate induction variable init, "
505             << indvar_init_result.status() << ", " << indvar_init->ToString();
506     return nullopt;
507   }
508   Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
509 
510   // First, try to pattern-match.
511   if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
512                                                   indvar_iter_val)) {
513     return trip_count;
514   }
515 
516   // If our pattern-match failed, try brute-forcing the loop trip count.
517   auto* while_body = while_op->while_body();
518   auto* while_body_indvar_update =
519       while_body->root_instruction()->operand(*indvar_tuple_idx);
520   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
521 
522   auto* while_cond = while_op->while_condition();
523   auto* while_cond_root = while_cond->root_instruction();
524   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
525 
526   for (int64_t trip_count = 0; trip_count != max_brute_force_iters + 1;
527        ++trip_count) {
528     StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
529         while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
530     if (!result.ok()) {
531       VLOG(2) << "Couldn't evaluate while cond: " << result.status();
532       return nullopt;
533     }
534     if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
535       VLOG(2) << "Loop has static trip count of " << trip_count;
536       return trip_count;
537     }
538 
539     // Calculate the value of the induction variable after one iteration of the
540     // loop, and check whether the while condition is true with this new value.
541     StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
542         while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
543     if (!indvar_next_result.ok()) {
544       VLOG(2) << "Couldn't evaluate induction variable update: "
545               << indvar_next_result.status();
546       return nullopt;
547     }
548     indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
549   }
550 
551   VLOG(2) << "Loop has unknown trip count.";
552   return nullopt;
553 }
554 
555 // If the only user of this instruction is a get-tuple-element, return that
556 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
557 // get a false negative if there are several copies of the same GTE, or there
558 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)559 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
560   if (inst->user_count() != 1) {
561     return nullptr;
562   }
563 
564   HloInstruction* user = inst->users().back();
565   if (user->opcode() != HloOpcode::kGetTupleElement) {
566     return nullptr;
567   }
568   return user;
569 }
570 
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)571 optional<int64_t> ComputeWhileLoopTripCountUpperBound(
572     HloInstruction* while_op) {
573   // If we know the exact trip count, it's also the upper bound.
574   auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
575   if (exact_trip_count) {
576     VLOG(2) << "Loop has exact trip count.";
577     return exact_trip_count;
578   }
579 
580   // There is one more case we know how to handle. If the loop condition only
581   // looks at one element of the tuple, and the loop body sets this element to a
582   // constant, there are two options:
583   // 1) Evaluating the condition on this constant returns true. In this case,
584   // the loop either executes 0 times, or is an infinite loop, depending on the
585   // init value.
586   // 2) Evaluating the condition on this constant returns false. In this case,
587   // the loop executes 0 or 1 times, depending on the init value. This means
588   // that, regardless of the init value, the upper bound on the trip count is 1.
589 
590   // Check whether the condition depends on a single parameter, and find out
591   // which.
592   auto* while_cond = while_op->while_condition();
593   auto* while_cond_param = while_cond->parameter_instruction(0);
594   auto* cond_gte = GetOnlyGTE(while_cond_param);
595   if (!cond_gte) {
596     VLOG(2) << "Induction variable not found in loop condition: "
597             << while_cond->root_instruction()->ToString();
598     return nullopt;
599   }
600 
601   // Now check whether this gets set to a constant by the while body.
602   auto* while_body = while_op->while_body();
603   auto* while_body_root = while_body->root_instruction();
604   if (while_body_root->opcode() != HloOpcode::kTuple) {
605     VLOG(3) << "While body's root is not a tuple instruction: "
606             << while_body_root->ToString();
607     return nullopt;
608   }
609 
610   int64_t indvar_index = cond_gte->tuple_index();
611   auto* while_body_indvar = while_body_root->operand(indvar_index);
612   if (while_body_indvar->opcode() != HloOpcode::kConstant) {
613     VLOG(3) << "While body does not set the IV to a constant: "
614             << while_body_indvar->ToString();
615     return nullopt;
616   }
617 
618   // We have a constant. Evaluate the condition on this constant.
619   HloEvaluator evaluator(/*max_loop_iterations=*/0);
620   Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
621   TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
622                                   /*dest_shape_index=*/{indvar_index},
623                                   /*src_shape_index=*/{}));
624   StatusOr<Literal> eval_result =
625       evaluator.Evaluate(*while_cond, {std::move(fake_input)});
626 
627   if (!eval_result.ok()) {
628     VLOG(2) << "Couldn't evaluate while loop condition.";
629     return nullopt;
630   }
631 
632   Literal cond_result_pred = std::move(eval_result.ValueOrDie());
633   CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
634                                       ShapeUtil::MakeShape(PRED, {})));
635 
636   // Per the explanation above, if the evaluated condition returns false, the
637   // loop executes at most once.
638   bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
639   if (!cond_returns_true) {
640     VLOG(2) << "Upper bound on the trip count is 1";
641     return 1;
642   }
643 
644   VLOG(2) << "Loop has no known upper bound on the trip count.";
645   return nullopt;
646 }
647 
648 }  // namespace xla
649