xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/compiler/xla/service/while_util.h"
21 #include "tensorflow/compiler/xla/util.h"
22 
23 namespace xla {
24 namespace {
25 
26 // Replaces all uses of old_instr with new_instr except the use at
27 // `while_body_root` (which must be a tuple instruction) at index `tuple_index`.
28 // This utility helps us replace an instruction in the while body with a
29 // constant while still keeping it trivially loop invariant.
ReplaceUsesWhileKeepingLoopInvariance(HloInstruction * old_instr,HloInstruction * new_instr,HloInstruction * while_body_root,int64_t tuple_index)30 Status ReplaceUsesWhileKeepingLoopInvariance(HloInstruction* old_instr,
31                                              HloInstruction* new_instr,
32                                              HloInstruction* while_body_root,
33                                              int64_t tuple_index) {
34   CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple);
35 
36   std::vector<HloInstruction*> users;
37   users.reserve(old_instr->user_count());
38   absl::c_copy(old_instr->users(), std::back_inserter(users));
39 
40   for (auto* user : users) {
41     for (int64_t i = 0, e = user->operand_count(); i < e; i++) {
42       if (user->operand(i) == old_instr &&
43           !(user == while_body_root && i == tuple_index)) {
44         TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr));
45       }
46     }
47   }
48 
49   return OkStatus();
50 }
51 
CloneHelper(const HloInstruction * instruction,HloComputation * computation)52 HloInstruction* CloneHelper(const HloInstruction* instruction,
53                             HloComputation* computation) {
54   if (instruction->opcode() == HloOpcode::kConstant) {
55     return computation->AddInstruction(instruction->Clone(/*suffix=*/".sunk"));
56   }
57   if (instruction->opcode() == HloOpcode::kBroadcast) {
58     return computation->AddInstruction(instruction->CloneWithNewOperands(
59         instruction->shape(),
60         {CloneHelper(instruction->operand(0), computation)}));
61   }
62   LOG(FATAL) << "Unexpected instruction.";
63 }
64 
65 }  // namespace
66 
TrySinkingConstantsIntoWhileLoop(HloInstruction * while_instr)67 StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
68     HloInstruction* while_instr) {
69   HloComputation* while_cond = while_instr->while_condition();
70   HloComputation* while_body = while_instr->while_body();
71 
72   const HloInstruction& init_value = *while_instr->operand(0);
73   if (init_value.opcode() != HloOpcode::kTuple) {
74     return false;
75   }
76 
77   bool changed = false;
78 
79   absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>
80       conditional_gte_index_to_insts =
81           WhileUtil::GetGTEsMapForWhileConditional(*while_cond);
82   std::vector<HloInstruction*> invariant_body_gtes =
83       WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
84 
85   for (HloInstruction* invariant_body_gte : invariant_body_gtes) {
86     int64_t index = invariant_body_gte->tuple_index();
87     const HloInstruction& invariant_value = *init_value.operand(index);
88 
89     // Original value should be a constant or broadcast of constant.
90     if (invariant_value.opcode() != HloOpcode::kConstant &&
91         (!sink_broadcast_of_constants_ ||
92          invariant_value.opcode() != HloOpcode::kBroadcast ||
93          invariant_value.operand(0)->opcode() != HloOpcode::kConstant)) {
94       continue;
95     }
96 
97     // Sink into the while_body.
98     // Should have at least one user that's not while_body_root.
99     if (invariant_body_gte->user_count() > 1) {
100       HloInstruction* constant_instr =
101           CloneHelper(&invariant_value, while_body);
102       TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance(
103           invariant_body_gte, constant_instr, while_body->root_instruction(),
104           index));
105       changed = true;
106     }
107 
108     // Check if there is a corresponding GTE in while_conditional.
109     auto it = conditional_gte_index_to_insts.find(index);
110     if (it == conditional_gte_index_to_insts.end()) {
111       continue;
112     }
113 
114     for (HloInstruction* invariant_cond_gte : it->second) {
115       // Should have at least one user.
116       if (invariant_cond_gte->user_count() > 0) {
117         HloInstruction* constant_instr =
118             CloneHelper(&invariant_value, while_cond);
119         TF_RETURN_IF_ERROR(
120             invariant_cond_gte->ReplaceAllUsesWith(constant_instr));
121         changed = true;
122       }
123     }
124   }
125 
126   return changed;
127 }
128 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)129 StatusOr<bool> WhileLoopConstantSinking::Run(
130     HloModule* module,
131     const absl::flat_hash_set<absl::string_view>& execution_threads) {
132   VLOG(2) << "HLO module before WhileLoopConstantSinking:";
133   XLA_VLOG_LINES(2, module->ToString());
134 
135   bool changed = false;
136   std::vector<HloInstruction*> while_instrs;
137   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
138     // Right now we don't particularly care about optimizing while-of-while
139     // patterns.  If/When we do, we'll want to visit the outer while (while_0)
140     // before we visit the inner while (while_1):
141     //
142     // while_1_body(state) {
143     //   val = gte(state, 0) // Loop invariant
144     //   use(val)
145     // }
146     //
147     // while_0_body(state) {
148     //   val = gte(state, 0) // Loop invariant
149     //   while_1 = while(init=tuple(val, ...), body=while_1_body, ...)
150     //   ...
151     // }
152     //
153     // main {
154     //   while_0 = while(init=(constant, ...), body=while_0_body, ...)
155     // }
156     //
157     // This will let us sink the constant into the outer while first and then
158     // into the inner while in a single run of this pass.
159     absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
160                     [](const HloInstruction* instr) {
161                       return instr->opcode() == HloOpcode::kWhile;
162                     });
163   }
164 
165   for (HloInstruction* while_instr : while_instrs) {
166     TF_ASSIGN_OR_RETURN(bool result,
167                         TrySinkingConstantsIntoWhileLoop(while_instr));
168     changed |= result;
169   }
170 
171   if (changed) {
172     VLOG(2) << "HLO module after WhileLoopConstantSinking:";
173     XLA_VLOG_LINES(2, module->ToString());
174   } else {
175     VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
176   }
177 
178   return changed;
179 }
180 }  // namespace xla
181