1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17
18 #include "absl/base/casts.h"
19 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25
26 namespace xla {
27
28 using std::nullopt;
29 using std::optional;
30 namespace m = match;
31
32 // Finds and returns the non-constant operand in instr.
33 //
34 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)35 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
36 const HloInstruction* result = nullptr;
37 for (const HloInstruction* operand : instr->operands()) {
38 if (!operand->IsConstant()) {
39 if (result != nullptr) {
40 CHECK_EQ(result, operand);
41 }
42 result = operand;
43 }
44 }
45 CHECK_NE(result, nullptr);
46 return result;
47 }
48
49 // If all of instr's operands are either constants or have the form
50 // get-tuple-element(gte_operand, N)
51 // for the same value N, returns N. Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)52 static optional<int64_t> GetGTEOperandIndex(const HloInstruction* instr,
53 const HloInstruction* gte_operand) {
54 VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
55 << gte_operand->ToString() << ")";
56
57 // All operands of `instr` must be either constants or of the form
58 // get-tuple-element(gte_operand, tuple_idx)
59 // for the same value tuple_idx. We also support the case where GTE feeds a
60 // copy that is then used.
61 optional<int64_t> tuple_idx;
62 for (const HloInstruction* operand : instr->operands()) {
63 if (Match(operand, m::Constant())) {
64 continue;
65 }
66 auto possibly_gte_operand = operand;
67
68 if (operand->opcode() == HloOpcode::kCopy) {
69 possibly_gte_operand = operand->operand(0);
70 }
71
72 if (possibly_gte_operand->opcode() != HloOpcode::kGetTupleElement) {
73 return nullopt;
74 }
75
76 if (!Match(possibly_gte_operand,
77 m::GetTupleElement(m::Op().Is(gte_operand)))) {
78 return nullopt;
79 }
80
81 int64_t operand_tuple_idx = possibly_gte_operand->tuple_index();
82 // This is the first GTE we are seeing. Set tuple_idx.
83 if (!tuple_idx.has_value()) {
84 tuple_idx = operand_tuple_idx;
85 } else {
86 if (operand_tuple_idx != tuple_idx) {
87 return nullopt;
88 }
89 }
90 }
91 return tuple_idx;
92 }
93
94 // The below function identifies a subset of all possible auxiliary
95 // induction variables (AIV). Specifically, candidates are gtes, e.g.,
96 // gte(param0, N)
97 // The function checks if the loop body plumbs the AIV
98 // through the same tuple index at root, and that ops involving AIV
99 // involve constants.
100 // op2 = op(constants, gte(param0, N), constants)
101 // op3 = op(constants, f(op2, gte(param0, N), constants)
102 // op4 = op(constants, f(op3, constants)
103 // root = tuple(..., op4, ...)
104 // Further, the ops are restricted to basic math ops (+,-,*,/).
105 // Finally, loop invariant GTEs are excluded from AIVs.
106 // We can expand the ops category/nature of AIVs as needed.
GetAuxiliaryLoopInductionVars(const HloInstruction * while_op)107 std::vector<const HloInstruction*> GetAuxiliaryLoopInductionVars(
108 const HloInstruction* while_op) {
109 std::vector<const HloInstruction*> aux_ind_gte;
110 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
111 auto* while_body = while_op->while_body();
112 auto* while_body_param = while_body->parameter_instruction(0);
113 VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString();
114 VLOG(2) << "the parameter instr:" << while_body_param->ToShortString();
115 VLOG(2) << "the parameter user count:" << while_body_param->users().size();
116 if (while_body_param == nullptr) return aux_ind_gte;
117
118 // candidates_pairs = pair<inst, inst>(
119 // operands of the root while body,
120 // GTE only operands that index into the same position in the parameter)
121 // for each candidate_pair (x, y)
122 // find all paths between x and y,
123 // each paths should satisfy the above listed criterion
124 // index that x and y used is added as a aux variable index
125 std::map<int64_t, const HloInstruction*> extractions;
126 for (const HloInstruction* indx_instr : while_body_param->users()) {
127 if (indx_instr->opcode() != HloOpcode::kGetTupleElement) {
128 continue;
129 }
130 auto it = extractions.find(indx_instr->tuple_index());
131 // if we find two extractions at the same index, we ignore such
132 // a candidate
133 if (it != extractions.end()) {
134 it->second = nullptr;
135 VLOG(2) << "two extractions at same index:" << indx_instr->ToString();
136 } else {
137 extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr));
138 VLOG(2) << "inserting extraction :" << indx_instr->ToString();
139 }
140 }
141 VLOG(2) << "total extractions size:" << extractions.size() << std::endl;
142 if (extractions.empty()) {
143 return aux_ind_gte;
144 }
145
146 auto* while_body_root = while_body->root_instruction();
147 if (while_body_root->opcode() != HloOpcode::kTuple) {
148 VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString();
149 return aux_ind_gte;
150 }
151 int64_t index = -1;
152 std::map<int64_t, const HloInstruction*> insertions;
153 for (const HloInstruction* operand : while_body_root->operands()) {
154 index++;
155 if (!operand->IsConstant()) {
156 auto it = insertions.find(index);
157 if (it != insertions.end()) {
158 it->second = nullptr;
159 VLOG(2) << "two insertions at same index:" << operand->ToString();
160 } else {
161 insertions.insert(std::make_pair(index, operand));
162 VLOG(2) << "inserting insertions:" << operand->ToString();
163 }
164 }
165 }
166 if (insertions.empty()) {
167 return aux_ind_gte;
168 }
169
170 std::map<int64_t, std::pair<const HloInstruction*, const HloInstruction*>>
171 candidate_pairs;
172 for (; index >= 0; --index) {
173 const HloInstruction *ext, *inst;
174 ext = (extractions.find(index) != extractions.end())
175 ? extractions.find(index)->second
176 : nullptr;
177 inst = (insertions.find(index) != insertions.end())
178 ? insertions.find(index)->second
179 : nullptr;
180 if (ext != nullptr && inst != nullptr) {
181 // Filter out trivial aux, i.e., extract directly to an insert.
182 if (ext != inst) {
183 candidate_pairs.insert(
184 std::make_pair(index, std::make_pair(ext, inst)));
185 }
186 }
187 }
188 VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl;
189
190 // Passed to ReachabilityMap to decide the type of produce-consumer edges
191 // along the reachability path.
192 const auto add_dependencies = [](const HloInstruction* hlo,
193 std::vector<HloInstruction*>* inputs) {
194 HloInstruction* non_const_operand = nullptr;
195 int num_non_constants = 0;
196 for (HloInstruction* operand : hlo->operands()) {
197 if (!operand->IsConstant()) {
198 num_non_constants++;
199 non_const_operand = operand;
200 }
201 }
202 if (num_non_constants == 1 &&
203 (hlo->opcode() == HloOpcode::kGetTupleElement ||
204 hlo->opcode() == HloOpcode::kAdd ||
205 hlo->opcode() == HloOpcode::kMultiply ||
206 hlo->opcode() == HloOpcode::kDivide ||
207 hlo->opcode() == HloOpcode::kSubtract)) {
208 inputs->push_back(non_const_operand);
209 }
210 };
211
212 std::unique_ptr<HloReachabilityMap> hrm =
213 HloReachabilityMap::BuildWithRestrictions(
214 while_body,
215 absl::FunctionRef<void(const HloInstruction* hlo,
216 std::vector<HloInstruction*>* inputs)>(
217 add_dependencies));
218
219 for (auto candidates : candidate_pairs) {
220 VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString()
221 << "*************" << (candidates.second.second)->ToString()
222 << std::endl;
223 if (hrm->IsReachable(candidates.second.first, candidates.second.second)) {
224 aux_ind_gte.push_back(candidates.second.first);
225 VLOG(2) << "YES";
226 } else {
227 VLOG(2) << "NO";
228 }
229 }
230 VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size();
231 return aux_ind_gte;
232 }
233
234 // Tries to get the tuple index of the induction variable of a while loop.
235 //
236 // Checks that the loop condition and body both plumb the induction variable
237 // through the same tuple index, and that they both apply exactly one op to the
238 // induction variable before deciding whether to do another loop iteration (in
239 // the loop condition's case) or packing the induction variable into the result
240 // tuple (in the loop body's case).
241 //
242 // Specifically, checks that the loop condition has structure
243 //
244 // root = op(constants, get-tuple-elem(param0, N), constants)
245 //
246 // and the loop body has the structure
247 //
248 // inc = op(constants, get-tuple-elem(param0, N), constants)
249 // root = tuple(..., inc, ...) // inc is N'th operand of tuple().
250 //
251 // If so, returns N. Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)252 optional<int64_t> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
253 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
254 VLOG(2) << "Finding induction variable for loop "
255 << while_op->ToShortString();
256
257 // The while_cond computation should have the form
258 //
259 // while_cond_root =
260 // op(constants, get-tuple-elem(while_cond_param, N), constants).
261 //
262 // If it does, set indvar_tuple_idx to N.
263 auto* while_cond = while_op->while_condition();
264 auto* while_cond_root = while_cond->root_instruction();
265 auto* while_cond_param = while_cond->parameter_instruction(0);
266 optional<int64_t> indvar_tuple_idx =
267 GetGTEOperandIndex(while_cond_root, while_cond_param);
268 if (!indvar_tuple_idx) {
269 VLOG(2) << "Induction variable not found in loop condition: "
270 << while_cond->root_instruction()->ToString();
271 return nullopt;
272 }
273
274 // The while_body computation should have the form
275 //
276 // while_body_inc =
277 // op(constants, get-tuple-elem(while_body_param, N), constants)
278 // while_body_root = tuple(..., while_body_inc, ...)
279 //
280 // where while_body_inc is operand N of while_body_root.
281 auto* while_body = while_op->while_body();
282 auto* while_body_root = while_body->root_instruction();
283 if (while_body_root->opcode() != HloOpcode::kTuple) {
284 VLOG(2) << "While body's root is not a tuple instruction: "
285 << while_body_root->ToString();
286 return nullopt;
287 }
288
289 auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
290 auto* while_body_param = while_body->parameter_instruction(0);
291 optional<int64_t> while_body_indvar_tuple_idx =
292 GetGTEOperandIndex(while_body_inc, while_body_param);
293 if (!while_body_indvar_tuple_idx) {
294 VLOG(2)
295 << "Induction variable not found in while body increment instruction: "
296 << while_body_inc->ToString();
297 return nullopt;
298 }
299 if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
300 VLOG(2) << "Tuple index of induction variable does not match between loop "
301 "condition ("
302 << *indvar_tuple_idx << ") and while body ("
303 << *while_body_indvar_tuple_idx << ")";
304 return nullopt;
305 }
306
307 // Finally, check that the while loop's initial value is a tuple with enough
308 // elements.
309 auto* while_init = while_op->operand(0);
310 if (while_init->opcode() != HloOpcode::kTuple) {
311 VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
312 return nullopt;
313 }
314
315 VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
316 return indvar_tuple_idx;
317 }
318
319 // Converts the given literal to a scalar int64_t, if possible.
320 //
321 // Fails if the literal is not an integral type or if the value it contains
322 // cannot be represented in an int64_t.
LiteralAsScalarInt64(const Literal & l)323 static optional<int64_t> LiteralAsScalarInt64(const Literal& l) {
324 if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
325 VLOG(2) << "literal is not an effective scalar: " << l.ToString();
326 return nullopt;
327 }
328 switch (l.shape().element_type()) {
329 case S8:
330 return l.GetFirstElement<int8_t>();
331 case S16:
332 return l.GetFirstElement<int16_t>();
333 case S32:
334 return l.GetFirstElement<int32_t>();
335 case S64:
336 return l.GetFirstElement<int64_t>();
337 case U8:
338 return l.GetFirstElement<uint8_t>();
339 case U16:
340 return l.GetFirstElement<uint16_t>();
341 case U32:
342 return l.GetFirstElement<uint32_t>();
343 case U64: {
344 uint64_t v = l.GetFirstElement<uint64_t>();
345 if (v > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
346 VLOG(2) << "uint64_t literal is out of range for int64_t: " << v;
347 return nullopt;
348 }
349 return v;
350 }
351 default:
352 VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
353 return nullopt;
354 }
355 }
356
357 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64_t a,int64_t b)358 optional<int64_t> CheckedAdd(int64_t a, int64_t b) {
359 // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
360 // different sign, see Hacker's Delignt 2nd Ed. pp 28.
361 uint64_t aa = absl::bit_cast<uint64_t>(a);
362 uint64_t bb = absl::bit_cast<uint64_t>(b);
363 int64_t result = absl::bit_cast<int64_t>(aa + bb);
364 if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
365 return nullopt;
366 }
367 return result;
368 }
369
370 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64_t a,int64_t b)371 optional<int64_t> CheckedSubtract(int64_t a, int64_t b) {
372 uint64_t aa = absl::bit_cast<uint64_t>(a);
373 uint64_t bb = absl::bit_cast<uint64_t>(b);
374 int64_t result = absl::bit_cast<int64_t>(aa - bb);
375 // Overflow occurred iff `a` and `b` have different signs and the sign of
376 // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
377 if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
378 return nullopt;
379 }
380 return result;
381 }
382
383 // Check if
384 // - `i` is initialized to a scalar constant K (namely, `indvar_init`),
385 // - the while condition does `i < N` or `i <= N`, and
386 // - the while body does `i++`.
387 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64_t indvar_tuple_idx,const Literal & indvar_init)388 static optional<int64_t> PatternMatchLoopTripCount(HloInstruction* while_op,
389 int64_t indvar_tuple_idx,
390 const Literal& indvar_init) {
391 // First, find the scalar constant K that `i` is initialized to.
392 optional<int64_t> indvar_init_val = LiteralAsScalarInt64(indvar_init);
393 if (!indvar_init_val) {
394 VLOG(2) << "Pattern-match failed: induction variable init is not a "
395 "constant scalar representable as an int64_t: "
396 << indvar_init.ToString();
397 return nullopt;
398 }
399
400 // Check that `i` goes as `i++` in the while body.
401 //
402 // TODO(jlebar): We could also handle i-- and other idioms.
403 auto* while_body = while_op->while_body();
404 auto* while_body_indvar_update =
405 while_body->root_instruction()->operand(indvar_tuple_idx);
406 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
407 if (!Match(while_body_indvar_update,
408 m::AddAnyOrder(m::Op().Is(while_body_indvar),
409 m::ConstantEffectiveScalar(1)))) {
410 VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
411 << while_body_indvar_update->ToString();
412 return nullopt;
413 }
414
415 // Check that we do op(i, N) or op(N, i) as the while condition. Capture the
416 // value N.
417 auto* while_cond = while_op->while_condition();
418 auto* while_cond_root = while_cond->root_instruction();
419 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
420 HloInstruction* while_cond_bound = nullptr;
421 if (!Match(while_cond_root,
422 m::Op().WithBinaryOperandsAnyOrder(
423 m::Op().Is(while_cond_indvar),
424 m::ConstantEffectiveScalar(&while_cond_bound)))) {
425 VLOG(2) << "Pattern-match failed: while condition is not of the form "
426 "op(i, N) or op(N, i).";
427 return nullopt;
428 }
429 // Note: If this succeeds, the constant `N` is representable as an int64_t --
430 // that is, if it's an XLA U64, it fits within an int64_t.
431 optional<int64_t> while_cond_bound_val =
432 LiteralAsScalarInt64(while_cond_bound->literal());
433 if (!while_cond_bound_val) {
434 VLOG(2) << "Pattern-match failed: while condition induction variable is "
435 "not a constant scalar representable as an int64_t.";
436 return nullopt;
437 }
438
439 // Handle `i = K; i < N; ++i`.
440 if (Match(while_cond_root,
441 m::Op()
442 .WithComparisonDirection(ComparisonDirection::kLt)
443 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
444 VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
445 << while_cond_root->ToString();
446 optional<int64_t> trips =
447 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
448 if (trips) {
449 return std::max(int64_t{0}, *trips);
450 } else {
451 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
452 return nullopt;
453 }
454 }
455
456 // Handle `i = K; i <= N; ++i`.
457 if (Match(while_cond_root,
458 m::Op()
459 .WithComparisonDirection(ComparisonDirection::kLe)
460 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
461 VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
462 << while_cond_root->ToString();
463 optional<int64_t> trips =
464 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
465 if (!trips) {
466 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
467 return nullopt;
468 }
469 trips = CheckedAdd(*trips, 1);
470 if (!trips) {
471 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
472 return nullopt;
473 }
474 return std::max<int64_t>(0, *trips);
475 }
476
477 VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
478 << while_cond_root->ToString();
479 return nullopt;
480 }
481
ComputeWhileLoopTripCount(HloInstruction * while_op,int64_t max_brute_force_iters)482 optional<int64_t> ComputeWhileLoopTripCount(HloInstruction* while_op,
483 int64_t max_brute_force_iters) {
484 VLOG(2) << "Getting trip count for loop " << while_op->ToString();
485
486 // The loop's induction variable is found at
487 //
488 // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
489 //
490 // where comp is while_op->while_body() or while_op->while_condition().
491 optional<int64_t> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
492 if (!indvar_tuple_idx) {
493 return nullopt;
494 }
495
496 // Now that we know the index of the induction variable, we can we can try to
497 // compute how many times the loop executes. Start by computing the induction
498 // variable's initial value.
499 HloEvaluator evaluator(/*max_loop_iterations=*/0);
500 auto* while_init = while_op->mutable_operand(0);
501 auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
502 StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
503 if (!indvar_init_result.ok()) {
504 VLOG(2) << "Couldn't evaluate induction variable init, "
505 << indvar_init_result.status() << ", " << indvar_init->ToString();
506 return nullopt;
507 }
508 Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
509
510 // First, try to pattern-match.
511 if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
512 indvar_iter_val)) {
513 return trip_count;
514 }
515
516 // If our pattern-match failed, try brute-forcing the loop trip count.
517 auto* while_body = while_op->while_body();
518 auto* while_body_indvar_update =
519 while_body->root_instruction()->operand(*indvar_tuple_idx);
520 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
521
522 auto* while_cond = while_op->while_condition();
523 auto* while_cond_root = while_cond->root_instruction();
524 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
525
526 for (int64_t trip_count = 0; trip_count != max_brute_force_iters + 1;
527 ++trip_count) {
528 StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
529 while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
530 if (!result.ok()) {
531 VLOG(2) << "Couldn't evaluate while cond: " << result.status();
532 return nullopt;
533 }
534 if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
535 VLOG(2) << "Loop has static trip count of " << trip_count;
536 return trip_count;
537 }
538
539 // Calculate the value of the induction variable after one iteration of the
540 // loop, and check whether the while condition is true with this new value.
541 StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
542 while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
543 if (!indvar_next_result.ok()) {
544 VLOG(2) << "Couldn't evaluate induction variable update: "
545 << indvar_next_result.status();
546 return nullopt;
547 }
548 indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
549 }
550
551 VLOG(2) << "Loop has unknown trip count.";
552 return nullopt;
553 }
554
555 // If the only user of this instruction is a get-tuple-element, return that
556 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
557 // get a false negative if there are several copies of the same GTE, or there
558 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)559 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
560 if (inst->user_count() != 1) {
561 return nullptr;
562 }
563
564 HloInstruction* user = inst->users().back();
565 if (user->opcode() != HloOpcode::kGetTupleElement) {
566 return nullptr;
567 }
568 return user;
569 }
570
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)571 optional<int64_t> ComputeWhileLoopTripCountUpperBound(
572 HloInstruction* while_op) {
573 // If we know the exact trip count, it's also the upper bound.
574 auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
575 if (exact_trip_count) {
576 VLOG(2) << "Loop has exact trip count.";
577 return exact_trip_count;
578 }
579
580 // There is one more case we know how to handle. If the loop condition only
581 // looks at one element of the tuple, and the loop body sets this element to a
582 // constant, there are two options:
583 // 1) Evaluating the condition on this constant returns true. In this case,
584 // the loop either executes 0 times, or is an infinite loop, depending on the
585 // init value.
586 // 2) Evaluating the condition on this constant returns false. In this case,
587 // the loop executes 0 or 1 times, depending on the init value. This means
588 // that, regardless of the init value, the upper bound on the trip count is 1.
589
590 // Check whether the condition depends on a single parameter, and find out
591 // which.
592 auto* while_cond = while_op->while_condition();
593 auto* while_cond_param = while_cond->parameter_instruction(0);
594 auto* cond_gte = GetOnlyGTE(while_cond_param);
595 if (!cond_gte) {
596 VLOG(2) << "Induction variable not found in loop condition: "
597 << while_cond->root_instruction()->ToString();
598 return nullopt;
599 }
600
601 // Now check whether this gets set to a constant by the while body.
602 auto* while_body = while_op->while_body();
603 auto* while_body_root = while_body->root_instruction();
604 if (while_body_root->opcode() != HloOpcode::kTuple) {
605 VLOG(3) << "While body's root is not a tuple instruction: "
606 << while_body_root->ToString();
607 return nullopt;
608 }
609
610 int64_t indvar_index = cond_gte->tuple_index();
611 auto* while_body_indvar = while_body_root->operand(indvar_index);
612 if (while_body_indvar->opcode() != HloOpcode::kConstant) {
613 VLOG(3) << "While body does not set the IV to a constant: "
614 << while_body_indvar->ToString();
615 return nullopt;
616 }
617
618 // We have a constant. Evaluate the condition on this constant.
619 HloEvaluator evaluator(/*max_loop_iterations=*/0);
620 Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
621 TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
622 /*dest_shape_index=*/{indvar_index},
623 /*src_shape_index=*/{}));
624 StatusOr<Literal> eval_result =
625 evaluator.Evaluate(*while_cond, {std::move(fake_input)});
626
627 if (!eval_result.ok()) {
628 VLOG(2) << "Couldn't evaluate while loop condition.";
629 return nullopt;
630 }
631
632 Literal cond_result_pred = std::move(eval_result.ValueOrDie());
633 CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
634 ShapeUtil::MakeShape(PRED, {})));
635
636 // Per the explanation above, if the evaluated condition returns false, the
637 // loop executes at most once.
638 bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
639 if (!cond_returns_true) {
640 VLOG(2) << "Upper bound on the trip count is 1";
641 return 1;
642 }
643
644 VLOG(2) << "Loop has no known upper bound on the trip count.";
645 return nullopt;
646 }
647
648 } // namespace xla
649