1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_util.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/tuple_util.h"
29
30 namespace xla {
31
32 using absl::StrCat;
33
WidenWhileCondition(HloComputation * narrow_condition,const Shape & wide_shape)34 static StatusOr<HloComputation*> WidenWhileCondition(
35 HloComputation* narrow_condition, const Shape& wide_shape) {
36 const Shape& narrow_shape =
37 narrow_condition->parameter_instruction(0)->shape();
38
39 HloComputation* wide_while_cond = [&]() {
40 HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
41 builder.AddInstruction(
42 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
43
44 // This is needed so that the root instruction is shaped as a PRED[] -- we
45 // need to get this right to begin with since we can't mutate the type of
46 // the root instruction later. We later change the root instruction to
47 // something more appropriate.
48 builder.AddInstruction(
49 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
50 return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
51 }();
52
53 HloInstruction* truncated_parameter =
54 TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
55 narrow_shape.tuple_shapes_size());
56 HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
57 HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
58 {truncated_parameter}, narrow_condition));
59
60 wide_while_cond->set_root_instruction(call_narrow_cond);
61
62 TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
63 return wide_while_cond;
64 }
65
66 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation * narrow_body,const Shape & wide_shape)67 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
68 const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
69
70 HloComputation* wide_while_body = [&]() {
71 HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
72 builder.AddInstruction(
73 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
74 return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
75 }();
76
77 HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
78 HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
79 wide_parameter, narrow_shape.tuple_shapes_size());
80 HloInstruction* call_narrow_body =
81 wide_while_body->AddInstruction(HloInstruction::CreateCall(
82 narrow_shape, {truncated_parameter}, narrow_body));
83
84 std::vector<HloInstruction*> live_through_values;
85 for (int i = narrow_shape.tuple_shapes_size();
86 i < wide_shape.tuple_shapes_size(); i++) {
87 live_through_values.push_back(
88 wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
89 wide_shape.tuple_shapes(i), wide_parameter, i)));
90 }
91
92 wide_while_body->set_root_instruction(
93 TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
94
95 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
96 CallInliner::Inline(call_narrow_body));
97 return {{wide_while_body, std::move(inlined_instructions_map)}};
98 }
99
100 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
MakeInstructionsLiveIn(HloInstruction * while_instr,absl::Span<HloInstruction * const> instructions)101 WhileUtil::MakeInstructionsLiveIn(
102 HloInstruction* while_instr,
103 absl::Span<HloInstruction* const> instructions) {
104 CHECK(while_instr->shape().IsTuple());
105
106 int elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
107 Shape new_while_shape = while_instr->shape();
108 for (auto* instruction : instructions) {
109 *new_while_shape.add_tuple_shapes() = instruction->shape();
110 }
111
112 TF_ASSIGN_OR_RETURN(
113 HloComputation * new_while_condition,
114 WidenWhileCondition(while_instr->while_condition(), new_while_shape));
115
116 HloComputation* new_while_body;
117 CallInliner::InlinedInstructionMap inlined_instructions_map;
118 TF_ASSIGN_OR_RETURN(
119 std::tie(new_while_body, inlined_instructions_map),
120 WidenWhileBody(while_instr->while_body(), new_while_shape));
121
122 HloInstruction* new_while_init =
123 TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
124 HloComputation* containing_computation = while_instr->parent();
125 HloInstruction* new_while = containing_computation->AddInstruction(
126 HloInstruction::CreateWhile(new_while_shape, new_while_condition,
127 new_while_body, new_while_init));
128
129 // We want to get rid of the old while instruction even if it has side
130 // effecting operations so we do a manual HloComputation::RemoveInstruction
131 // instead of relying on HloComputation::ReplaceInstruction.
132 HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
133 new_while, while_instr->shape().tuple_shapes_size());
134 TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
135 TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
136
137 HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
138 std::vector<HloInstruction*> live_in_instructions;
139 for (int64_t i = elements_in_old_while_shape;
140 i < new_while_shape.tuple_shapes_size(); i++) {
141 live_in_instructions.push_back(
142 new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
143 instructions[i - elements_in_old_while_shape]->shape(),
144 while_body_param, i)));
145 }
146
147 WhileUtil::MakeInstructionsLiveInResult result;
148
149 result.new_while_instr = new_while;
150 result.replacement_instr = replacement_instr;
151 result.while_body_live_in_values = std::move(live_in_instructions);
152 result.while_body_instruction_map = std::move(inlined_instructions_map);
153
154 return std::move(result);
155 }
156
157 static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape & loop_state_shape,int32_t trip_count)158 MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
159 int32_t trip_count) {
160 Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
161
162 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
163 CreateComputationWithSignature(
164 {&loop_state_shape}, scalar_pred, "while_cond"));
165
166 HloInstruction* trip_count_constant =
167 cond_computation->AddInstruction(HloInstruction::CreateConstant(
168 LiteralUtil::CreateR0<int32_t>(trip_count)));
169
170 HloInstruction* param = cond_computation->parameter_instruction(0);
171 TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
172 MakeGetTupleElementHlo(param, 0));
173
174 TF_ASSIGN_OR_RETURN(
175 HloInstruction * compare,
176 MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
177 cond_computation->set_root_instruction(compare);
178 return std::move(cond_computation);
179 }
180
MakeCountedLoopBodyComputation(const Shape & loop_state_shape,const std::function<StatusOr<WhileUtil::LoopStateTy> (HloInstruction *,const WhileUtil::LoopStateTy &)> & loop_body_generator)181 static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
182 const Shape& loop_state_shape,
183 const std::function<StatusOr<WhileUtil::LoopStateTy>(
184 HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) {
185 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> body_computation,
186 CreateComputationWithSignature(
187 {&loop_state_shape}, loop_state_shape, "while_body"));
188 HloInstruction* one = body_computation->AddInstruction(
189 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
190 HloInstruction* param = body_computation->parameter_instruction(0);
191 TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
192 MakeGetTupleElementHlo(param, 0));
193 TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
194 MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
195
196 std::vector<HloInstruction*> loop_body_generator_args;
197 for (int i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
198 TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
199 MakeGetTupleElementHlo(param, i));
200 loop_body_generator_args.push_back(tuple_element);
201 }
202 TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
203 loop_body_generator(indvar, loop_body_generator_args));
204 next_state.insert(next_state.begin(), next_indvar);
205 HloInstruction* next_state_tuple =
206 body_computation->AddInstruction(HloInstruction::CreateTuple(next_state));
207 body_computation->set_root_instruction(next_state_tuple);
208
209 return std::move(body_computation);
210 }
211
212 static std::pair<std::unique_ptr<HloInstruction>,
213 std::unique_ptr<HloInstruction>>
MakeInitTupleFromInitValues(const WhileUtil::LoopStateTy & init_values)214 MakeInitTupleFromInitValues(const WhileUtil::LoopStateTy& init_values) {
215 std::vector<HloInstruction*> init_values_with_indvar;
216 init_values_with_indvar.reserve(init_values.size() + 1);
217 std::unique_ptr<HloInstruction> zero =
218 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0));
219 init_values_with_indvar.push_back(zero.get());
220 absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
221 return std::make_pair(std::move(zero),
222 HloInstruction::CreateTuple(init_values_with_indvar));
223 }
224
225 // Returns a tuple shape containing a S32, and a shape from each value in
226 // `init_values`. If a shape from a value in `init_values` doesn't have a
227 // layout, use a default layout for the shape.
MakeLoopStateShapeWithLayout(const WhileUtil::LoopStateTy & init_values)228 static Shape MakeLoopStateShapeWithLayout(
229 const WhileUtil::LoopStateTy& init_values) {
230 std::vector<Shape> loop_state_shape_components;
231 loop_state_shape_components.reserve(init_values.size() + 1);
232 loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
233 absl::c_transform(init_values,
234 std::back_inserter(loop_state_shape_components),
235 [](HloInstruction* instr) {
236 Shape shape = instr->shape();
237 if (!shape.has_layout()) {
238 LayoutUtil::SetToDefaultLayout(&shape);
239 }
240 return shape;
241 });
242 return ShapeUtil::MakeTupleShape(loop_state_shape_components);
243 }
244
MakeCountedLoop(HloModule * module,int32_t trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)245 /*static*/ StatusOr<WhileUtil::OwningLoopStateTy> WhileUtil::MakeCountedLoop(
246 HloModule* module, int32_t trip_count,
247 const WhileUtil::LoopStateTy& init_values,
248 const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
249 const OpMetadata& metadata) {
250 CHECK_GE(trip_count, 0);
251
252 // Both MakeCountedLoopConditionComputation and MakeCountedLoopBodyComputation
253 // use loop_state_shape to create a literal, which requires loop_state_shape
254 // to have a layout.
255 Shape loop_state_shape = MakeLoopStateShapeWithLayout(init_values);
256 TF_ASSIGN_OR_RETURN(
257 std::unique_ptr<HloComputation> cond,
258 MakeCountedLoopConditionComputation(loop_state_shape, trip_count));
259 TF_ASSIGN_OR_RETURN(
260 std::unique_ptr<HloComputation> body,
261 MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator));
262 std::unique_ptr<HloInstruction> owned_indvar;
263 std::unique_ptr<HloInstruction> owned_init_tuple;
264 std::tie(owned_indvar, owned_init_tuple) =
265 MakeInitTupleFromInitValues(init_values);
266 std::unique_ptr<HloInstruction> owned_while = HloInstruction::CreateWhile(
267 loop_state_shape, module->AddEmbeddedComputation(std::move(cond)),
268 module->AddEmbeddedComputation(std::move(body)), owned_init_tuple.get());
269 owned_while->set_metadata(metadata);
270 HloInstruction* while_instr = owned_while.get();
271
272 std::vector<std::unique_ptr<HloInstruction>> owned;
273 owned.push_back(std::move(owned_indvar));
274 owned.push_back(std::move(owned_init_tuple));
275 owned.push_back(std::move(owned_while));
276 std::vector<HloInstruction*> while_results;
277 for (int64_t i = 0, e = init_values.size(); i < e; i++) {
278 std::unique_ptr<HloInstruction> user_state =
279 HloInstruction::CreateGetTupleElement(init_values[i]->shape(),
280 while_instr, i + 1);
281 while_results.push_back(user_state.get());
282 owned.push_back(std::move(user_state));
283 }
284 return WhileUtil::OwningLoopStateTy{std::move(owned), while_results};
285 }
286
MakeCountedLoop(HloComputation * computation,int32_t trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)287 /*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
288 HloComputation* computation, int32_t trip_count,
289 const WhileUtil::LoopStateTy& init_values,
290 const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
291 const OpMetadata& metadata) {
292 TF_ASSIGN_OR_RETURN(
293 auto owning_loop_state,
294 MakeCountedLoop(computation->parent(), trip_count, init_values,
295 loop_body_generator, metadata));
296 for (auto& instruction_to_add : owning_loop_state.instructions_to_add) {
297 computation->AddInstruction(std::move(instruction_to_add));
298 }
299 return owning_loop_state.while_results;
300 }
301
GetInvariantGTEsForWhileBody(const HloComputation & while_body)302 /*static*/ std::vector<HloInstruction*> WhileUtil::GetInvariantGTEsForWhileBody(
303 const HloComputation& while_body) {
304 std::vector<HloInstruction*> result;
305 const HloInstruction::InstructionVector root_operands =
306 while_body.root_instruction()->operands();
307 for (int i = 0; i < root_operands.size(); i++) {
308 HloInstruction* instr = root_operands[i];
309 if (instr->opcode() == HloOpcode::kGetTupleElement &&
310 instr->tuple_index() == i &&
311 instr->operand(0) == while_body.parameter_instruction(0)) {
312 result.push_back(instr);
313 }
314 }
315 return result;
316 }
317
318 /*static*/ absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>
GetGTEsMapForWhileConditional(const HloComputation & while_conditional)319 WhileUtil::GetGTEsMapForWhileConditional(
320 const HloComputation& while_conditional) {
321 absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>> result;
322 for (HloInstruction* user :
323 while_conditional.parameter_instruction(0)->users()) {
324 if (user->opcode() == HloOpcode::kGetTupleElement) {
325 result[user->tuple_index()].push_back(user);
326 }
327 }
328 return result;
329 }
330
331 } // namespace xla
332