1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17
18 #include <set>
19
20 #include "tensorflow/compiler/xla/debug_options_flags.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34
35 namespace op = xla::testing::opcode_matchers;
36
37 namespace xla {
38 namespace {
39
40 using ::testing::UnorderedElementsAre;
41
CountCopies(const HloComputation & computation)42 int64_t CountCopies(const HloComputation& computation) {
43 int64_t count = 0;
44 for (const auto& instruction : computation.instructions()) {
45 if (instruction->opcode() == HloOpcode::kCopy) {
46 count++;
47 }
48 }
49 return count;
50 }
51
CountCopies(const HloModule & module)52 int64_t CountCopies(const HloModule& module) {
53 int64_t count = 0;
54 for (const auto& computation : module.computations()) {
55 count += CountCopies(*computation);
56 }
57 return count;
58 }
59
CountControlEdges(const HloComputation & computation)60 int64_t CountControlEdges(const HloComputation& computation) {
61 int64_t count = 0;
62 for (const auto& instruction : computation.instructions()) {
63 count += instruction->control_successors().size();
64 }
65 return count;
66 }
67
CountControlEdges(const HloModule & module)68 int64_t CountControlEdges(const HloModule& module) {
69 int64_t count = 0;
70 for (const auto& computation : module.computations()) {
71 count += CountControlEdges(*computation);
72 }
73 return count;
74 }
75
76 class CopyInsertionTest : public HloTestBase {
77 protected:
InsertCopies(HloModule * module)78 void InsertCopies(HloModule* module) {
79 CopyInsertion copy_insertion;
80 VLOG(3) << "Before copy inser: " << module->ToString();
81 ASSERT_IS_OK(copy_insertion.Run(module).status());
82 VLOG(2) << "After copy inser: " << module->ToString();
83 }
84
85 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
86 };
87
TEST_F(CopyInsertionTest,SingleParameter)88 TEST_F(CopyInsertionTest, SingleParameter) {
89 // Computation is a single parameter passed into a tuple. The parameter should
90 // be copied before entering the tuple.
91 auto builder = HloComputation::Builder(TestName());
92 HloInstruction* x = builder.AddInstruction(
93 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
94 HloInstruction* tuple =
95 builder.AddInstruction(HloInstruction::CreateTuple({x}));
96
97 EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
98
99 auto module = CreateNewVerifiedModule();
100 module->AddEntryComputation(builder.Build());
101
102 InsertCopies(module.get());
103
104 EXPECT_THAT(module->entry_computation()->root_instruction(),
105 op::Tuple(op::Copy(x)));
106 }
107
TEST_F(CopyInsertionTest,SingleConstant)108 TEST_F(CopyInsertionTest, SingleConstant) {
109 // Computation is a single constant passed into a tuple. The parameter should
110 // be copied before entering the tuple.
111 auto builder = HloComputation::Builder(TestName());
112 HloInstruction* constant = builder.AddInstruction(
113 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
114 HloInstruction* tuple =
115 builder.AddInstruction(HloInstruction::CreateTuple({constant}));
116
117 EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
118
119 auto module = CreateNewVerifiedModule();
120 module->AddEntryComputation(builder.Build());
121
122 InsertCopies(module.get());
123 EXPECT_EQ(CountCopies(*module), 1);
124
125 EXPECT_THAT(module->entry_computation()->root_instruction(),
126 op::Tuple(op::Copy(constant)));
127 }
128
TEST_F(CopyInsertionTest,ExistingCopiesNotRemoved)129 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
130 // Verify that kCopy instructions which change layout and exist before
131 // copy-insertion remain in the graph after copy-insertion.
132 auto module = CreateNewVerifiedModule();
133
134 auto builder = HloComputation::Builder(TestName());
135 HloInstruction* constant =
136 builder.AddInstruction(HloInstruction::CreateConstant(
137 LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
138 auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
139 Layout reversed_layout =
140 LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
141 Shape copy_shape = constant->shape();
142 *copy_shape.mutable_layout() = reversed_layout;
143 HloInstruction* copy_1 = builder.AddInstruction(
144 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
145 HloInstruction* copy_2 = builder.AddInstruction(
146 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
147 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
148 constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
149 builder.AddInstruction(
150 HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
151
152 module->AddEntryComputation(builder.Build());
153
154 EXPECT_EQ(CountCopies(*module), 3);
155
156 InsertCopies(module.get());
157
158 EXPECT_EQ(CountCopies(*module), 2);
159
160 EXPECT_EQ(module->entry_computation()->root_instruction(), add);
161 EXPECT_THAT(module->entry_computation()->root_instruction(),
162 op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
163 }
164
TEST_F(CopyInsertionTest,MultipleConstantsAndParameters)165 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
166 // Create a computation with more than one constant and parameter. Only one of
167 // each constant/parameter is pointed to by the output tuple. Only these
168 // instructions should be copied.
169 auto builder = HloComputation::Builder(TestName());
170
171 HloInstruction* constant1 = builder.AddInstruction(
172 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
173 HloInstruction* constant2 = builder.AddInstruction(
174 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
175
176 HloInstruction* x = builder.AddInstruction(
177 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
178 HloInstruction* y = builder.AddInstruction(
179 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
180
181 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
182 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
183
184 builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
185
186 auto module = CreateNewVerifiedModule();
187 module->AddEntryComputation(builder.Build());
188
189 InsertCopies(module.get());
190 EXPECT_EQ(CountCopies(*module), 2);
191
192 EXPECT_THAT(
193 module->entry_computation()->root_instruction(),
194 op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
195 }
196
TEST_F(CopyInsertionTest,BitcastParameter)197 TEST_F(CopyInsertionTest, BitcastParameter) {
198 // The output of a bitcast is its operand (same buffer), so a bitcast
199 // parameter feeding the result must have a copy added.
200 auto builder = HloComputation::Builder(TestName());
201 HloInstruction* x = builder.AddInstruction(
202 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
203 HloInstruction* bitcast = builder.AddInstruction(
204 HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
205
206 auto module = CreateNewVerifiedModule();
207 module->AddEntryComputation(builder.Build());
208
209 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
210
211 HloInstruction* old_root = module->entry_computation()->root_instruction();
212 InsertCopies(module.get());
213 EXPECT_EQ(CountCopies(*module), 1);
214
215 EXPECT_THAT(module->entry_computation()->root_instruction(),
216 op::Copy(old_root));
217 }
218
TEST_F(CopyInsertionTest,BitcastConstant)219 TEST_F(CopyInsertionTest, BitcastConstant) {
220 // The output of a bitcast is its operand (same buffer), so a bitcast
221 // constant feeding the result must have a copy added.
222 auto builder = HloComputation::Builder(TestName());
223 HloInstruction* constant =
224 builder.AddInstruction(HloInstruction::CreateConstant(
225 LiteralUtil::CreateR1<float>({1.0, 42.0})));
226 HloInstruction* bitcast = builder.AddInstruction(
227 HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2}), constant));
228
229 auto module = CreateNewVerifiedModule();
230 module->AddEntryComputation(builder.Build());
231
232 EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
233
234 HloInstruction* old_root = module->entry_computation()->root_instruction();
235 InsertCopies(module.get());
236 EXPECT_EQ(CountCopies(*module), 1);
237
238 EXPECT_THAT(module->entry_computation()->root_instruction(),
239 op::Copy(old_root));
240 }
241
TEST_F(CopyInsertionTest,BitcastTupleElementParameter)242 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
243 // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
244 auto builder = HloComputation::Builder(TestName());
245 HloInstruction* x = builder.AddInstruction(
246 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
247 HloInstruction* bitcast = builder.AddInstruction(
248 HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
249 builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
250
251 auto module = CreateNewVerifiedModule();
252 module->AddEntryComputation(builder.Build());
253
254 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
255
256 InsertCopies(module.get());
257 EXPECT_EQ(CountCopies(*module), 1);
258
259 EXPECT_THAT(module->entry_computation()->root_instruction(),
260 op::Tuple(op::Copy(bitcast)));
261 }
262
TEST_F(CopyInsertionTest,NestedTupleParameter)263 TEST_F(CopyInsertionTest, NestedTupleParameter) {
264 // Construct a trivial computation where the root of the computation is a
265 // nested tuple-shaped parameter. The parameter should be deep copied and the
266 // copy should be the root of the computation.
267 auto builder = HloComputation::Builder(TestName());
268
269 // Param shape is: ((F32[], S32[1,2,3]), F32[42])
270 builder.AddInstruction(HloInstruction::CreateParameter(
271 0,
272 ShapeUtil::MakeTupleShape(
273 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
274 ShapeUtil::MakeShape(S32, {1, 2, 3})}),
275 ShapeUtil::MakeShape(F32, {42})}),
276 "param0"));
277
278 auto module = CreateNewVerifiedModule();
279 module->AddEntryComputation(builder.Build());
280
281 EXPECT_EQ(HloOpcode::kParameter,
282 module->entry_computation()->root_instruction()->opcode());
283
284 HloInstruction* old_root = module->entry_computation()->root_instruction();
285 InsertCopies(module.get());
286 EXPECT_EQ(CountCopies(*module), 3);
287
288 HloInstruction* new_root = module->entry_computation()->root_instruction();
289 EXPECT_NE(old_root, new_root);
290
291 EXPECT_THAT(
292 new_root,
293 op::Tuple(
294 op::Tuple(
295 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))),
296 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))),
297 op::Copy(op::GetTupleElement(old_root))));
298 }
299
TEST_F(CopyInsertionTest,ElementOfNestedTupleParameter)300 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
301 // Construct a computation where the root of the computation is a tuple
302 // element of a nested tuple-shaped parameter.
303 auto builder = HloComputation::Builder(TestName());
304
305 // Param shape is: ((F32[], S32[1,2,3]), F32[42])
306 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
307 0,
308 ShapeUtil::MakeTupleShape(
309 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
310 ShapeUtil::MakeShape(S32, {1, 2, 3})}),
311 ShapeUtil::MakeShape(F32, {42})}),
312 "param0"));
313
314 // The return value of the computation is the zero-th element of the nested
315 // tuple. This element is itself a tuple.
316 auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
317 ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
318
319 auto module = CreateNewVerifiedModule();
320 module->AddEntryComputation(builder.Build());
321
322 EXPECT_EQ(gte, module->entry_computation()->root_instruction());
323
324 InsertCopies(module.get());
325 EXPECT_EQ(CountCopies(*module), 2);
326
327 EXPECT_THAT(
328 module->entry_computation()->root_instruction(),
329 op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
330 op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
331 }
332
333 class WhileCopyInsertionTest : public CopyInsertionTest {
334 protected:
WhileCopyInsertionTest()335 WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {}
336
337 // Builds a While condition computation which reads the induction variable
338 // from the tuple parameter, and returns a predicate indicating whether this
339 // value is less than the constant '10'.
340 // The parameter 'nested' specifies the loop state shape from which to
341 // read the induction variable.
BuildConditionComputation(const Shape & loop_state_shape)342 std::unique_ptr<HloComputation> BuildConditionComputation(
343 const Shape& loop_state_shape) {
344 auto builder = HloComputation::Builder(TestName() + ".Condition");
345 auto limit_const = builder.AddInstruction(
346 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(10)));
347 auto loop_state = builder.AddInstruction(
348 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
349 auto induction_variable =
350 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
351 limit_const->shape(), loop_state, 0));
352 builder.AddInstruction(HloInstruction::CreateCompare(
353 condition_result_shape_, induction_variable, limit_const,
354 ComparisonDirection::kLt));
355 return builder.Build();
356 }
357
358 // Builds a While body computation with one output tuple element dependent on
359 // both input tuple elements.
360 // EX:
361 // Body({in0, in1})
362 // out0 = Add(in0, 1)
363 // out1 = Add(BCast(in0), in1)
364 // Tuple(out0, out1)
BuildDependentBodyComputation()365 std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
366 auto builder = HloComputation::Builder(TestName() + ".Body");
367 // Create param instruction to access loop state.
368 auto loop_state = builder.AddInstruction(
369 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
370 // Update the induction variable GTE(0).
371 auto induction_variable =
372 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
373 induction_variable_shape_, loop_state, 0));
374 auto inc = builder.AddInstruction(
375 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
376 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
377 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
378 // Update data GTE(1).
379 auto data = builder.AddInstruction(
380 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
381 // Use 'induction_variable' in computation with no path to output tuple.
382 Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
383 auto convert = builder.AddInstruction(
384 HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
385 auto update = builder.AddInstruction(
386 HloInstruction::CreateBroadcast(data_shape_, convert, {}));
387 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
388 data_shape_, HloOpcode::kAdd, data, update));
389 // Create output Tuple.
390 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
391 return builder.Build();
392 }
393
394 // Builds a While body computation with two output tuple elements dependent on
395 // both input tuple elements.
396 //
397 // EX: Body({in0, in1, in2})
398 // out0 = Add(in0, 1)
399 // out1 = in1
400 // out2 = in2
401 // Tuple(out0, out1, out2)
BuildDependentBodyComputation2()402 std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
403 auto builder = HloComputation::Builder(TestName() + ".Body");
404
405 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
406 {induction_variable_shape_, data_shape_, data_shape_});
407
408 auto loop_state = builder.AddInstruction(
409 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
410
411 // Update the induction variable GTE(0).
412 auto induction_variable =
413 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
414 induction_variable_shape_, loop_state, 0));
415 auto inc = builder.AddInstruction(
416 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
417
418 // add0 = Add(in0, 1)
419 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
420 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
421 // data1 = GTE(1).
422 HloInstruction* data1 = builder.AddInstruction(
423 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
424
425 // data2 = GTE(2).
426 HloInstruction* data2 = builder.AddInstruction(
427 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
428
429 // Create output Tuple.
430 builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
431
432 return builder.Build();
433 }
434
435 // Builds a While body computation with read-only tuple element 0.
436 // EX:
437 // Body({in0, in1})
438 // out0 = in0
439 // out1 = Add(BCast(in0), in1)
440 // Tuple(out0, out1)
BuildDependentBodyOneReadOnlyComputation()441 std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
442 auto builder = HloComputation::Builder(TestName() + ".Body");
443 // Create param instruction to access loop state.
444 auto loop_state = builder.AddInstruction(
445 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
446 // Update the induction variable GTE(0).
447 auto induction_variable =
448 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
449 induction_variable_shape_, loop_state, 0));
450 // Update data GTE(1).
451 auto data = builder.AddInstruction(
452 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
453
454 // Use 'induction_variable' in computation with no path to output tuple.
455 Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
456 auto convert = builder.AddInstruction(
457 HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
458 auto update = builder.AddInstruction(
459 HloInstruction::CreateBroadcast(data_shape_, convert, {}));
460 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
461 data_shape_, HloOpcode::kAdd, data, update));
462 // Create output Tuple.
463 builder.AddInstruction(
464 HloInstruction::CreateTuple({induction_variable, add1}));
465 return builder.Build();
466 }
467
468 // Builds a While body computation with independent outputs.
469 // EX:
470 // Body({in0, in1})
471 // out0 = Add(in0, 1)
472 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
473 // Tuple(out0, out1)
BuildIndependentBodyComputation(bool nested=false)474 std::unique_ptr<HloComputation> BuildIndependentBodyComputation(
475 bool nested = false) {
476 auto builder = HloComputation::Builder(TestName() + ".Body");
477 // Create param instruction to access loop state.
478 const Shape& loop_state_shape =
479 nested ? nested_loop_state_shape_ : loop_state_shape_;
480
481 auto loop_state = builder.AddInstruction(
482 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
483 // Update the induction variable GTE(0).
484 auto induction_variable =
485 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
486 induction_variable_shape_, loop_state, 0));
487 auto inc = builder.AddInstruction(
488 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
489 // add0 = Add(in0, 1)
490 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
491 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
492 // Update data GTE(1).
493 HloInstruction* data = nullptr;
494 if (nested) {
495 data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
496 nested_tuple_shape_, loop_state, 1));
497 data = builder.AddInstruction(
498 HloInstruction::CreateGetTupleElement(data_shape_, data, 0));
499 } else {
500 data = builder.AddInstruction(
501 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
502 }
503 auto update = builder.AddInstruction(
504 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
505 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
506 // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
507 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
508 data_shape_, HloOpcode::kAdd, data, update));
509 // Create output Tuple.
510 if (nested) {
511 auto nested_tuple =
512 builder.AddInstruction(HloInstruction::CreateTuple({add1, add1}));
513 builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple}));
514 } else {
515 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
516 }
517 return builder.Build();
518 }
519
520 // Builds a While body computation with the following nested tuple
521 // sub-computation:
522 // |
523 // GTE(loop_state, 1)
524 // / \
525 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
526 // | |
527 // Add Reverse
528 // | |
BuildNestedBodyComputation()529 std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
530 auto builder = HloComputation::Builder(TestName() + ".Body");
531 // Create param instruction to access loop state.
532 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
533 0, nested_loop_state_shape_, "loop_state"));
534 // Update GTE(0).
535 auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
536 induction_variable_shape_, loop_state, 0));
537 auto inc = builder.AddInstruction(
538 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
539 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
540 gte0->shape(), HloOpcode::kAdd, gte0, inc));
541
542 // GTE(loop_state, 1)
543 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
544 nested_tuple_shape_, loop_state, 1));
545 // GTE(GTE(loop_state, 1), 0) -> Add
546 auto gte10 = builder.AddInstruction(
547 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
548 auto update10 = builder.AddInstruction(
549 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
550 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
551 auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
552 data_shape_, HloOpcode::kAdd, gte10, update10));
553
554 // GTE(GTE(loop_state, 1), 1) -> Reverse
555 auto gte11 = builder.AddInstruction(
556 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
557 auto rev11 = builder.AddInstruction(
558 HloInstruction::CreateReverse(data_shape_, gte11, {0}));
559
560 // Create output Tuple.
561 auto inner_tuple =
562 builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
563 builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
564 return builder.Build();
565 }
566
567 // Builds a While instruction using 'condition' and 'body' sub-computations.
568 // Init operand is initialized to zeros of appropriate shape.
BuildWhileInstruction(HloComputation * condition,HloComputation * body,bool nested=false)569 HloInstruction* BuildWhileInstruction(HloComputation* condition,
570 HloComputation* body,
571 bool nested = false) {
572 auto builder = HloComputation::Builder(TestName() + ".While");
573 auto induction_var_init = builder.AddInstruction(
574 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
575
576 auto data_init = builder.AddInstruction(
577 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
578 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
579
580 if (nested) {
581 auto inner_init = builder.AddInstruction(
582 HloInstruction::CreateTuple({data_init, data_init}));
583 auto loop_state_init = builder.AddInstruction(
584 HloInstruction::CreateTuple({induction_var_init, inner_init}));
585 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
586 loop_state_init->shape(), condition, body, loop_state_init));
587 module_->AddEntryComputation(builder.Build());
588 return while_hlo;
589 }
590
591 auto loop_state_init = builder.AddInstruction(
592 HloInstruction::CreateTuple({induction_var_init, data_init}));
593 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
594 loop_state_shape_, condition, body, loop_state_init));
595 module_->AddEntryComputation(builder.Build());
596 return while_hlo;
597 }
598
BuildWhileInstruction_InitPointsToConstant()599 HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
600 auto builder = HloComputation::Builder(TestName() + ".While");
601 auto data_init = builder.AddInstruction(
602 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
603 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
604 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
605 &builder);
606 }
607
BuildWhileInstruction_InitPointsToParameter()608 HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
609 auto builder = HloComputation::Builder(TestName() + ".While");
610 auto data_init = builder.AddInstruction(
611 HloInstruction::CreateParameter(0, data_shape_, "data_init"));
612 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
613 &builder);
614 }
615
BuildWhileInstruction_InitPointsToNonDistinct()616 HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
617 auto builder = HloComputation::Builder(TestName() + ".While");
618
619 auto one = builder.AddInstruction(
620 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
621 auto one_vec = builder.AddInstruction(
622 HloInstruction::CreateBroadcast(data_shape_, one, {}));
623 auto data_init =
624 builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
625
626 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
627 data_init, &builder);
628 }
629
BuildWhileInstruction_InitPointsToInterfering()630 HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
631 auto builder = HloComputation::Builder(TestName() + ".While");
632 auto one = builder.AddInstruction(
633 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
634 auto data_init = builder.AddInstruction(
635 HloInstruction::CreateBroadcast(data_shape_, one, {}));
636 auto one_vec = builder.AddInstruction(
637 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
638 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
639 // Take a reference to 'data_init' to make it interfere with while result.
640 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
641 data_shape_, HloOpcode::kAdd, data_init, one_vec));
642
643 auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
644 data_init, &builder);
645
646 // Add an additional binary operation operating on the while and the
647 // interfering add so that neither operation is dead.
648 auto gte = xla_while->parent()->AddInstruction(
649 HloInstruction::CreateGetTupleElement(
650 ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
651 auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
652 data_shape_, HloOpcode::kSubtract, add, gte));
653 auto gte0 = xla_while->parent()->AddInstruction(
654 HloInstruction::CreateGetTupleElement(
655 ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
656 auto tuple = xla_while->parent()->AddInstruction(
657 HloInstruction::CreateTuple({gte0, sub}));
658
659 xla_while->parent()->set_root_instruction(tuple);
660
661 return xla_while;
662 }
663
BuildWhileInstructionWithCustomInit(const Shape & loop_state_shape,HloInstruction * data_init,HloComputation::Builder * builder)664 HloInstruction* BuildWhileInstructionWithCustomInit(
665 const Shape& loop_state_shape, HloInstruction* data_init,
666 HloComputation::Builder* builder) {
667 const bool nested =
668 ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
669 auto induction_var_init = builder->AddInstruction(
670 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
671 auto condition = module_->AddEmbeddedComputation(
672 BuildConditionComputation(loop_state_shape));
673 auto body = module_->AddEmbeddedComputation(
674 BuildIndependentBodyComputation(nested));
675 auto loop_state_init = builder->AddInstruction(
676 HloInstruction::CreateTuple({induction_var_init, data_init}));
677 auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
678 loop_state_shape, condition, body, loop_state_init));
679 module_->AddEntryComputation(builder->Build());
680 return while_hlo;
681 }
682
683 std::unique_ptr<HloModule> module_;
684 Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
685 Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
686 Shape loop_state_shape_ =
687 ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
688 Shape nested_tuple_shape_ =
689 ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
690 Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
691 {induction_variable_shape_, nested_tuple_shape_});
692 Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
693 };
694
695 // Tests while body computation with independent tuple elements:
696 //
697 // While.Body({in0, in1})
698 // out0 = Add(in0, 1)
699 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
700 // Tuple(out0, out1)
701 //
702 // CopyInsertion pass should not generate any copies.
703 //
TEST_F(WhileCopyInsertionTest,IndependentTupleElements)704 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
705 auto condition = module_->AddEmbeddedComputation(
706 BuildConditionComputation(loop_state_shape_));
707 auto body =
708 module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
709 auto while_hlo = BuildWhileInstruction(condition, body);
710
711 InsertCopies(module_.get());
712
713 // Body should have no copies as the adds can be done inplace.
714 EXPECT_EQ(CountCopies(*body), 0);
715 EXPECT_EQ(CountControlEdges(*module_), 0);
716
717 // Both init indices need copies as they are constants.
718 EXPECT_THAT(while_hlo->operand(0),
719 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
720 }
721
722 // Tests Copy Insertion when a while feeds another while
723 // PARAMETER
724 // | |
725 // GTE(0) GTE(1)
726 // | |
727 // X = CreateTuple(GTE(0), GTE(1))
728 // |
729 // WHILE(X) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterWithCopies)730 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterWithCopies) {
731 const std::string& hlo_string = R"(
732 HloModule DependentTupleElements
733
734 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
735 %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
736 %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
737 %constant.1 = s32[] constant(1)
738 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
739 %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
740 %convert = f32[] convert(s32[] %get-tuple-element.1)
741 %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
742 %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
743 ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
744 }
745
746 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
747 %loop_state = (s32[], f32[8]{0}) parameter(0)
748 %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
749 %constant = s32[] constant(10)
750 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
751 }
752
753 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
754 %constant.2 = s32[] constant(0)
755 %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
756 %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
757 ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
758 }
759 )";
760 auto module_ = ParseAndReturnVerifiedModule(hlo_string).value();
761 auto while_hlo = module_->entry_computation()->root_instruction();
762 // module_ and while_hlo are the pre-existing module and hlo, the below
763 // code generates a clone of the existing while and replaces that while
764 // with itself. The body of the new while calls the previous while
765 HloComputation* outer_while_condition =
766 module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
767 HloComputation* outer_while_body =
768 module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
769 HloInstruction* outer_while =
770 while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
771 while_hlo->shape(), outer_while_condition, outer_while_body,
772 while_hlo->mutable_operand(0)));
773 HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
774 std::vector<HloInstruction*> materialized_gtes;
775 for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
776 materialized_gtes.push_back(
777 outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
778 outer_param->shape().tuple_shapes(i), outer_param, i)));
779 }
780 HloInstruction* dual_init = outer_while_body->AddInstruction(
781 HloInstruction::CreateTuple(materialized_gtes));
782 HloInstruction* dual_while =
783 outer_while_body->AddInstruction(HloInstruction::CreateWhile(
784 while_hlo->shape(), while_hlo->while_condition(),
785 while_hlo->while_body(), dual_init));
786 TF_CHECK_OK(outer_while_body->ReplaceInstruction(
787 outer_while_body->root_instruction(), dual_while));
788 TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
789 InsertCopies(module_.get());
790 }
791
792 // Tests Copy Insertion when a while feeds another while
793 // PARAMETER
794 // | |
795 // \ /
796 // WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterNoCopies)797 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterNoCopies) {
798 const std::string& hlo_string = R"(
799 HloModule DependentTupleElements
800
801 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
802 %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
803 %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
804 %constant.1 = s32[] constant(1)
805 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
806 %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
807 %convert = f32[] convert(s32[] %get-tuple-element.1)
808 %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
809 %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
810 ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
811 }
812
813 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
814 %loop_state = (s32[], f32[8]{0}) parameter(0)
815 %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
816 %constant = s32[] constant(10)
817 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
818 }
819
820 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
821 %constant.2 = s32[] constant(0)
822 %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
823 %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
824 ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
825 }
826 )";
827 auto module_ = ParseAndReturnVerifiedModule(hlo_string).value();
828 auto while_hlo = module_->entry_computation()->root_instruction();
829 // module_ and while_hlo are the pre-existing module and hlo, the below
830 // code generates a clone of the existing while and replaces that while
831 // with itself. The body of the new while calls the previous while
832 HloComputation* outer_while_condition =
833 module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
834 HloComputation* outer_while_body =
835 module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
836 HloInstruction* outer_while =
837 while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
838 while_hlo->shape(), outer_while_condition, outer_while_body,
839 while_hlo->mutable_operand(0)));
840 HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
841 HloInstruction* dual_while =
842 outer_while_body->AddInstruction(HloInstruction::CreateWhile(
843 while_hlo->shape(), while_hlo->while_condition(),
844 while_hlo->while_body(), outer_param));
845 TF_CHECK_OK(outer_while_body->ReplaceInstruction(
846 outer_while_body->root_instruction(), dual_while));
847 TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
848 InsertCopies(module_.get());
849 }
850
851 // Tests Copy Insertion when a while feeds another while
852 // PARAMETER
853 // | |
854 // \ /
855 // WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterBig)856 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterBig) {
857 const std::string& hlo_string = R"(
858 HloModule DependentTupleElements
859
860 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
861 %loop_state.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
862 %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=0
863 %constant.1 = s32[] constant(1)
864 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
865 %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=1
866 %convert = f32[] convert(s32[] %get-tuple-element.1)
867 %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
868 %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
869 ROOT %tuple = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1)
870 }
871
872 %DependentTupleElements.Condition (loop_state: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> pred[] {
873 %loop_state = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
874 %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state), index=0
875 %constant = s32[] constant(10)
876 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
877 }
878
879 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
880 %constant.2 = s32[] constant(0)
881 %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
882 %tuple.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3)
883 ROOT %while.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) while( (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
884 }
885 )";
886 auto module_ = ParseAndReturnVerifiedModule(hlo_string).value();
887 auto while_hlo = module_->entry_computation()->root_instruction();
888 // module_ and while_hlo are the pre-existing module and hlo, the below
889 // code generates a clone of the existing while and replaces that while
890 // with itself. The body of the new while calls the previous while
891 HloComputation* outer_while_condition =
892 module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
893 HloComputation* outer_while_body =
894 module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
895 HloInstruction* outer_while =
896 while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
897 while_hlo->shape(), outer_while_condition, outer_while_body,
898 while_hlo->mutable_operand(0)));
899 HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
900 std::vector<HloInstruction*> materialized_gtes;
901 for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
902 materialized_gtes.push_back(
903 outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
904 outer_param->shape().tuple_shapes(i), outer_param, i)));
905 }
906 HloInstruction* dual_init = outer_while_body->AddInstruction(
907 HloInstruction::CreateTuple(materialized_gtes));
908 HloInstruction* dual_while =
909 outer_while_body->AddInstruction(HloInstruction::CreateWhile(
910 while_hlo->shape(), while_hlo->while_condition(),
911 while_hlo->while_body(), dual_init));
912 TF_CHECK_OK(outer_while_body->ReplaceInstruction(
913 outer_while_body->root_instruction(), dual_while));
914 TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
915 InsertCopies(module_.get());
916 }
917
918 // Tests while body computation with dependent tuple elements:
919 //
920 // While.Body({in0, in1})
921 // out0 = Add(in0, 1)
922 // out1 = Add(BCast(in0), in1)
923 // Tuple(out0, out1)
924 //
925 // CopyInsertion pass should convert the root instruction to:
926 //
927 // Tuple(Copy(out0), out1)
928 //
TEST_F(WhileCopyInsertionTest,DependentTupleElements)929 TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
930 auto condition = module_->AddEmbeddedComputation(
931 BuildConditionComputation(loop_state_shape_));
932 auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
933 auto while_hlo = BuildWhileInstruction(condition, body);
934
935 InsertCopies(module_.get());
936
937 EXPECT_EQ(CountCopies(*body), 1);
938 EXPECT_EQ(CountControlEdges(*body), 0);
939
940 EXPECT_THAT(
941 body->root_instruction(),
942 op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
943
944 auto add = body->root_instruction()->operand(0);
945 auto bcast = body->root_instruction()->operand(1)->operand(1);
946 ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
947 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
948
949 EXPECT_THAT(while_hlo->while_body()->root_instruction(),
950 op::Tuple(op::Add(op::Copy(), op::Constant()),
951 op::Add(op::GetTupleElement(),
952 op::Broadcast(op::Convert(op::Copy())))));
953
954 // Both init indices need copies as they are constants.
955 EXPECT_THAT(while_hlo->operand(0),
956 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
957 }
958
959 // Tests while body computation with read-only tuple element 0:
960 //
961 // PARAMETER
962 // / \
963 // GTE(0) GTE(1)
964 // | \ |
965 // | BCAST |
966 // | \ |
967 // | ADD
968 // | |
969 // \ /
970 // TUPLE (root)
971 //
972 // CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly)973 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
974 auto condition = module_->AddEmbeddedComputation(
975 BuildConditionComputation(loop_state_shape_));
976 auto body = module_->AddEmbeddedComputation(
977 BuildDependentBodyOneReadOnlyComputation());
978 BuildWhileInstruction(condition, body);
979
980 InsertCopies(module_.get());
981
982 // No copies or control edges should be inserted. The body is legal as is.
983 EXPECT_EQ(CountCopies(*body), 0);
984 EXPECT_EQ(CountControlEdges(*body), 0);
985 }
986
987 // Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_EntryParams)988 TEST_F(WhileCopyInsertionTest,
989 DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
990 auto condition1 = module_->AddEmbeddedComputation(
991 BuildConditionComputation(loop_state_shape_));
992 auto condition2 = module_->AddEmbeddedComputation(
993 BuildConditionComputation(loop_state_shape_));
994 auto body1 = module_->AddEmbeddedComputation(
995 BuildDependentBodyOneReadOnlyComputation());
996 auto body2 = module_->AddEmbeddedComputation(
997 BuildDependentBodyOneReadOnlyComputation());
998
999 auto builder = HloComputation::Builder(TestName() + ".While");
1000 auto iter_param = builder.AddInstruction(
1001 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1002 auto data_param = builder.AddInstruction(
1003 HloInstruction::CreateParameter(1, data_shape_, "data"));
1004 auto loop_init = builder.AddInstruction(
1005 HloInstruction::CreateTuple({iter_param, data_param}));
1006
1007 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1008 loop_state_shape_, condition1, body1, loop_init));
1009 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1010 loop_state_shape_, condition2, body2, loop_init));
1011
1012 // Add a couple elements from each of the while so both whiles are live.
1013 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1014 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1015 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1016 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1017 builder.AddInstruction(
1018 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1019
1020 auto entry = module_->AddEntryComputation(builder.Build());
1021
1022 InsertCopies(module_.get());
1023
1024 // Neither body should have any copies or control edges in them.
1025 EXPECT_EQ(CountCopies(*body1), 0);
1026 EXPECT_EQ(CountCopies(*body2), 0);
1027 EXPECT_EQ(CountControlEdges(*body1), 0);
1028 EXPECT_EQ(CountControlEdges(*body2), 0);
1029
1030 // Only two copies should be necessary. Each of the whiles should have
1031 // a copy of tuple element 1 (init value is a parameter, and the element is
1032 // not non-read-only) so each of the while bodies gets its own buffer to write
1033 // element 1 into.
1034 EXPECT_EQ(CountCopies(*entry), 2);
1035
1036 EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1037 EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1038
1039 // The two copies of element 1 should be different.
1040 EXPECT_NE(while_hlo1->operand(0)->operand(1),
1041 while_hlo2->operand(0)->operand(1));
1042 }
1043
1044 // Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_NonParams)1045 TEST_F(WhileCopyInsertionTest,
1046 DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
1047 auto condition1 = module_->AddEmbeddedComputation(
1048 BuildConditionComputation(loop_state_shape_));
1049 auto condition2 = module_->AddEmbeddedComputation(
1050 BuildConditionComputation(loop_state_shape_));
1051 auto body1 = module_->AddEmbeddedComputation(
1052 BuildDependentBodyOneReadOnlyComputation());
1053 auto body2 = module_->AddEmbeddedComputation(
1054 BuildDependentBodyOneReadOnlyComputation());
1055
1056 auto builder = HloComputation::Builder(TestName() + ".While");
1057 auto iter_param = builder.AddInstruction(
1058 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1059 auto data_param = builder.AddInstruction(
1060 HloInstruction::CreateParameter(1, data_shape_, "data"));
1061 // Add dummy ops to ensure loop_init elements aren't entry parameters.
1062 Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
1063 auto convert = builder.AddInstruction(
1064 HloInstruction::CreateConvert(f32_scalar_shape, iter_param));
1065 auto iter_value = builder.AddInstruction(
1066 HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert));
1067 auto convert2 = builder.AddInstruction(
1068 HloInstruction::CreateConvert(induction_variable_shape_, iter_value));
1069 auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
1070 data_param->shape(), HloOpcode::kExp, data_param));
1071 auto loop_init = builder.AddInstruction(
1072 HloInstruction::CreateTuple({convert2, data_value}));
1073
1074 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1075 loop_state_shape_, condition1, body1, loop_init));
1076 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1077 loop_state_shape_, condition2, body2, loop_init));
1078
1079 // Add a couple elements from each of the while so both whiles are not dead.
1080 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1081 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1082 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1083 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1084 builder.AddInstruction(
1085 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1086 auto entry = module_->AddEntryComputation(builder.Build());
1087
1088 InsertCopies(module_.get());
1089
1090 // Ideally only one copy should be necessary. One of the whiles should
1091 // have a copy of tuple element 1 (the non-read-only element) so each of the
1092 // while bodies gets its own buffer to write element 1 into. However, the
1093 // analysis isn't perfect and adds an additional copy of element 0.
1094 EXPECT_EQ(CountCopies(*entry), 2);
1095
1096 EXPECT_THAT(while_hlo1->operand(0),
1097 op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1098 EXPECT_THAT(while_hlo2->operand(0),
1099 op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1100 }
1101
1102 // Tests while body computation with nested tuple elements:
1103 //
1104 // |
1105 // GTE(loop_state, 1)
1106 // / \
1107 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
1108 // | |
1109 // Add Reverse
1110 // | |
1111 //
1112 // CopyInsertion pass will conceptually generate the following, but with the
1113 // actual GTE and Tuple instructions optimized away:
1114 //
1115 // Tuple // old root
1116 // / \
1117 // / \
1118 // GTE(0) GTE(1)
1119 // | / \
1120 // | / \
1121 // | GTE(0) GTE(1)
1122 // | | |
1123 // | | Copy
1124 // | | |
1125 // \ | /
1126 // \ Tuple // "inner" tuple.
1127 // \ /
1128 // \ /
1129 // Tuple // new root
1130 //
TEST_F(WhileCopyInsertionTest,NestedTupleElements)1131 TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
1132 auto condition = module_->AddEmbeddedComputation(
1133 BuildConditionComputation(nested_loop_state_shape_));
1134 auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
1135 BuildWhileInstruction(condition, body, true);
1136
1137 // HloInstruction* old_root = body->root_instruction();
1138 InsertCopies(module_.get());
1139
1140 // The only copy necessary is for the kReverse as it cannot be done
1141 // in-place (instruction can share buffer with operand). The other elements of
1142 // the loop state are kAdd instructions which can be done in-place.
1143 EXPECT_EQ(CountCopies(*body), 1);
1144
1145 // Each element of the init needs a copy as all are constants.
1146 EXPECT_EQ(CountCopies(*module_), 4);
1147
1148 // Either the kReverse itself must be copied or the operand of the kReverse
1149 // must be copied.
1150 if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
1151 HloOpcode::kCopy) {
1152 EXPECT_THAT(
1153 body->root_instruction(),
1154 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
1155 } else {
1156 EXPECT_THAT(
1157 body->root_instruction(),
1158 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
1159 }
1160 }
1161
1162 // Tests while init instruction which points-to a constant.
1163 //
1164 // init = Tuple(Constant(S32, {}), Constant(F32, {8}))
1165 //
1166 // CopyInsertion pass should add copies for both constants.
1167 //
TEST_F(WhileCopyInsertionTest,InitPointsToConstant)1168 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
1169 auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
1170
1171 InsertCopies(module_.get());
1172 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1173 EXPECT_EQ(CountCopies(*module_), 2);
1174
1175 EXPECT_THAT(while_hlo->operand(0),
1176 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1177 }
1178
1179 // Tests while init instruction which points-to a parameter.
1180 //
1181 // init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1182 //
1183 // CopyInsertion pass should add copies for both the constant and parameter.
1184 //
TEST_F(WhileCopyInsertionTest,InitPointsToParameter)1185 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
1186 auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
1187
1188 InsertCopies(module_.get());
1189 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1190 EXPECT_EQ(CountCopies(*module_), 2);
1191
1192 EXPECT_THAT(while_hlo->operand(0),
1193 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
1194 }
1195
1196 // Tests while init instruction which has a non-distinct points-to set.
1197 //
1198 // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
1199 //
1200 // CopyInsertion pass will conceptually generate the following, but with some of
1201 // the actual GTE and Tuple instructions optimized away:
1202 //
1203 // Tuple // old init
1204 // / \
1205 // / \
1206 // GTE(0) GTE(1)
1207 // | / \
1208 // | / \
1209 // | GTE(0) GTE(1)
1210 // | | |
1211 // Copy Copy Copy
1212 // | | |
1213 // \ | /
1214 // \ Tuple
1215 // \ /
1216 // \ /
1217 // Tuple // new init
1218 //
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinct)1219 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
1220 auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
1221
1222 InsertCopies(module_.get());
1223
1224 // The entry computation requires two copies to resolve the non-distinctness
1225 // of two init elements and the constant passed in as one of the init
1226 // elements. Either element can be copied for the distinctness issue.
1227 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1228 if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
1229 HloOpcode::kCopy) {
1230 EXPECT_THAT(
1231 while_hlo->operand(0),
1232 op::Tuple(op::Copy(op::Constant()),
1233 op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
1234 } else {
1235 EXPECT_THAT(
1236 while_hlo->operand(0),
1237 op::Tuple(op::Copy(op::Constant()),
1238 op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
1239 }
1240
1241 // The body requires one copy because the buffer set is not distinct: the
1242 // result of one of the adds is written into two elements of the output of the
1243 // loop body. Either element might be copied.
1244 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1245 if (while_hlo->while_body()
1246 ->root_instruction()
1247 ->operand(1)
1248 ->operand(0)
1249 ->opcode() == HloOpcode::kCopy) {
1250 EXPECT_THAT(
1251 while_hlo->while_body()->root_instruction(),
1252 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1253 } else {
1254 EXPECT_THAT(
1255 while_hlo->while_body()->root_instruction(),
1256 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1257 }
1258 }
1259
1260 // Tests while init instruction buffer which interferes with while result
1261 // buffer.
1262 //
1263 // init_data = Broadcast(...)
1264 // add_unrelated = Add(init_data) // takes a reference to cause interference
1265 // init = Tuple(Constant(S32, {}), init_data))
1266 //
1267 // CopyInsertion pass should copy both operands.
1268 //
TEST_F(WhileCopyInsertionTest,InitPointsToInterfering)1269 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
1270 auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
1271
1272 InsertCopies(module_.get());
1273 EXPECT_EQ(CountCopies(*module_), 2);
1274 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1275
1276 EXPECT_THAT(while_hlo->operand(0),
1277 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
1278 }
1279
1280 // Tests while init instruction buffer which has a non-distinct points-to set:
1281 //
1282 // init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
1283 // Parameter(F32, {8})))
1284 //
1285 // where the second and third parameters are identical *and* the tuple shared
1286 // by another while instruction.
1287 //
1288 // Verifies that the resulting point-to set is distinct in the resulting Tuple
1289 // (non-identical Copys). In other words, verifies that copy sharing does not
1290 // insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinctUsedByTwoWhileLoops)1291 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
1292 // Loop body that outputs tuple comprises two elements dependent on the init
1293 // tuple.
1294 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
1295 {induction_variable_shape_, data_shape_, data_shape_});
1296
1297 auto condition1 = module_->AddEmbeddedComputation(
1298 BuildConditionComputation(loop_state_shape));
1299 auto condition2 = module_->AddEmbeddedComputation(
1300 BuildConditionComputation(loop_state_shape));
1301 auto body1 =
1302 module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1303 auto body2 =
1304 module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1305
1306 auto builder = HloComputation::Builder(TestName() + ".While");
1307
1308 auto iter_param = builder.AddInstruction(
1309 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1310 auto data_param = builder.AddInstruction(
1311 HloInstruction::CreateParameter(1, data_shape_, "data"));
1312
1313 // Loop init tuple contains two identical parameter buffers.
1314 auto loop_init = builder.AddInstruction(
1315 HloInstruction::CreateTuple({iter_param, data_param, data_param}));
1316
1317 // Two while loops share the same loop init tuple.
1318 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1319 loop_state_shape, condition1, body1, loop_init));
1320 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1321 loop_state_shape, condition2, body2, loop_init));
1322
1323 // Add add instruction so neither while is dead.
1324 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1325 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1326 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1327 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
1328 builder.AddInstruction(
1329 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1330
1331 module_->AddEntryComputation(builder.Build());
1332
1333 InsertCopies(module_.get());
1334
1335 // None of the bodies should have copies or control flow edges.
1336 EXPECT_EQ(CountCopies(*body1), 0);
1337 EXPECT_EQ(CountCopies(*body2), 0);
1338
1339 // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
1340 // these should not need to be copied before either while. However, copy
1341 // insertion is not able to reason about the transparency of elements through
1342 // while bodies in all circumstances so extra copies are added (b/xxx).
1343 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1344
1345 EXPECT_THAT(while_hlo1->operand(0),
1346 op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1347 EXPECT_THAT(while_hlo2->operand(0),
1348 op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1349 }
1350
TEST_F(CopyInsertionTest,SwizzlingWhile)1351 TEST_F(CopyInsertionTest, SwizzlingWhile) {
1352 // Test a while instruction with a body which permutes its tuple parameter
1353 // elements.
1354 auto module = CreateNewVerifiedModule();
1355 const Shape loop_state_shape =
1356 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1357
1358 // Body simply interchanges the two tuple elements in the loop state.
1359 auto body_builder = HloComputation::Builder("body");
1360 auto body_param = body_builder.AddInstruction(
1361 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1362 auto body_element_0 = body_builder.AddInstruction(
1363 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1364 auto body_element_1 = body_builder.AddInstruction(
1365 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1366 body_builder.AddInstruction(
1367 HloInstruction::CreateTuple({body_element_1, body_element_0}));
1368 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1369
1370 auto cond_builder = HloComputation::Builder("condition");
1371 cond_builder.AddInstruction(
1372 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1373 auto cond_constant = cond_builder.AddInstruction(
1374 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1375 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1376 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1377 HloComputation* condition =
1378 module->AddEmbeddedComputation(cond_builder.Build());
1379
1380 auto builder = HloComputation::Builder(TestName());
1381 auto constant1 = builder.AddInstruction(
1382 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1383 auto constant2 = builder.AddInstruction(
1384 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1385 auto tuple = builder.AddInstruction(
1386 HloInstruction::CreateTuple({constant1, constant2}));
1387 auto xla_while = builder.AddInstruction(
1388 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1389 module->AddEntryComputation(builder.Build());
1390
1391 InsertCopies(module.get());
1392
1393 EXPECT_EQ(CountCopies(*module), 6);
1394
1395 // The loop state elements should be copied at the parameter and at the root
1396 // with a control edge in between (see DeepCopyAndAddControlEdges). This is
1397 // technically one more copy than is strictly necessary, but in order to have
1398 // only three copies the copies of different loop state elements must be
1399 // ordered with a control edge.
1400 EXPECT_EQ(CountCopies(*body), 4);
1401 EXPECT_EQ(CountControlEdges(*body), 2);
1402
1403 EXPECT_THAT(body->root_instruction(),
1404 op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
1405
1406 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1407 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1408 }
1409
TEST_F(CopyInsertionTest,CrossingParameters)1410 TEST_F(CopyInsertionTest, CrossingParameters) {
1411 // Test a case where two parameters' dataflow cross with each other while
1412 // input and output are aliased with same index:
1413 //
1414 // (p0 , p1)
1415 // | \ /|
1416 // | \ / |
1417 // alias X alias
1418 // | / \ |
1419 // | / \|
1420 // (p1 , p0)
1421 auto module = CreateNewVerifiedModule();
1422 const Shape tuple_shape =
1423 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1424
1425 auto builder = HloComputation::Builder(TestName());
1426 auto param = builder.AddInstruction(
1427 HloInstruction::CreateParameter(0, tuple_shape, "0"));
1428 auto gte0 = builder.AddInstruction(
1429 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1430 auto gte1 = builder.AddInstruction(
1431 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1432 builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
1433 module->AddEntryComputation(builder.Build());
1434 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1435 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1436 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1437 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1438 InsertCopies(module.get());
1439
1440 EXPECT_EQ(CountCopies(*module), 4);
1441 }
1442
TEST_F(CopyInsertionTest,ParametersAliasing)1443 TEST_F(CopyInsertionTest, ParametersAliasing) {
1444 // Test a case where two parameters' dataflow don't interfere with each other
1445 // while aliased.
1446 //
1447 // (p0 , p1)
1448 // | |
1449 // | |
1450 // alias alias
1451 // | |
1452 // | |
1453 // (p0 , p1)
1454 auto module = CreateNewVerifiedModule();
1455 const Shape tuple_shape =
1456 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1457
1458 auto builder = HloComputation::Builder(TestName());
1459 auto param = builder.AddInstruction(
1460 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1461 auto gte0 = builder.AddInstruction(
1462 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1463 auto gte1 = builder.AddInstruction(
1464 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1465 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1466 module->AddEntryComputation(builder.Build());
1467 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1468 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1469 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1470 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1471 InsertCopies(module.get());
1472
1473 EXPECT_EQ(CountCopies(*module), 0);
1474 }
1475
TEST_F(CopyInsertionTest,ParameterWithNoAliasing)1476 TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
1477 // Test a case where no parameter is aliased with result. In this case, copy
1478 // should be added
1479 //
1480 // (p0 , p1)
1481 // | |
1482 // | |
1483 // | |
1484 // | |
1485 // | |
1486 // (p0 , p1)
1487 auto module = CreateNewVerifiedModule();
1488 const Shape tuple_shape =
1489 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1490
1491 auto builder = HloComputation::Builder(TestName());
1492 auto param = builder.AddInstruction(
1493 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1494 auto gte0 = builder.AddInstruction(
1495 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1496 auto gte1 = builder.AddInstruction(
1497 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1498 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1499 module->AddEntryComputation(builder.Build());
1500 InsertCopies(module.get());
1501
1502 EXPECT_THAT(module->entry_computation()->root_instruction(),
1503 op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
1504 op::Copy(op::GetTupleElement(param, 1))));
1505
1506 EXPECT_EQ(CountCopies(*module), 2);
1507 }
1508
TEST_F(CopyInsertionTest,ParameterWithPartialAliasing)1509 TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
1510 // Test a case where one parameter is aliased with result while another one
1511 // isn't.
1512 //
1513 // (p0 , p1)
1514 // | |
1515 // | |
1516 // alias |
1517 // | |
1518 // | |
1519 // (p0 , p1)
1520 auto module = CreateNewVerifiedModule();
1521 const Shape tuple_shape =
1522 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1523
1524 auto builder = HloComputation::Builder(TestName());
1525 auto param = builder.AddInstruction(
1526 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1527 auto gte0 = builder.AddInstruction(
1528 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1529 auto gte1 = builder.AddInstruction(
1530 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1531 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1532 module->AddEntryComputation(builder.Build());
1533 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1534 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1535 InsertCopies(module.get());
1536
1537 EXPECT_THAT(module->entry_computation()->root_instruction(),
1538 op::Tuple(op::GetTupleElement(param, 0),
1539 op::Copy(op::GetTupleElement(param, 1))));
1540
1541 EXPECT_EQ(CountCopies(*module), 1);
1542 }
1543
TEST_F(CopyInsertionTest,ParameterAndParallelOpsWithPartialAliasing)1544 TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
1545 // Test a case where one parameter is aliased with result while another one
1546 // isn't.
1547 //
1548 // +-- (p0 , p1)
1549 // | | |
1550 // | | |
1551 // alias Negate Negate
1552 // | | |
1553 // | | |
1554 // +-- (p0 , p1)
1555 auto module = CreateNewVerifiedModule();
1556 const Shape tuple_shape =
1557 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1558
1559 auto builder = HloComputation::Builder(TestName());
1560 auto param = builder.AddInstruction(
1561 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1562 auto gte0 = builder.AddInstruction(
1563 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1564 auto gte1 = builder.AddInstruction(
1565 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1566
1567 auto negate0 = builder.AddInstruction(
1568 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1569
1570 auto negate1 = builder.AddInstruction(
1571 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1572 builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
1573 module->AddEntryComputation(builder.Build());
1574 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1575 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1576 InsertCopies(module.get());
1577
1578 EXPECT_EQ(CountCopies(*module), 0);
1579 }
1580
TEST_F(CopyInsertionTest,ParameterAndOpsWithPartialAliasing)1581 TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
1582 // Test a case where one parameter is aliased with result while another one
1583 // isn't.
1584 //
1585 // +-- (p0 , p1)
1586 // | | |
1587 // | | |
1588 // alias Negate Negate
1589 // | | |
1590 // | Add----+
1591 // | | |
1592 // +-- (p0 , p1)
1593 auto module = CreateNewVerifiedModule();
1594 const Shape tuple_shape =
1595 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1596
1597 auto builder = HloComputation::Builder(TestName());
1598 auto param = builder.AddInstruction(
1599 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1600 auto gte0 = builder.AddInstruction(
1601 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1602 auto gte1 = builder.AddInstruction(
1603 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1604
1605 auto negate0 = builder.AddInstruction(
1606 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1607
1608 auto negate1 = builder.AddInstruction(
1609 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1610
1611 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1612 scalar_shape_, HloOpcode::kAdd, negate0, negate1));
1613 builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
1614 module->AddEntryComputation(builder.Build());
1615 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1616 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1617 InsertCopies(module.get());
1618
1619 EXPECT_EQ(CountCopies(*module), 0);
1620 }
1621
TEST_F(CopyInsertionTest,SwizzlingWhileWithOneOp)1622 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
1623 // Test a while instruction with a body which permutes its tuple parameter
1624 // elements and applies one operation to one of the elements. The addition of
1625 // the operation (instruction) on the element makes the live range of the
1626 // respective input and output elements different than if the instruction were
1627 // not there (as in the SwizzlingWhile test above).
1628 auto module = CreateNewVerifiedModule();
1629 const Shape loop_state_shape =
1630 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1631
1632 // Body interchanges the two tuple elements in the loop state and negates one
1633 // of them.
1634 auto body_builder = HloComputation::Builder("body");
1635 auto body_param = body_builder.AddInstruction(
1636 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1637 auto body_element_0 = body_builder.AddInstruction(
1638 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1639 auto body_element_1 = body_builder.AddInstruction(
1640 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1641 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1642 scalar_shape_, HloOpcode::kNegate, body_element_1));
1643 body_builder.AddInstruction(
1644 HloInstruction::CreateTuple({negate, body_element_0}));
1645 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1646
1647 auto cond_builder = HloComputation::Builder("condition");
1648 cond_builder.AddInstruction(
1649 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1650 auto cond_constant = cond_builder.AddInstruction(
1651 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1652 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1653 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1654 HloComputation* condition =
1655 module->AddEmbeddedComputation(cond_builder.Build());
1656
1657 auto builder = HloComputation::Builder(TestName());
1658 auto constant1 = builder.AddInstruction(
1659 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1660 auto constant2 = builder.AddInstruction(
1661 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1662 auto tuple = builder.AddInstruction(
1663 HloInstruction::CreateTuple({constant1, constant2}));
1664 auto xla_while = builder.AddInstruction(
1665 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1666 module->AddEntryComputation(builder.Build());
1667
1668 InsertCopies(module.get());
1669
1670 EXPECT_EQ(CountCopies(*module), 6);
1671
1672 // The loop state elements should be copied at the parameter and at the root
1673 // with a control edge in between (see DeepCopyAndAddControlEdges).
1674 EXPECT_EQ(CountCopies(*body), 4);
1675 EXPECT_EQ(CountControlEdges(*body), 2);
1676
1677 EXPECT_THAT(
1678 body->root_instruction(),
1679 op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
1680
1681 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1682 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1683 }
1684
TEST_F(CopyInsertionTest,SwizzlingWhileSharedInput)1685 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
1686 // Test a while instruction with a body which permutes it's tuple parameter
1687 // elements similar to SwizzlinWhile above. However, in this test the input to
1688 // the while body is a single constant (both loop state elements are the same
1689 // constant). This means no copies are necessary because both loop state
1690 // elements are the same so interchanging them is a no-op.
1691 auto module = CreateNewVerifiedModule();
1692 const Shape loop_state_shape =
1693 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1694
1695 // Body simply interchanges the two tuple elements in the loop state.
1696 auto body_builder = HloComputation::Builder("body");
1697 auto body_param = body_builder.AddInstruction(
1698 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1699 auto body_element_0 = body_builder.AddInstruction(
1700 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1701 auto body_element_1 = body_builder.AddInstruction(
1702 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1703 body_builder.AddInstruction(
1704 HloInstruction::CreateTuple({body_element_1, body_element_0}));
1705 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1706
1707 auto cond_builder = HloComputation::Builder("condition");
1708 cond_builder.AddInstruction(
1709 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1710 auto cond_constant = cond_builder.AddInstruction(
1711 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1712 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1713 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1714 HloComputation* condition =
1715 module->AddEmbeddedComputation(cond_builder.Build());
1716
1717 auto builder = HloComputation::Builder(TestName());
1718 auto constant = builder.AddInstruction(
1719 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1720 auto tuple =
1721 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
1722 builder.AddInstruction(
1723 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1724 module->AddEntryComputation(builder.Build());
1725
1726 InsertCopies(module.get());
1727
1728 EXPECT_EQ(CountCopies(*module), 2);
1729 EXPECT_EQ(CountCopies(*body), 0);
1730
1731 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1732 EXPECT_THAT(module->entry_computation()->root_instruction(),
1733 op::Tuple(op::Copy(), op::Copy()));
1734 }
1735
TEST_F(CopyInsertionTest,SequentialWhiles)1736 TEST_F(CopyInsertionTest, SequentialWhiles) {
1737 // Construct a computation with a series of sequential while instructions
1738 // containing four loop state elements:
1739 //
1740 // element 0 is passed to each while directly from an entry parameter.
1741 //
1742 // element 1 is passed transparently in series through all the while bodies.
1743 //
1744 // element 2 is negated in each while body. (in-place possible)
1745 //
1746 // element 3 is reversed in each while body. (in-place not possible)
1747 //
1748 const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1749 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1750 {element_shape, element_shape, element_shape, element_shape});
1751
1752 auto module = CreateNewVerifiedModule();
1753 auto builder = HloComputation::Builder(TestName());
1754 auto param_0 = builder.AddInstruction(
1755 HloInstruction::CreateParameter(0, element_shape, "param_0"));
1756 auto param_1 = builder.AddInstruction(
1757 HloInstruction::CreateParameter(1, element_shape, "param_1"));
1758 auto param_2 = builder.AddInstruction(
1759 HloInstruction::CreateParameter(2, element_shape, "param_2"));
1760 auto param_3 = builder.AddInstruction(
1761 HloInstruction::CreateParameter(3, element_shape, "param_3"));
1762
1763 // The number of sequential kWhile instructions.
1764 const int kNumWhiles = 3;
1765
1766 HloInstruction* prev_element_1 = param_1;
1767 HloInstruction* prev_element_2 = param_2;
1768 HloInstruction* prev_element_3 = param_3;
1769
1770 // Vector containing all of the while instructions.
1771 std::vector<const HloInstruction*> whiles;
1772 for (int i = 0; i < kNumWhiles; ++i) {
1773 auto body_builder = HloComputation::Builder("body");
1774 auto body_param = body_builder.AddInstruction(
1775 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1776 auto body_element_0 = body_builder.AddInstruction(
1777 HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
1778 auto body_element_1 = body_builder.AddInstruction(
1779 HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
1780 auto body_element_2 = body_builder.AddInstruction(
1781 HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
1782 auto body_element_3 = body_builder.AddInstruction(
1783 HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
1784 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1785 element_shape, HloOpcode::kNegate, body_element_2));
1786 auto reverse = body_builder.AddInstruction(
1787 HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
1788 body_builder.AddInstruction(HloInstruction::CreateTuple(
1789 {body_element_0, body_element_1, negate, reverse}));
1790 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1791
1792 auto cond_builder = HloComputation::Builder("condition");
1793 cond_builder.AddInstruction(
1794 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1795 auto cond_constant = cond_builder.AddInstruction(
1796 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1797 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1798 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1799 HloComputation* condition =
1800 module->AddEmbeddedComputation(cond_builder.Build());
1801
1802 auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
1803 {param_0, prev_element_1, prev_element_2, prev_element_3}));
1804
1805 auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
1806 loop_state_shape, condition, body, while_init));
1807 whiles.push_back(xla_while);
1808 if (i != kNumWhiles - 1) {
1809 prev_element_1 = builder.AddInstruction(
1810 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
1811 prev_element_2 = builder.AddInstruction(
1812 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
1813 prev_element_3 = builder.AddInstruction(
1814 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
1815 }
1816 }
1817
1818 module->AddEntryComputation(builder.Build());
1819
1820 InsertCopies(module.get());
1821
1822 // Each while body has one copy. And each loop state element is copied once in
1823 // the entry computation.
1824 EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
1825
1826 // Each while body should have exactly one copy for element three which is an
1827 // op (kReverse) which cannot be done in place.
1828 for (const HloInstruction* xla_while : whiles) {
1829 EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
1830 }
1831
1832 EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
1833 op::Copy(), op::Copy()));
1834 EXPECT_THAT(module->entry_computation()->root_instruction(),
1835 op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
1836 op::GetTupleElement()));
1837 }
1838
TEST_F(CopyInsertionTest,WhileBodyWithConstantRoot)1839 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
1840 // Test a while body and condition which are each simply a constant (root of
1841 // computation is a constant). The body constant should be copied.
1842 auto module = CreateNewVerifiedModule();
1843 auto builder = HloComputation::Builder(TestName());
1844 auto param_0 = builder.AddInstruction(
1845 HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
1846
1847 auto body_builder = HloComputation::Builder("body");
1848 body_builder.AddInstruction(
1849 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1850 body_builder.AddInstruction(
1851 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
1852 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1853
1854 auto cond_builder = HloComputation::Builder("condition");
1855 cond_builder.AddInstruction(
1856 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1857 cond_builder.AddInstruction(
1858 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1859 HloComputation* condition =
1860 module->AddEmbeddedComputation(cond_builder.Build());
1861
1862 auto xla_while = builder.AddInstruction(
1863 HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
1864
1865 module->AddEntryComputation(builder.Build());
1866
1867 InsertCopies(module.get());
1868
1869 EXPECT_EQ(CountCopies(*module), 2);
1870
1871 EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
1872 EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
1873 EXPECT_THAT(condition->root_instruction(), op::Constant());
1874 }
1875
TEST_F(CopyInsertionTest,TokensShouldNotBeCopied)1876 TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
1877 std::string module_string = R"(
1878 HloModule TokensShouldNotBeCopied
1879
1880 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
1881 %param.1 = (s32[], token[]) parameter(0)
1882 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
1883 %constant.1 = s32[] constant(1)
1884 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
1885 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
1886 %after-all = token[] after-all(token[] %get-tuple-element.2)
1887 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
1888 }
1889
1890 %Cond (param: (s32[], token[])) -> pred[] {
1891 %param = (s32[], token[]) parameter(0)
1892 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
1893 %constant = s32[] constant(42)
1894 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
1895 }
1896
1897 ENTRY %TokensShouldNotBeCopied () -> s32[] {
1898 %one = s32[] constant(1)
1899 %negative_one = s32[] negate(%one)
1900 %init_token = token[] after-all()
1901 %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
1902 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
1903 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
1904 }
1905 )";
1906 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1907 ParseAndReturnVerifiedModule(module_string));
1908 InsertCopies(module.get());
1909
1910 // There should be no copies added because tokens should not be copied.
1911 EXPECT_EQ(CountCopies(*module), 0);
1912 }
1913
MakeTrivialCondition(const Shape & shape)1914 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
1915 auto builder = HloComputation::Builder("trivial_condition");
1916 builder.AddInstruction(
1917 HloInstruction::CreateParameter(0, shape, "loop_state"));
1918 auto constant = builder.AddInstruction(
1919 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1920 builder.AddInstruction(HloInstruction::CreateUnary(
1921 constant->shape(), HloOpcode::kNot, constant));
1922 return builder.Build();
1923 }
1924
MakeBenchmarkWhileBody()1925 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
1926 auto builder = HloComputation::Builder("benchmark_loop_body");
1927 const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1928 const Shape loop_state_shape =
1929 ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
1930 HloInstruction* param = builder.AddInstruction(
1931 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
1932 HloInstruction* element_0 = builder.AddInstruction(
1933 HloInstruction::CreateGetTupleElement(element_shape, param, 0));
1934 HloInstruction* element_1 = builder.AddInstruction(
1935 HloInstruction::CreateGetTupleElement(element_shape, param, 1));
1936 HloInstruction* element_2 = builder.AddInstruction(
1937 HloInstruction::CreateGetTupleElement(element_shape, param, 2));
1938
1939 HloInstruction* rev_1 = builder.AddInstruction(
1940 HloInstruction::CreateReverse(element_shape, element_1, {0}));
1941 HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
1942 element_shape, HloOpcode::kAdd, element_1, element_2));
1943
1944 builder.AddInstruction(
1945 HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
1946 return builder.Build();
1947 }
1948
BM_SequentialWhiles(::testing::benchmark::State & state)1949 void BM_SequentialWhiles(::testing::benchmark::State& state) {
1950 const int num_whiles = state.range(0);
1951
1952 // This benchmark constructs a chain of sequential while instructions.
1953 // Timer starts automatically at the first iteration of this loop
1954 // and ends after the last one.
1955 for (auto s : state) {
1956 state.PauseTiming();
1957 HloModuleConfig config;
1958 config.set_debug_options(GetDebugOptionsFromFlags());
1959 HloModule module("BM_SequentialWhiles", config);
1960
1961 auto builder = HloComputation::Builder("BM_SequentialWhiles");
1962 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
1963 0, ShapeUtil::MakeShape(F32, {42}), "x"));
1964 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
1965 1, ShapeUtil::MakeShape(F32, {42}), "y"));
1966 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
1967 2, ShapeUtil::MakeShape(F32, {42}), "z"));
1968 HloInstruction* init =
1969 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
1970
1971 HloInstruction* prev_loop_state = init;
1972 for (int w = 0; w < num_whiles; ++w) {
1973 HloComputation* condition =
1974 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
1975 HloComputation* body =
1976 module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
1977 prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
1978 init->shape(), condition, body, prev_loop_state));
1979 }
1980 module.AddEntryComputation(builder.Build());
1981
1982 CopyInsertion copy_insertion;
1983
1984 state.ResumeTiming();
1985 ASSERT_IS_OK(copy_insertion.Run(&module).status());
1986 state.PauseTiming();
1987
1988 // The entry computation should have three copies, and each body has one.
1989 ASSERT_EQ(CountCopies(module), 3 + num_whiles);
1990 state.ResumeTiming();
1991 }
1992 }
1993
BM_ParallelWhiles(::testing::benchmark::State & state)1994 void BM_ParallelWhiles(::testing::benchmark::State& state) {
1995 const int num_whiles = state.range(0);
1996
1997 // This benchmark constructs a fan-out of parallel while instructions.
1998 for (auto s : state) {
1999 state.PauseTiming();
2000 HloModuleConfig config;
2001 config.set_debug_options(GetDebugOptionsFromFlags());
2002 HloModule module("BM_SequentialWhiles", config);
2003
2004 auto builder = HloComputation::Builder("BM_ParallelWhiles");
2005 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
2006 0, ShapeUtil::MakeShape(F32, {42}), "x"));
2007 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
2008 1, ShapeUtil::MakeShape(F32, {42}), "y"));
2009 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
2010 2, ShapeUtil::MakeShape(F32, {42}), "z"));
2011 HloInstruction* init =
2012 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
2013
2014 HloInstruction* sum = nullptr;
2015 for (int w = 0; w < num_whiles; ++w) {
2016 HloComputation* condition =
2017 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2018 HloComputation* body =
2019 module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
2020
2021 HloInstruction* xla_while = builder.AddInstruction(
2022 HloInstruction::CreateWhile(init->shape(), condition, body, init));
2023
2024 if (sum == nullptr) {
2025 sum = builder.AddInstruction(
2026 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2027 } else {
2028 HloInstruction* element_0 = builder.AddInstruction(
2029 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2030 sum = builder.AddInstruction(HloInstruction::CreateBinary(
2031 x->shape(), HloOpcode::kAdd, sum, element_0));
2032 }
2033 }
2034 module.AddEntryComputation(builder.Build());
2035
2036 CopyInsertion copy_insertion;
2037
2038 state.ResumeTiming();
2039 ASSERT_IS_OK(copy_insertion.Run(&module).status());
2040 state.PauseTiming();
2041
2042 // Each body receives of copy of two of the parameters (the corresponding
2043 // elements in the body are modified), and there is one copy in each body.
2044 ASSERT_EQ(CountCopies(module), 3 * num_whiles);
2045 }
2046 }
2047
MakeBenchmarkWhileBody(const int num_tuple_inputs)2048 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
2049 const int num_tuple_inputs) {
2050 auto builder = HloComputation::Builder("benchmark_loop_body");
2051 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2052 std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
2053 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
2054 HloInstruction* param = builder.AddInstruction(
2055 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2056 std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
2057 for (int i = 0; i < num_tuple_inputs; ++i) {
2058 gte_nodes[i] = builder.AddInstruction(
2059 HloInstruction::CreateGetTupleElement(element_shape, param, i));
2060 }
2061 builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
2062 return builder.Build();
2063 }
2064
BM_ManyElementTuple(::testing::benchmark::State & state)2065 void BM_ManyElementTuple(::testing::benchmark::State& state) {
2066 const int num_tuple_inputs = state.range(0);
2067 HloModuleConfig config;
2068 config.set_debug_options(GetDebugOptionsFromFlags());
2069 CopyInsertion copy_insertion;
2070 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2071 std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
2072 for (auto s : state) {
2073 state.PauseTiming();
2074 auto builder = HloComputation::Builder("BM_ParallelWhiles");
2075 HloModule module("BM_ManyElementTuple", config);
2076 for (int j = 0; j < num_tuple_inputs; ++j) {
2077 tuple_params[j] = builder.AddInstruction(
2078 HloInstruction::CreateParameter(j, element_shape, ""));
2079 }
2080 HloInstruction* init =
2081 builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
2082 HloComputation* condition =
2083 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2084 HloComputation* body =
2085 module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
2086 HloInstruction* xla_while = builder.AddInstruction(
2087 HloInstruction::CreateWhile(init->shape(), condition, body, init));
2088 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
2089 ShapeUtil::MakeShape(F32, {}), xla_while, 0));
2090 module.AddEntryComputation(builder.Build());
2091 state.ResumeTiming();
2092 ASSERT_IS_OK(copy_insertion.Run(&module).status());
2093 }
2094 }
2095
2096 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2097 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2098 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
2099
TEST_F(CopyInsertionTest,SimpleControlFlowTest)2100 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
2101 const std::string& hlo_string = R"(
2102 HloModule TestModule
2103
2104 if-body.v5 {
2105 constant.3 = s32[] constant(-1)
2106 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2107 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2108 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2109 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2110 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2111 tuple.33 = (s32[]) tuple(add.3)
2112 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2113 }
2114
2115 if-condition.v4 {
2116 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2117 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2118 constant.4 = s32[] constant(0)
2119 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2120 }
2121
2122 _functionalize_body_1__.v28 {
2123 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2124 get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
2125 constant.7 = s32[] constant(1)
2126 add.4 = s32[] add(get-tuple-element.68, constant.7)
2127 get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
2128 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
2129 less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
2130 constant.8 = s32[] constant(0)
2131 select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2132 get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
2133 tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
2134 tuple.36 = (s32[]) tuple(constant.8)
2135 tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
2136 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
2137 get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
2138 get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
2139 ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
2140 }
2141
2142 cond_wrapper.v3.1 {
2143 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2144 get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
2145 constant.11 = s32[] constant(7)
2146 ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
2147 }
2148
2149 _functionalize_body_2__.v25 {
2150 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2151 get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
2152 get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
2153 get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
2154 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
2155 tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
2156 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2157 get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
2158 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
2159 constant.12 = s32[] constant(1)
2160 add.5 = s32[] add(get-tuple-element.81, constant.12)
2161 get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
2162 ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
2163 }
2164
2165 cond_wrapper.v3.2 {
2166 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2167 get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
2168 constant.13 = s32[] constant(5)
2169 ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
2170 }
2171
2172 ENTRY TestComputation {
2173 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2174 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2175 }
2176 )";
2177 auto module = ParseAndReturnVerifiedModule(hlo_string).value();
2178 InsertCopies(module.get());
2179 }
2180
TEST_F(CopyInsertionTest,ControlFlowTest)2181 TEST_F(CopyInsertionTest, ControlFlowTest) {
2182 const std::string& hlo_string = R"(
2183 HloModule TestModule
2184
2185 if-body.v5 {
2186 constant.3 = s32[] constant(-1)
2187 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2188 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2189 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2190 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2191 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2192 tuple.33 = (s32[]) tuple(add.3)
2193 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2194 }
2195
2196 if-condition.v4 {
2197 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2198 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2199 constant.4 = s32[] constant(0)
2200 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2201 }
2202
2203 if-body.v5.1 {
2204 constant.5 = s32[] constant(-1)
2205 p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2206 get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
2207 get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
2208 multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
2209 tuple.35 = (s32[]) tuple(multiply.1)
2210 ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
2211 }
2212
2213 if-condition.v4.1 {
2214 p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2215 get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
2216 constant.6 = s32[] constant(1)
2217 ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
2218 }
2219
2220 _functionalize_body_1__.v28 {
2221 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2222 get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
2223 constant.7 = s32[] constant(1)
2224 add.4 = s32[] add(get-tuple-element.72, constant.7)
2225 get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
2226 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
2227 less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
2228 constant.8 = s32[] constant(0)
2229 select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2230 get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
2231 tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
2232 tuple.38 = (s32[]) tuple(constant.8)
2233 tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
2234 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
2235 while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
2236 get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
2237 get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
2238 ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
2239 }
2240
2241 cond_wrapper.v3.1 {
2242 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2243 get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
2244 constant.11 = s32[] constant(7)
2245 ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
2246 }
2247
2248 _functionalize_body_2__.v25 {
2249 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2250 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
2251 get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
2252 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
2253 get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
2254 tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
2255 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2256 get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
2257 get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
2258 constant.12 = s32[] constant(1)
2259 add.5 = s32[] add(get-tuple-element.84, constant.12)
2260 get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
2261 ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
2262 }
2263
2264 cond_wrapper.v3.2 {
2265 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2266 get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
2267 constant.13 = s32[] constant(5)
2268 ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
2269 }
2270
2271 ENTRY TestComputation {
2272 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2273 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2274 }
2275 )";
2276 auto module = ParseAndReturnVerifiedModule(hlo_string).value();
2277 InsertCopies(module.get());
2278 }
2279
TEST_F(CopyInsertionTest,NestedWhiles)2280 TEST_F(CopyInsertionTest, NestedWhiles) {
2281 // Verify that only no unnecessary copies remain after copy insertion for
2282 // trivial nested whiles (b/112472605).
2283 const std::string& hlo_string = R"(
2284 HloModule TestModule
2285
2286 cond.inner {
2287 ROOT param.cond.inner = pred[] parameter(0)
2288 }
2289
2290 body.inner {
2291 param.body.inner = pred[] parameter(0)
2292 ROOT not = pred[] not(param.body.inner)
2293 }
2294
2295 cond.outer {
2296 ROOT param.cond.outer = pred[] parameter(0)
2297 }
2298
2299 body.outer {
2300 param.cond.outer = pred[] parameter(0)
2301 ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2302 }
2303
2304 ENTRY TestComputation {
2305 entry_param = pred[] parameter(0)
2306 ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2307 }
2308 )";
2309 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2310 ParseAndReturnVerifiedModule(hlo_string));
2311 InsertCopies(module.get());
2312
2313 // There should only be a single copy inserted, and it's in the entry
2314 // computation.
2315 EXPECT_EQ(CountCopies(*module), 1);
2316 EXPECT_THAT(module->entry_computation()->root_instruction(),
2317 op::While(op::Copy(op::Parameter())));
2318 }
2319
TEST_F(CopyInsertionTest,NestedWhilesWithParamRoot)2320 TEST_F(CopyInsertionTest, NestedWhilesWithParamRoot) {
2321 // Test that when the root of a computation is before other side-effecting
2322 // operation (e.g. when the while body computation parameter is the root), we
2323 // introduce an interference edge and copy at the level of this outer loop
2324 // body and not one level out.
2325 const std::string& hlo_string = R"(
2326 HloModule TestModule
2327
2328 cond.inner {
2329 ROOT param.cond.inner = pred[] parameter(0)
2330 }
2331
2332 body.inner {
2333 param.body.inner = pred[] parameter(0)
2334 ROOT not = pred[] not(param.body.inner)
2335 }
2336
2337 cond.outer {
2338 ROOT param.cond.outer = pred[] parameter(0)
2339 }
2340
2341 body.outer {
2342 ROOT param.cond.outer = pred[] parameter(0)
2343 while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2344 after-all = token[] after-all()
2345 outfeed = token[] outfeed(while, after-all)
2346 }
2347
2348 ENTRY TestComputation {
2349 entry_param = pred[] parameter(0)
2350 while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2351 ROOT not = pred[] not(while)
2352 }
2353 )";
2354 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2355 ParseAndReturnVerifiedModule(hlo_string));
2356 InsertCopies(module.get());
2357
2358 // There should only be a single copy inserted, and it's in the outer while
2359 // loop body.
2360 EXPECT_EQ(CountCopies(*module), 1);
2361 EXPECT_THAT(module->entry_computation()->root_instruction(),
2362 op::Not(op::While(op::Parameter())));
2363 HloInstruction* outfeed = FindInstruction(module.get(), "outfeed");
2364 EXPECT_THAT(outfeed, op::Outfeed(op::While(op::Copy(op::Parameter(0))),
2365 op::AfterAll()));
2366 }
2367
TEST_F(CopyInsertionTest,NestedWhilesWithParamRoot2)2368 TEST_F(CopyInsertionTest, NestedWhilesWithParamRoot2) {
2369 // Test that when the root of a computation is before other side-effecting
2370 // operation (e.g. when the while body computation parameter is the root), we
2371 // introduce an interference edge and copy at the level of this outer loop
2372 // body and not one level out.
2373 const std::string& hlo_string = R"(
2374 HloModule TestModule
2375
2376 cond.inner {
2377 param.cond.inner = (pred[], pred[]) parameter(0)
2378 ROOT gte = pred[] get-tuple-element(param.cond.inner), index=0
2379 }
2380
2381 body.inner {
2382 param.body.inner = (pred[], pred[]) parameter(0)
2383 gte.0 = pred[] get-tuple-element(param.body.inner), index=0
2384 gte.1 = pred[] get-tuple-element(param.body.inner), index=1
2385 and = pred[] and(gte.0, gte.1)
2386 not = pred[] not(gte.1)
2387 ROOT root = (pred[], pred[]) tuple(and, not)
2388 }
2389
2390 cond.outer {
2391 param.cond.outer = (pred[], pred[]) parameter(0)
2392 ROOT gte = pred[] get-tuple-element(param.cond.outer), index=0
2393 }
2394
2395 body.outer {
2396 param.body.outer = (pred[], pred[]) parameter(0)
2397 gte.0 = pred[] get-tuple-element(param.body.outer), index=0
2398 gte.1 = pred[] get-tuple-element(param.body.outer), index=1
2399 while.inner = (pred[], pred[]) while(param.body.outer), condition=cond.inner, body=body.inner
2400 gte.2 = pred[] get-tuple-element(while.inner), index=0
2401 after-all = token[] after-all()
2402 outfeed = token[] outfeed(gte.2, after-all)
2403 ROOT root = (pred[], pred[]) tuple(gte.0, gte.1)
2404 }
2405
2406 ENTRY TestComputation {
2407 entry_param.1 = pred[] parameter(0)
2408 entry_param.2 = pred[] parameter(1)
2409 tuple = (pred[], pred[]) tuple(entry_param.1, entry_param.2)
2410 while.outer = (pred[], pred[]) while(tuple), condition=cond.outer, body=body.outer
2411 gte = pred[] get-tuple-element(while.outer), index=0
2412 ROOT not = pred[] not(gte)
2413 }
2414 )";
2415 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2416 ParseAndReturnVerifiedModule(hlo_string));
2417 InsertCopies(module.get());
2418
2419 HloInstruction* while_inner = FindInstruction(module.get(), "while.inner");
2420 EXPECT_THAT(
2421 while_inner,
2422 op::While(op::Tuple(op::Copy(op::GetTupleElement(op::Parameter(0))),
2423 op::Copy(op::GetTupleElement(op::Parameter(0))))));
2424 }
2425
TEST_F(CopyInsertionTest,NestedWhileAndConditional2)2426 TEST_F(CopyInsertionTest, NestedWhileAndConditional2) {
2427 const std::string& hlo_string = R"(
2428 HloModule TestModule
2429
2430 on_true
2431 {
2432 v1 = f32[2] parameter(0)
2433 v2 = f32[2] add(v1,v1)
2434 ROOT t1 = (f32[2], f32[2]) tuple(v1,v2)
2435 }
2436
2437 on_false
2438 {
2439 v1 = f32[2] parameter(0)
2440 v2 = f32[2] multiply(v1,v1)
2441 ROOT t2 = (f32[2], f32[2]) tuple(v1,v2)
2442 }
2443
2444 cond.outer {
2445 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2446 ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2447 }
2448
2449 body.outer {
2450 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2451 pred.1 = pred[] get-tuple-element(param.1), index=0
2452 arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2453 if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2454 e1 = f32[2] get-tuple-element(if), index=0
2455 e2 = f32[2] get-tuple-element(if), index=1
2456 ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2457 }
2458
2459 ENTRY TestComputation {
2460 entry_param.1 = pred[] parameter(0)
2461 float_param = f32[2] parameter(1)
2462 entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2463 ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2464 }
2465 )";
2466 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2467 ParseAndReturnVerifiedModule(hlo_string));
2468 InsertCopies(module.get());
2469
2470 // An extra copy must be kept inside the loop due to uses in the conditional.
2471 EXPECT_EQ(CountCopies(*module), 3);
2472 }
2473
TEST_F(CopyInsertionTest,NestedWhileAndConditional)2474 TEST_F(CopyInsertionTest, NestedWhileAndConditional) {
2475 const std::string& hlo_string = R"(
2476 HloModule TestModule
2477
2478 on_true
2479 {
2480 v1 = f32[2] parameter(0)
2481 ROOT v2 = f32[2] add(v1,v1)
2482 }
2483
2484 on_false
2485 {
2486 v1 = f32[2] parameter(0)
2487 ROOT v2 = f32[2] multiply(v1,v1)
2488 }
2489
2490 cond.outer {
2491 param.1 = (pred[], f32[2]) parameter(0)
2492 ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2493 }
2494
2495 body.outer {
2496 param.1 = (pred[], f32[2]) parameter(0)
2497 pred.1 = pred[] get-tuple-element(param.1), index=0
2498 arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2499 if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2500 ROOT res = (pred[], f32[2]) tuple(pred.1,if)
2501 }
2502
2503 ENTRY TestComputation {
2504 entry_param.1 = pred[] parameter(0)
2505 float_param = f32[2] parameter(1)
2506 entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param)
2507 ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2508 }
2509 )";
2510 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2511 ParseAndReturnVerifiedModule(hlo_string));
2512 InsertCopies(module.get());
2513 VLOG(2) << module->ToString() << "\n";
2514
2515 // There should only be two copies inserted, and in the entry and exit of the
2516 // computation.
2517 EXPECT_EQ(CountCopies(*module), 2);
2518 }
2519
TEST_F(CopyInsertionTest,FixpointComputationRequired)2520 TEST_F(CopyInsertionTest, FixpointComputationRequired) {
2521 const std::string& hlo_string = R"(
2522 HloModule Module
2523
2524 fused_computation {
2525 param0 = f32[3,3,96,1] parameter(0)
2526 param1 = f32[] parameter(1)
2527 broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={}
2528 ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast)
2529 }
2530
2531 ENTRY entry_computation {
2532 arg0 = f32[3,3,96,1] parameter(0)
2533 arg1 = f32[] parameter(1)
2534 fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1),
2535 kind=kLoop, calls=fused_computation
2536 negate = f32[] negate(f32[] arg1)
2537 ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple(
2538 f32[3,3,96,1] fusion,
2539 f32[3,3,96,1] arg0,
2540 f32[] negate,
2541 f32[] arg1)
2542 }
2543 )";
2544
2545 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2546 ParseAndReturnVerifiedModule(hlo_string));
2547 // Set up the aliasing manually which normally would be set by
2548 // alias_passthrough_params pass.
2549 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2550 /*output_index=*/{1},
2551 /*param_number=*/0,
2552 /*param_index=*/{}));
2553 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2554 /*output_index=*/{3},
2555 /*param_number=*/1,
2556 /*param_index=*/{}));
2557
2558 InsertCopies(module.get());
2559
2560 // There should be no copies inserted.
2561 EXPECT_EQ(CountCopies(*module), 0);
2562 }
2563
TEST_F(CopyInsertionTest,NoAliasCheckViolation)2564 TEST_F(CopyInsertionTest, NoAliasCheckViolation) {
2565 const std::string& hlo_string = R"(
2566 HloModule cluster
2567
2568 ENTRY Entry {
2569 %arg = f32[8,28,28,1] parameter(0)
2570 %bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg)
2571 ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg)
2572 }
2573 )";
2574 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2575 ParseAndReturnVerifiedModule(hlo_string));
2576 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2577 /*output_index=*/{1},
2578 /*param_number=*/0,
2579 /*param_index=*/{}));
2580 InsertCopies(module.get());
2581 EXPECT_EQ(CountCopies(*module), 1);
2582 }
2583
TEST_F(CopyInsertionTest,DynamicUpdateSliceNoCopy)2584 TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
2585 absl::string_view hlo_string = R"(
2586 HloModule Module
2587
2588 ENTRY main {
2589 param = f32[1280,1,128] parameter(0)
2590 negate = f32[1280,1,128] negate(param)
2591 constant.1 = f32[] constant(0)
2592 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2593 constant.3 = s32[] constant(0)
2594 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2595 }
2596 )";
2597 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2598 ParseAndReturnVerifiedModule(hlo_string));
2599 InsertCopies(module.get());
2600 EXPECT_EQ(CountCopies(*module), 0);
2601 }
2602
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceNoCopy)2603 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
2604 absl::string_view hlo_string = R"(
2605 HloModule Module
2606
2607 fused_computation {
2608 param0 = f32[1280,1,128] parameter(0)
2609 constant.1 = f32[] constant(0)
2610 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2611 constant.3 = s32[] constant(0)
2612 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2613 }
2614
2615 ENTRY main {
2616 param = f32[1280,1,128] parameter(0)
2617 negate = f32[1280,1,128] negate(param)
2618 ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2619 }
2620 )";
2621 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2622 ParseAndReturnVerifiedModule(hlo_string));
2623 InsertCopies(module.get());
2624 EXPECT_EQ(CountCopies(*module), 0);
2625 }
2626
TEST_F(CopyInsertionTest,DynamicUpdateSliceCopy)2627 TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
2628 absl::string_view hlo_string = R"(
2629 HloModule Module
2630
2631 ENTRY main {
2632 param = f32[1280,1,128] parameter(0)
2633 negate = f32[1280,1,128] negate(param)
2634 constant.1 = f32[] constant(0)
2635 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2636 constant.3 = s32[] constant(0)
2637 add = f32[1280,1,128] add(negate, negate)
2638 dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2639 ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
2640 }
2641 )";
2642 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2643 ParseAndReturnVerifiedModule(hlo_string));
2644 InsertCopies(module.get());
2645 EXPECT_EQ(CountCopies(*module), 1);
2646 }
2647
TEST_F(CopyInsertionTest,DynamicUpdateSliceParameterShareCopy)2648 TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
2649 absl::string_view hlo_string = R"(
2650 HloModule Module
2651
2652 ENTRY main {
2653 param = f32[1280,1,128] parameter(0)
2654 constant.1 = f32[] constant(0)
2655 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2656 constant.3 = s32[] constant(0)
2657 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
2658 }
2659 )";
2660 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2661 ParseAndReturnVerifiedModule(hlo_string));
2662 InsertCopies(module.get());
2663 EXPECT_EQ(CountCopies(*module), 1);
2664 }
2665
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy)2666 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
2667 absl::string_view hlo_string = R"(
2668 HloModule Module
2669
2670 fused_computation {
2671 param0 = f32[1280,1,128] parameter(0)
2672 constant.1 = f32[] constant(0)
2673 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2674 constant.3 = s32[] constant(0)
2675 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2676 }
2677
2678 ENTRY main {
2679 param = f32[1280,1,128] parameter(0)
2680 negate = f32[1280,1,128] negate(param)
2681 add = f32[1280,1,128] add(negate, negate)
2682 fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2683 ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
2684 }
2685 )";
2686 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2687 ParseAndReturnVerifiedModule(hlo_string));
2688 InsertCopies(module.get());
2689 EXPECT_EQ(CountCopies(*module), 1);
2690 }
2691
TEST_F(CopyInsertionTest,ChainDynamicUpdateSliceCopy)2692 TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
2693 absl::string_view hlo_string = R"(
2694 HloModule Module
2695
2696 ENTRY main {
2697 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2698 constant.1 = f32[] constant(0)
2699 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2700 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2701 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2702 constant.2 = s32[] constant(128)
2703 add.5 = s32[] add(get-tuple-element.3, constant.2)
2704 constant.3 = s32[] constant(0)
2705 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2706 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2707 ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2708 }
2709 )";
2710 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2711 ParseAndReturnVerifiedModule(hlo_string));
2712 InsertCopies(module.get());
2713 EXPECT_EQ(CountCopies(*module), 1);
2714 }
2715
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy2)2716 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
2717 absl::string_view hlo_string = R"(
2718 HloModule Module
2719
2720 fused_computation.1 {
2721 param0 = f32[1280,1,128] parameter(0)
2722 constant.1 = f32[] constant(0)
2723 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2724 constant.3 = s32[] constant(0)
2725 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2726 }
2727
2728 fused_computation.2 {
2729 param0 = f32[1280,1,128] parameter(0)
2730 param1 = f32[1280,1,128] parameter(1)
2731 slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
2732 constant.3 = s32[] constant(0)
2733 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
2734 }
2735
2736 ENTRY main {
2737 param = f32[1280,1,128] parameter(0)
2738 negate = f32[1280,1,128] negate(param)
2739 add = f32[1280,1,128] add(negate, negate)
2740 fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
2741 ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
2742 }
2743 )";
2744 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2745 ParseAndReturnVerifiedModule(hlo_string));
2746 InsertCopies(module.get());
2747 EXPECT_EQ(CountCopies(*module), 1);
2748 }
2749
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceCopy)2750 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
2751 // Tests multi-output fusion with two DUS outputs, requiring two copies.
2752 absl::string_view hlo_string = R"(
2753 HloModule Module
2754
2755 fused_computation {
2756 param0 = f32[1280,1,128] parameter(0)
2757 param1 = f32[1280,1,128] parameter(1)
2758 param2 = f32[1280,1,128] parameter(2)
2759 constant.1 = f32[] constant(0)
2760 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2761 constant.3 = s32[] constant(0)
2762 add.1 = f32[1280,1,128] add(param0, param0)
2763 dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2764 dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2765 ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2766 }
2767
2768 ENTRY main {
2769 param = f32[1280,1,128] parameter(0)
2770 negate0 = f32[1280,1,128] negate(param)
2771 negate1 = f32[1280,1,128] negate(param)
2772 negate2 = f32[1280,1,128] negate(param)
2773 fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2774 gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2775 gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2776 gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2777 add0 = f32[1280,1,128] add(negate0, gte0)
2778 add1 = f32[1280,1,128] add(negate1, gte1)
2779 add2 = f32[1280,1,128] add(negate2, gte2)
2780 ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2781 }
2782 )";
2783 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2784 ParseAndReturnVerifiedModule(hlo_string));
2785 InsertCopies(module.get());
2786 EXPECT_EQ(CountCopies(*module), 2);
2787 }
2788
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceNoCopy)2789 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
2790 // Same as above, but negate1 is not used beyond fusion, so it only needs one
2791 // copy for negate0.
2792 absl::string_view hlo_string = R"(
2793 HloModule Module
2794
2795 fused_computation {
2796 param0 = f32[1280,1,128] parameter(0)
2797 param1 = f32[1280,1,128] parameter(1)
2798 param2 = f32[1280,1,128] parameter(2)
2799 constant.1 = f32[] constant(0)
2800 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2801 constant.3 = s32[] constant(0)
2802 add.1 = f32[1280,1,128] add(param0, param0)
2803 dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2804 dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2805 ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2806 }
2807
2808 ENTRY main {
2809 param = f32[1280,1,128] parameter(0)
2810 negate0 = f32[1280,1,128] negate(param)
2811 negate1 = f32[1280,1,128] negate(param)
2812 negate2 = f32[1280,1,128] negate(param)
2813 fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2814 gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2815 gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2816 gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2817 add0 = f32[1280,1,128] add(negate0, gte0)
2818 add1 = f32[1280,1,128] add(gte1, gte1)
2819 add2 = f32[1280,1,128] add(negate2, gte2)
2820 ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2821 }
2822 )";
2823 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2824 ParseAndReturnVerifiedModule(hlo_string));
2825 InsertCopies(module.get());
2826 EXPECT_EQ(CountCopies(*module), 1);
2827 }
2828
TEST_F(CopyInsertionTest,ScatterSharedOperand)2829 TEST_F(CopyInsertionTest, ScatterSharedOperand) {
2830 // If an in-place op has an additional operand that has the same value as the
2831 // in-place buffer, a copy needs to be inserted on one of these values only.
2832 absl::string_view hlo_string = R"(
2833 HloModule Module
2834
2835 update_s32 {
2836 lhs = s32[] parameter(0)
2837 ROOT rhs = s32[] parameter(1)
2838 }
2839
2840 fused_computation {
2841 iota.1 = s32[73729]{0} iota(), iota_dimension=0
2842 ROOT indices.1 = s32[73729]{0} reverse(iota.1), dimensions={0}
2843 }
2844
2845 ENTRY main {
2846 iota.2 = s32[73729]{0} iota(), iota_dimension=0
2847 fusion = s32[73729]{0} fusion(), kind=kLoop, calls=fused_computation
2848 ROOT scatter = s32[73729]{0} scatter(iota.2, fusion, iota.2), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=update_s32
2849 }
2850 )";
2851 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2852 ParseAndReturnVerifiedModule(hlo_string));
2853 InsertCopies(module.get());
2854 EXPECT_EQ(CountCopies(*module), 1);
2855 EXPECT_THAT(module->entry_computation()->root_instruction(),
2856 op::Scatter(op::Copy(op::Iota()), op::Fusion(), op::Iota()));
2857 }
2858
TEST_F(CopyInsertionTest,HorizontalLoopFusionNoCopy)2859 TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
2860 const std::string& hlo_string = R"(
2861 HloModule test
2862
2863 fused_computation {
2864 p0 = f32[10,20] parameter(0)
2865 p1 = f32[10,20] parameter(1)
2866 p2 = f32[10,10] parameter(2)
2867 p3 = f32[10,10] parameter(3)
2868 add0 = f32[10, 20] add(p0, p1)
2869 sub0 = f32[10, 10] subtract(p2, p3)
2870 reshape0 = f32[200] reshape(add0)
2871 reshape1 = f32[100] reshape(sub0)
2872 concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
2873 slice0 = f32[200] slice(concat0), slice={[0:200]}
2874 slice1 = f32[100] slice(concat0), slice={[200:300]}
2875 ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
2876 }
2877
2878 ENTRY test {
2879 p0 = f32[10,20] parameter(0)
2880 p1 = f32[10,20] parameter(1)
2881 p2 = f32[10,10] parameter(2)
2882 p3 = f32[10,10] parameter(3)
2883 fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
2884 gte0 = f32[200] get-tuple-element(fusion), index=0
2885 gte1 = f32[100] get-tuple-element(fusion), index=1
2886 bitcast0 = f32[10,20] bitcast(gte0)
2887 bitcast1 = f32[10,10] bitcast(gte1)
2888 ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
2889 }
2890 )";
2891
2892 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2893 ParseAndReturnVerifiedModule(hlo_string));
2894 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2895 /*output_index=*/{0},
2896 /*param_number=*/0,
2897 /*param_index=*/{}));
2898 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2899 /*output_index=*/{1},
2900 /*param_number=*/3,
2901 /*param_index=*/{}));
2902
2903 InsertCopies(module.get());
2904
2905 // There should be no copies inserted.
2906 EXPECT_EQ(CountCopies(*module), 0);
2907 }
2908
TEST_F(CopyInsertionTest,NestedWhileAndConditional3)2909 TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
2910 const std::string& hlo_string = R"(
2911 HloModule TestModule
2912
2913 on_true.1
2914 {
2915 ROOT v1 = f32[2] parameter(0)
2916 }
2917
2918 on_false.1
2919 {
2920 v1 = f32[2] parameter(0)
2921 ROOT v2 = f32[2] multiply(v1,v1)
2922 }
2923
2924 on_true
2925 {
2926 v1 = f32[2] parameter(0)
2927 v2 = f32[2] add(v1,v1)
2928 v3 = (f32[2],f32[2]) tuple(v1,v2)
2929 v4 = f32[2] get-tuple-element(v3), index=1
2930 v5 = f32[2] multiply(v4,v2)
2931 ROOT t1 = (f32[2], f32[2]) tuple(v5,v2)
2932 }
2933
2934 on_false
2935 {
2936 v1 = f32[2] parameter(0)
2937 v2 = f32[2] multiply(v1,v1)
2938 pred.1 = pred[] constant(true)
2939 v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1
2940 v5 = f32[2] multiply(v4,v2)
2941 ROOT t2 = (f32[2], f32[2]) tuple(v2,v5)
2942
2943 }
2944
2945 cond.outer {
2946 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2947 ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2948 }
2949
2950 body.outer {
2951 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2952 pred.1 = pred[] get-tuple-element(param.1), index=0
2953 arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2954 if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2955 e1 = f32[2] get-tuple-element(if), index=0
2956 e2 = f32[2] get-tuple-element(if), index=1
2957 ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2958 }
2959
2960 ENTRY TestComputation {
2961 entry_param.1 = pred[] parameter(0)
2962 float_param = f32[2] parameter(1)
2963 entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2964 ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2965 }
2966 )";
2967 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2968 ParseAndReturnVerifiedModule(hlo_string));
2969 InsertCopies(module.get());
2970 // An extra copy must be kept inside the loop due to uses in the conditional
2971 EXPECT_EQ(CountCopies(*module), 4);
2972 }
2973
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy1)2974 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy1) {
2975 const std::string& hlo_string = R"(
2976 HloModule TestModule
2977
2978 branch_0_comp.5.clone {
2979 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
2980 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
2981 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
2982 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
2983 ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
2984 }
2985
2986 branch_1_comp.12.clone {
2987 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
2988 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
2989 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
2990 ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
2991 }
2992
2993 ENTRY TestComputation {
2994 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
2995 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
2996 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
2997 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
2998 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
2999 %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3000 %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3001 ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3002 }
3003 )";
3004 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3005 ParseAndReturnVerifiedModule(hlo_string));
3006 InsertCopies(module.get());
3007 CopyInsertion copy_insertion(nullptr,
3008 /*use_region_based_live_range_analysis=*/-1);
3009 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3010 VLOG(3) << module->ToString();
3011 // The copy.1 must be kept due to modification in the other branch.
3012 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3013 CHECK_NE(conditional18, nullptr);
3014 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3015 CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3016 auto copy1 = tuple6->operand(0);
3017 CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3018 }
3019
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy2)3020 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy2) {
3021 const std::string& hlo_string = R"(
3022 HloModule TestModule
3023
3024 branch_0_comp.5.clone {
3025 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3026 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3027 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3028 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3029 ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3030 }
3031
3032 branch_1_comp.12.clone {
3033 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3034 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3035 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3036 %constant.1 = s32[] constant(0)
3037 %broadcast.6 = s32[2] broadcast(constant.1), dimensions={}
3038 dynamic-update-slice.5 = s32[2]{0:T(128)} dynamic-update-slice(%copy.1, %broadcast.6, %constant.1)
3039 %add.1 = s32[2]{0:T(128)} add(dynamic-update-slice.5, %copy.1)
3040 ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%add.1)
3041 }
3042
3043 ENTRY TestComputation {
3044 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3045 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3046 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3047 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3048 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3049 %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}
3050 %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3051 ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3052 }
3053 )";
3054 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3055 ParseAndReturnVerifiedModule(hlo_string));
3056 CopyInsertion copy_insertion(nullptr,
3057 /*use_region_based_live_range_analysis=*/-1);
3058 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3059 // The copy.1 must be kept due to modification in the other branch.
3060 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3061 CHECK_NE(conditional18, nullptr);
3062 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3063 CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3064 auto add1 = tuple6->operand(0);
3065 CHECK_EQ(add1->opcode(), HloOpcode::kAdd);
3066 auto dus = add1->operand(0);
3067 auto copy1 = dus->operand(0);
3068 CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3069 }
3070
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy3)3071 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy3) {
3072 const std::string& hlo_string = R"(
3073 HloModule primitive_computation_cond.19
3074 %branch_0_comp.5.clone (parameter.0: (s32[2])) -> (s32[2]) {
3075 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3076 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3077 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3078 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3079 ROOT %tuple.5 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy)
3080 }
3081
3082 %branch_1_comp.12.clone (parameter.4: (s32[2])) -> (s32[2]) {
3083 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3084 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3085 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3086 ROOT %tuple.6 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy.1)
3087 }
3088
3089 ENTRY %primitive_computation_cond.19 (parameter.1: s32[], parameter.2: s32[2], parameter.3: s32[2]) -> (s32[2]) {
3090 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3091 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3092 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3093 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3094 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3095 ROOT %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3096 }
3097 )";
3098 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3099 ParseAndReturnVerifiedModule(hlo_string));
3100 InsertCopies(module.get());
3101 CopyInsertion copy_insertion(nullptr,
3102 /*use_region_based_live_range_analysis=*/-1);
3103 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3104 VLOG(3) << module->ToString();
3105 // The copy.1 must be kept b/c aliasing of parameter and root is not allowed.
3106 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3107 CHECK_NE(conditional18, nullptr);
3108 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3109 CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3110 auto copy1 = tuple6->operand(0);
3111 CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3112 }
3113
TEST_F(CopyInsertionTest,ConditionalBranchDoNotCopy1)3114 TEST_F(CopyInsertionTest, ConditionalBranchDoNotCopy1) {
3115 const std::string& hlo_string = R"(
3116 HloModule TestModule
3117
3118 branch_0_comp.5.clone {
3119 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3120 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3121 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3122 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3123 ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3124 }
3125
3126 branch_1_comp.12.clone {
3127 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3128 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3129 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3130 ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
3131 }
3132
3133 ENTRY TestComputation {
3134 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3135 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3136 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3137 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3138 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3139 %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3140 %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3141 ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(gte.1, gte.1)
3142 }
3143 )";
3144 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3145 ParseAndReturnVerifiedModule(hlo_string));
3146 CopyInsertion copy_insertion(nullptr,
3147 /*use_region_based_live_range_analysis=*/-1);
3148 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3149 VLOG(3) << module->ToString() << "\n";
3150
3151 // The copy.1 must be kept due to modification in the other branch.
3152 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3153 CHECK_NE(conditional18, nullptr);
3154 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3155 CHECK_EQ(tuple6->opcode(), HloOpcode::kParameter);
3156 }
3157
TEST_F(CopyInsertionTest,RootInstructionNotLast)3158 TEST_F(CopyInsertionTest, RootInstructionNotLast) {
3159 // This is a test for b/189219227. When the root instruction is scheduled not
3160 // as the last instruction, it still lives out. So, we make sure that the copy
3161 // after the root cannot be removed.
3162 const std::string& hlo_string = R"(
3163 HloModule module, is_scheduled=true
3164
3165 body2 {
3166 p_body2 = (f32[2]{0}) parameter(0)
3167 p_body2.1 = f32[2]{0} get-tuple-element(p_body2), index=0
3168 add.3 = f32[2]{0} add(p_body2.1, p_body2.1)
3169 ROOT root2 = (f32[2]{0}) tuple(add.3)
3170 }
3171
3172 condition2 {
3173 p_cond2 = (f32[2]{0}) parameter(0)
3174 ROOT result = pred[] constant(true)
3175 }
3176
3177 body {
3178 p_body = (f32[2]{0}) parameter(0)
3179 p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
3180 ROOT root = (f32[2]{0}) tuple(p_body.1)
3181 copy = f32[2]{0} copy(p_body.1)
3182 tuple = (f32[2]{0}) tuple(copy)
3183 while.1 = (f32[2]{0}) while(tuple), condition=condition2, body=body2
3184 }
3185
3186 condition {
3187 p_cond = (f32[2]{0}) parameter(0)
3188 ROOT result = pred[] constant(true)
3189 }
3190
3191 ENTRY entry {
3192 const0 = f32[2]{0} constant({1, 2})
3193 while_init = (f32[2]{0}) tuple(const0)
3194 ROOT while.0 = (f32[2]{0}) while(while_init), condition=condition, body=body
3195 }
3196 )";
3197 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3198 ParseAndReturnVerifiedModule(hlo_string));
3199 CopyInsertion copy_insertion(nullptr,
3200 /*use_region_based_live_range_analysis=*/-1);
3201 SequentialHloOrdering ordering(module->schedule());
3202 ASSERT_IS_OK(copy_insertion.RemoveUnnecessaryCopies(&ordering, module.get()));
3203 auto while_1 = FindInstruction(module.get(), "while.1");
3204 EXPECT_THAT(while_1, op::While(op::Tuple(op::Copy())));
3205 }
3206
TEST_F(CopyInsertionTest,InPlaceCollectivePermuteCopy)3207 TEST_F(CopyInsertionTest, InPlaceCollectivePermuteCopy) {
3208 absl::string_view hlo_string = R"(
3209 HloModule hlo_runner_test_0.1
3210 ENTRY hlo_runner_test_0.1 {
3211 replica_id = u32[] replica-id()
3212 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
3213 constant.1 = u32[] constant(1000)
3214 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3215 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3216 constant.2 = s32[] constant(0)
3217 constant.3 = s32[] constant(1)
3218 tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
3219 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
3220 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
3221 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
3222 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
3223 constant.4 = s32[] constant(2)
3224 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
3225 tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
3226 tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
3227 tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3228 tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
3229 tuple.10 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3230 collective-permute.0 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3231 collective-permute.1 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.10), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3232 ROOT tuple = ((u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}), (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})) tuple(collective-permute.0, collective-permute.1)
3233 }
3234 )";
3235 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3236 ParseAndReturnVerifiedModule(hlo_string));
3237 InsertCopies(module.get());
3238 EXPECT_EQ(CountCopies(*module), 4);
3239 }
3240
TEST_F(CopyInsertionTest,KeepCopyOfBroadcast)3241 TEST_F(CopyInsertionTest, KeepCopyOfBroadcast) {
3242 absl::string_view hlo_string = R"(
3243 HloModule Module
3244
3245 ENTRY main {
3246 param = f32[128,1,128] parameter(0)
3247 negate = f32[128,1,128] negate(param)
3248 constant.1 = f32[] constant(0)
3249 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
3250 broadcast.7 = f32[128,1,128] broadcast(constant.1), dimensions={}
3251 constant.3 = s32[] constant(0)
3252 dynamic-update-slice.5 = f32[128,1,128] dynamic-update-slice(broadcast.6, broadcast.7, constant.3, constant.3, constant.3)
3253 add1 = f32[128,1,128] add(dynamic-update-slice.5, dynamic-update-slice.5)
3254 dynamic-update-slice.4 = f32[128,1,128] dynamic-update-slice(broadcast.6, broadcast.7, constant.3, constant.3, constant.3)
3255 add2 = f32[128,1,128] add(dynamic-update-slice.4, dynamic-update-slice.4)
3256 tuple = (f32[128,1,128], f32[128,1,128]) tuple(add1, add2)
3257 }
3258 )";
3259 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3260 ParseAndReturnVerifiedModule(hlo_string));
3261 CopyInsertion copy_insertion(nullptr,
3262 /*use_region_based_live_range_analysis=*/-1);
3263 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3264 EXPECT_EQ(CountCopies(*module), 2);
3265 }
3266
TEST_F(CopyInsertionTest,CustomCallAliasingCopyInsertedAliasedParam)3267 TEST_F(CopyInsertionTest, CustomCallAliasingCopyInsertedAliasedParam) {
3268 // The custom call specifies aliasing for an operand that is an input to the
3269 // computation, but it does not own that buffer so a precautionary copy
3270 // must be inserted.
3271 const char* const kModuleString = R"(
3272 HloModule xla_computation_f
3273
3274 ENTRY xla_computation_f {
3275 parameter.1 = f32[2,3,4,5] parameter(0)
3276 parameter.2 = f32[2,3,4,5] parameter(1)
3277 ROOT custom-call = f32[2,3,4,5] custom-call(parameter.1, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
3278 }
3279 )";
3280
3281 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3282 ParseAndReturnUnverifiedModule(kModuleString));
3283 InsertCopies(module.get());
3284 HloInstruction* custom_call = module->entry_computation()->root_instruction();
3285 EXPECT_THAT(custom_call->operand(0), op::Copy(op::Parameter(0)));
3286 }
3287
TEST_F(CopyInsertionTest,CustomCallAliasingCopyInsertedAliasedReuse)3288 TEST_F(CopyInsertionTest, CustomCallAliasingCopyInsertedAliasedReuse) {
3289 // The custom call specifies aliasing for an operand that is later re-used
3290 // by a different instruction (add.2) A copy must be inserted so the correct
3291 // HloValue is passed to the add, and not the result of the aliased call.
3292 const char* const kModuleString = R"(
3293 HloModule xla_computation_f
3294
3295 ENTRY xla_computation_f {
3296 parameter.1 = f32[2,3,4,5] parameter(0)
3297 parameter.2 = f32[2,3,4,5] parameter(1)
3298 add.1 = f32[2,3,4,5] add(parameter.1, parameter.2)
3299 custom-call = f32[2,3,4,5] custom-call(add.1, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
3300 ROOT add.2 = f32[2,3,4,5] add(custom-call, add.1)
3301 }
3302 )";
3303
3304 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3305 ParseAndReturnUnverifiedModule(kModuleString));
3306
3307 InsertCopies(module.get());
3308 HloInstruction* custom_call = FindInstruction(module.get(), "custom-call");
3309 CHECK_NE(custom_call, nullptr);
3310 EXPECT_THAT(custom_call->operand(0), op::Copy(op::Add()));
3311 }
3312
TEST_F(CopyInsertionTest,CustomCallAliasingCopyRemoved)3313 TEST_F(CopyInsertionTest, CustomCallAliasingCopyRemoved) {
3314 // This custom call aliases an intermediate result, and the value is never
3315 // reused. There is no need for a copy.
3316 const char* const kModuleString = R"(
3317 HloModule xla_computation_f__1
3318 ENTRY xla_computation_f {
3319 parameter.1 = f32[2,3,4,5] parameter(0)
3320 parameter.2 = f32[2,3,4,5] parameter(1)
3321 add = f32[2,3,4,5] add(parameter.1, parameter.2)
3322 ROOT custom-call = f32[2,3,4,5] custom-call(add, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
3323 }
3324 )";
3325
3326 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3327 ParseAndReturnUnverifiedModule(kModuleString));
3328
3329 InsertCopies(module.get());
3330 HloInstruction* custom_call = module->entry_computation()->root_instruction();
3331 EXPECT_THAT(custom_call->operand(0), op::Add());
3332 }
3333
TEST_F(CopyInsertionTest,ReverseInConditional)3334 TEST_F(CopyInsertionTest, ReverseInConditional) {
3335 const char* const kModuleString = R"(
3336 HloModule jit_f.0
3337
3338 %region_0.4 (Arg_.5: u8[300,451,3]) -> (u8[300,451,3]) {
3339 %Arg_.5 = u8[300,451,3]{1,0,2:T(8,128)(4,1)} parameter(0)
3340 ROOT %tuple = (u8[300,451,3]{1,0,2:T(8,128)(4,1)}) tuple(u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_.5)
3341 }
3342
3343 %region_1.9 (Arg_.10: u8[300,451,3]) -> (u8[300,451,3]) {
3344 %Arg_.10 = u8[300,451,3]{1,0,2:T(8,128)(4,1)} parameter(0)
3345 %reverse = u8[300,451,3]{1,0,2:T(8,128)(4,1)} reverse(u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_.10), dimensions={0}
3346 ROOT %tuple.1 = (u8[300,451,3]{1,0,2:T(8,128)(4,1)}) tuple(u8[300,451,3]{1,0,2:T(8,128)(4,1)} %reverse)
3347 }
3348
3349 ENTRY %main.13 (Arg_0.1: pred[], Arg_1.2: u8[300,451,3]) -> u8[300,451,3] {
3350 %Arg_0.1 = pred[]{:T(1024)} parameter(0)
3351 %convert.3 = s32[]{:T(256)} convert(pred[]{:T(1024)} %Arg_0.1)
3352 %Arg_1.2 = u8[300,451,3]{1,0,2:T(8,128)(4,1)} parameter(1)
3353 %conditional.12.clone = (u8[300,451,3]{1,0,2:T(8,128)(4,1)}) conditional(s32[]{:T(256)} %convert.3, u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_1.2, u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_1.2), branch_computations={%region_0.4, %region_1.9}
3354 ROOT %get-tuple-element = u8[300,451,3]{1,0,2:T(8,128)(4,1)} get-tuple-element((u8[300,451,3]{1,0,2:T(8,128)(4,1)}) %conditional.12.clone), index=0
3355 }
3356 )";
3357
3358 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3359 ParseAndReturnUnverifiedModule(kModuleString));
3360
3361 CopyInsertion copy_insertion(nullptr,
3362 /*use_region_based_live_range_analysis=*/-1);
3363 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3364 VLOG(2) << module->ToString();
3365 HloInstruction* reverse = FindInstruction(module.get(), "reverse");
3366 EXPECT_THAT(reverse->operand(0), op::Copy());
3367 }
3368
TEST_F(CopyInsertionTest,InputOutputAliasCopy)3369 TEST_F(CopyInsertionTest, InputOutputAliasCopy) {
3370 const char* const kModuleString = R"(
3371 HloModule main_tf2xla.11, input_output_alias={ {0}: (0, {1}, may-alias) }
3372
3373 ENTRY %main_tf2xla.11 (arg_tuple.1: (f32[], f32[])) -> (f32[], f32[]) {
3374 ROOT %arg_tuple.1 = (f32[]{:T(256)}, f32[]{:T(256)}) parameter(0), parameter_replication={false,false}, sharding={{replicated}, {replicated}}
3375 }
3376 )";
3377
3378 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3379 ParseAndReturnUnverifiedModule(kModuleString));
3380
3381 CopyInsertion copy_insertion(nullptr,
3382 /*use_region_based_live_range_analysis=*/-1);
3383 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3384 VLOG(2) << module->ToString();
3385 }
3386
3387 } // namespace
3388 } // namespace xla
3389