xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/while_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_util.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/tuple_util.h"
29 
30 namespace xla {
31 
32 using absl::StrCat;
33 
WidenWhileCondition(HloComputation * narrow_condition,const Shape & wide_shape)34 static StatusOr<HloComputation*> WidenWhileCondition(
35     HloComputation* narrow_condition, const Shape& wide_shape) {
36   const Shape& narrow_shape =
37       narrow_condition->parameter_instruction(0)->shape();
38 
39   HloComputation* wide_while_cond = [&]() {
40     HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
41     builder.AddInstruction(
42         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
43 
44     // This is needed so that the root instruction is shaped as a PRED[] -- we
45     // need to get this right to begin with since we can't mutate the type of
46     // the root instruction later.  We later change the root instruction to
47     // something more appropriate.
48     builder.AddInstruction(
49         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
50     return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
51   }();
52 
53   HloInstruction* truncated_parameter =
54       TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
55                                narrow_shape.tuple_shapes_size());
56   HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
57       HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
58                                  {truncated_parameter}, narrow_condition));
59 
60   wide_while_cond->set_root_instruction(call_narrow_cond);
61 
62   TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
63   return wide_while_cond;
64 }
65 
66 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation * narrow_body,const Shape & wide_shape)67 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
68   const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
69 
70   HloComputation* wide_while_body = [&]() {
71     HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
72     builder.AddInstruction(
73         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
74     return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
75   }();
76 
77   HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
78   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
79       wide_parameter, narrow_shape.tuple_shapes_size());
80   HloInstruction* call_narrow_body =
81       wide_while_body->AddInstruction(HloInstruction::CreateCall(
82           narrow_shape, {truncated_parameter}, narrow_body));
83 
84   std::vector<HloInstruction*> live_through_values;
85   for (int i = narrow_shape.tuple_shapes_size();
86        i < wide_shape.tuple_shapes_size(); i++) {
87     live_through_values.push_back(
88         wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
89             wide_shape.tuple_shapes(i), wide_parameter, i)));
90   }
91 
92   wide_while_body->set_root_instruction(
93       TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
94 
95   TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
96                       CallInliner::Inline(call_narrow_body));
97   return {{wide_while_body, std::move(inlined_instructions_map)}};
98 }
99 
100 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
MakeInstructionsLiveIn(HloInstruction * while_instr,absl::Span<HloInstruction * const> instructions)101 WhileUtil::MakeInstructionsLiveIn(
102     HloInstruction* while_instr,
103     absl::Span<HloInstruction* const> instructions) {
104   CHECK(while_instr->shape().IsTuple());
105 
106   int elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
107   Shape new_while_shape = while_instr->shape();
108   for (auto* instruction : instructions) {
109     *new_while_shape.add_tuple_shapes() = instruction->shape();
110   }
111 
112   TF_ASSIGN_OR_RETURN(
113       HloComputation * new_while_condition,
114       WidenWhileCondition(while_instr->while_condition(), new_while_shape));
115 
116   HloComputation* new_while_body;
117   CallInliner::InlinedInstructionMap inlined_instructions_map;
118   TF_ASSIGN_OR_RETURN(
119       std::tie(new_while_body, inlined_instructions_map),
120       WidenWhileBody(while_instr->while_body(), new_while_shape));
121 
122   HloInstruction* new_while_init =
123       TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
124   HloComputation* containing_computation = while_instr->parent();
125   HloInstruction* new_while = containing_computation->AddInstruction(
126       HloInstruction::CreateWhile(new_while_shape, new_while_condition,
127                                   new_while_body, new_while_init));
128 
129   // We want to get rid of the old while instruction even if it has side
130   // effecting operations so we do a manual HloComputation::RemoveInstruction
131   // instead of relying on HloComputation::ReplaceInstruction.
132   HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
133       new_while, while_instr->shape().tuple_shapes_size());
134   TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
135   TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
136 
137   HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
138   std::vector<HloInstruction*> live_in_instructions;
139   for (int64_t i = elements_in_old_while_shape;
140        i < new_while_shape.tuple_shapes_size(); i++) {
141     live_in_instructions.push_back(
142         new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
143             instructions[i - elements_in_old_while_shape]->shape(),
144             while_body_param, i)));
145   }
146 
147   WhileUtil::MakeInstructionsLiveInResult result;
148 
149   result.new_while_instr = new_while;
150   result.replacement_instr = replacement_instr;
151   result.while_body_live_in_values = std::move(live_in_instructions);
152   result.while_body_instruction_map = std::move(inlined_instructions_map);
153 
154   return std::move(result);
155 }
156 
157 static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape & loop_state_shape,int32_t trip_count)158 MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
159                                     int32_t trip_count) {
160   Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
161 
162   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
163                       CreateComputationWithSignature(
164                           {&loop_state_shape}, scalar_pred, "while_cond"));
165 
166   HloInstruction* trip_count_constant =
167       cond_computation->AddInstruction(HloInstruction::CreateConstant(
168           LiteralUtil::CreateR0<int32_t>(trip_count)));
169 
170   HloInstruction* param = cond_computation->parameter_instruction(0);
171   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
172                       MakeGetTupleElementHlo(param, 0));
173 
174   TF_ASSIGN_OR_RETURN(
175       HloInstruction * compare,
176       MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
177   cond_computation->set_root_instruction(compare);
178   return std::move(cond_computation);
179 }
180 
MakeCountedLoopBodyComputation(const Shape & loop_state_shape,const std::function<StatusOr<WhileUtil::LoopStateTy> (HloInstruction *,const WhileUtil::LoopStateTy &)> & loop_body_generator)181 static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
182     const Shape& loop_state_shape,
183     const std::function<StatusOr<WhileUtil::LoopStateTy>(
184         HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) {
185   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> body_computation,
186                       CreateComputationWithSignature(
187                           {&loop_state_shape}, loop_state_shape, "while_body"));
188   HloInstruction* one = body_computation->AddInstruction(
189       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
190   HloInstruction* param = body_computation->parameter_instruction(0);
191   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
192                       MakeGetTupleElementHlo(param, 0));
193   TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
194                       MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
195 
196   std::vector<HloInstruction*> loop_body_generator_args;
197   for (int i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
198     TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
199                         MakeGetTupleElementHlo(param, i));
200     loop_body_generator_args.push_back(tuple_element);
201   }
202   TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
203                       loop_body_generator(indvar, loop_body_generator_args));
204   next_state.insert(next_state.begin(), next_indvar);
205   HloInstruction* next_state_tuple =
206       body_computation->AddInstruction(HloInstruction::CreateTuple(next_state));
207   body_computation->set_root_instruction(next_state_tuple);
208 
209   return std::move(body_computation);
210 }
211 
212 static std::pair<std::unique_ptr<HloInstruction>,
213                  std::unique_ptr<HloInstruction>>
MakeInitTupleFromInitValues(const WhileUtil::LoopStateTy & init_values)214 MakeInitTupleFromInitValues(const WhileUtil::LoopStateTy& init_values) {
215   std::vector<HloInstruction*> init_values_with_indvar;
216   init_values_with_indvar.reserve(init_values.size() + 1);
217   std::unique_ptr<HloInstruction> zero =
218       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0));
219   init_values_with_indvar.push_back(zero.get());
220   absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
221   return std::make_pair(std::move(zero),
222                         HloInstruction::CreateTuple(init_values_with_indvar));
223 }
224 
225 // Returns a tuple shape containing a S32, and a shape from each value in
226 // `init_values`. If a shape from a value in `init_values` doesn't have a
227 // layout, use a default layout for the shape.
MakeLoopStateShapeWithLayout(const WhileUtil::LoopStateTy & init_values)228 static Shape MakeLoopStateShapeWithLayout(
229     const WhileUtil::LoopStateTy& init_values) {
230   std::vector<Shape> loop_state_shape_components;
231   loop_state_shape_components.reserve(init_values.size() + 1);
232   loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
233   absl::c_transform(init_values,
234                     std::back_inserter(loop_state_shape_components),
235                     [](HloInstruction* instr) {
236                       Shape shape = instr->shape();
237                       if (!shape.has_layout()) {
238                         LayoutUtil::SetToDefaultLayout(&shape);
239                       }
240                       return shape;
241                     });
242   return ShapeUtil::MakeTupleShape(loop_state_shape_components);
243 }
244 
MakeCountedLoop(HloModule * module,int32_t trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)245 /*static*/ StatusOr<WhileUtil::OwningLoopStateTy> WhileUtil::MakeCountedLoop(
246     HloModule* module, int32_t trip_count,
247     const WhileUtil::LoopStateTy& init_values,
248     const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
249     const OpMetadata& metadata) {
250   CHECK_GE(trip_count, 0);
251 
252   // Both MakeCountedLoopConditionComputation and MakeCountedLoopBodyComputation
253   // use loop_state_shape to create a literal, which requires loop_state_shape
254   // to have a layout.
255   Shape loop_state_shape = MakeLoopStateShapeWithLayout(init_values);
256   TF_ASSIGN_OR_RETURN(
257       std::unique_ptr<HloComputation> cond,
258       MakeCountedLoopConditionComputation(loop_state_shape, trip_count));
259   TF_ASSIGN_OR_RETURN(
260       std::unique_ptr<HloComputation> body,
261       MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator));
262   std::unique_ptr<HloInstruction> owned_indvar;
263   std::unique_ptr<HloInstruction> owned_init_tuple;
264   std::tie(owned_indvar, owned_init_tuple) =
265       MakeInitTupleFromInitValues(init_values);
266   std::unique_ptr<HloInstruction> owned_while = HloInstruction::CreateWhile(
267       loop_state_shape, module->AddEmbeddedComputation(std::move(cond)),
268       module->AddEmbeddedComputation(std::move(body)), owned_init_tuple.get());
269   owned_while->set_metadata(metadata);
270   HloInstruction* while_instr = owned_while.get();
271 
272   std::vector<std::unique_ptr<HloInstruction>> owned;
273   owned.push_back(std::move(owned_indvar));
274   owned.push_back(std::move(owned_init_tuple));
275   owned.push_back(std::move(owned_while));
276   std::vector<HloInstruction*> while_results;
277   for (int64_t i = 0, e = init_values.size(); i < e; i++) {
278     std::unique_ptr<HloInstruction> user_state =
279         HloInstruction::CreateGetTupleElement(init_values[i]->shape(),
280                                               while_instr, i + 1);
281     while_results.push_back(user_state.get());
282     owned.push_back(std::move(user_state));
283   }
284   return WhileUtil::OwningLoopStateTy{std::move(owned), while_results};
285 }
286 
MakeCountedLoop(HloComputation * computation,int32_t trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)287 /*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
288     HloComputation* computation, int32_t trip_count,
289     const WhileUtil::LoopStateTy& init_values,
290     const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
291     const OpMetadata& metadata) {
292   TF_ASSIGN_OR_RETURN(
293       auto owning_loop_state,
294       MakeCountedLoop(computation->parent(), trip_count, init_values,
295                       loop_body_generator, metadata));
296   for (auto& instruction_to_add : owning_loop_state.instructions_to_add) {
297     computation->AddInstruction(std::move(instruction_to_add));
298   }
299   return owning_loop_state.while_results;
300 }
301 
GetInvariantGTEsForWhileBody(const HloComputation & while_body)302 /*static*/ std::vector<HloInstruction*> WhileUtil::GetInvariantGTEsForWhileBody(
303     const HloComputation& while_body) {
304   std::vector<HloInstruction*> result;
305   const HloInstruction::InstructionVector root_operands =
306       while_body.root_instruction()->operands();
307   for (int i = 0; i < root_operands.size(); i++) {
308     HloInstruction* instr = root_operands[i];
309     if (instr->opcode() == HloOpcode::kGetTupleElement &&
310         instr->tuple_index() == i &&
311         instr->operand(0) == while_body.parameter_instruction(0)) {
312       result.push_back(instr);
313     }
314   }
315   return result;
316 }
317 
318 /*static*/ absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>
GetGTEsMapForWhileConditional(const HloComputation & while_conditional)319 WhileUtil::GetGTEsMapForWhileConditional(
320     const HloComputation& while_conditional) {
321   absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>> result;
322   for (HloInstruction* user :
323        while_conditional.parameter_instruction(0)->users()) {
324     if (user->opcode() == HloOpcode::kGetTupleElement) {
325       result[user->tuple_index()].push_back(user);
326     }
327   }
328   return result;
329 }
330 
331 }  // namespace xla
332