xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/copy_insertion_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 
18 #include <set>
19 
20 #include "tensorflow/compiler/xla/debug_options_flags.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 
35 namespace op = xla::testing::opcode_matchers;
36 
37 namespace xla {
38 namespace {
39 
40 using ::testing::UnorderedElementsAre;
41 
CountCopies(const HloComputation & computation)42 int64_t CountCopies(const HloComputation& computation) {
43   int64_t count = 0;
44   for (const auto& instruction : computation.instructions()) {
45     if (instruction->opcode() == HloOpcode::kCopy) {
46       count++;
47     }
48   }
49   return count;
50 }
51 
CountCopies(const HloModule & module)52 int64_t CountCopies(const HloModule& module) {
53   int64_t count = 0;
54   for (const auto& computation : module.computations()) {
55     count += CountCopies(*computation);
56   }
57   return count;
58 }
59 
CountControlEdges(const HloComputation & computation)60 int64_t CountControlEdges(const HloComputation& computation) {
61   int64_t count = 0;
62   for (const auto& instruction : computation.instructions()) {
63     count += instruction->control_successors().size();
64   }
65   return count;
66 }
67 
CountControlEdges(const HloModule & module)68 int64_t CountControlEdges(const HloModule& module) {
69   int64_t count = 0;
70   for (const auto& computation : module.computations()) {
71     count += CountControlEdges(*computation);
72   }
73   return count;
74 }
75 
76 class CopyInsertionTest : public HloTestBase {
77  protected:
InsertCopies(HloModule * module)78   void InsertCopies(HloModule* module) {
79     CopyInsertion copy_insertion;
80     VLOG(3) << "Before copy inser: " << module->ToString();
81     ASSERT_IS_OK(copy_insertion.Run(module).status());
82     VLOG(2) << "After copy inser: " << module->ToString();
83   }
84 
85   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
86 };
87 
TEST_F(CopyInsertionTest,SingleParameter)88 TEST_F(CopyInsertionTest, SingleParameter) {
89   // Computation is a single parameter passed into a tuple. The parameter should
90   // be copied before entering the tuple.
91   auto builder = HloComputation::Builder(TestName());
92   HloInstruction* x = builder.AddInstruction(
93       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
94   HloInstruction* tuple =
95       builder.AddInstruction(HloInstruction::CreateTuple({x}));
96 
97   EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
98 
99   auto module = CreateNewVerifiedModule();
100   module->AddEntryComputation(builder.Build());
101 
102   InsertCopies(module.get());
103 
104   EXPECT_THAT(module->entry_computation()->root_instruction(),
105               op::Tuple(op::Copy(x)));
106 }
107 
TEST_F(CopyInsertionTest,SingleConstant)108 TEST_F(CopyInsertionTest, SingleConstant) {
109   // Computation is a single constant passed into a tuple. The parameter should
110   // be copied before entering the tuple.
111   auto builder = HloComputation::Builder(TestName());
112   HloInstruction* constant = builder.AddInstruction(
113       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
114   HloInstruction* tuple =
115       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
116 
117   EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
118 
119   auto module = CreateNewVerifiedModule();
120   module->AddEntryComputation(builder.Build());
121 
122   InsertCopies(module.get());
123   EXPECT_EQ(CountCopies(*module), 1);
124 
125   EXPECT_THAT(module->entry_computation()->root_instruction(),
126               op::Tuple(op::Copy(constant)));
127 }
128 
TEST_F(CopyInsertionTest,ExistingCopiesNotRemoved)129 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
130   // Verify that kCopy instructions which change layout and exist before
131   // copy-insertion remain in the graph after copy-insertion.
132   auto module = CreateNewVerifiedModule();
133 
134   auto builder = HloComputation::Builder(TestName());
135   HloInstruction* constant =
136       builder.AddInstruction(HloInstruction::CreateConstant(
137           LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
138   auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
139   Layout reversed_layout =
140       LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
141   Shape copy_shape = constant->shape();
142   *copy_shape.mutable_layout() = reversed_layout;
143   HloInstruction* copy_1 = builder.AddInstruction(
144       HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
145   HloInstruction* copy_2 = builder.AddInstruction(
146       HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
147   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
148       constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
149   builder.AddInstruction(
150       HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
151 
152   module->AddEntryComputation(builder.Build());
153 
154   EXPECT_EQ(CountCopies(*module), 3);
155 
156   InsertCopies(module.get());
157 
158   EXPECT_EQ(CountCopies(*module), 2);
159 
160   EXPECT_EQ(module->entry_computation()->root_instruction(), add);
161   EXPECT_THAT(module->entry_computation()->root_instruction(),
162               op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
163 }
164 
TEST_F(CopyInsertionTest,MultipleConstantsAndParameters)165 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
166   // Create a computation with more than one constant and parameter. Only one of
167   // each constant/parameter is pointed to by the output tuple. Only these
168   // instructions should be copied.
169   auto builder = HloComputation::Builder(TestName());
170 
171   HloInstruction* constant1 = builder.AddInstruction(
172       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
173   HloInstruction* constant2 = builder.AddInstruction(
174       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
175 
176   HloInstruction* x = builder.AddInstruction(
177       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
178   HloInstruction* y = builder.AddInstruction(
179       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
180 
181   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
182       ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
183 
184   builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
185 
186   auto module = CreateNewVerifiedModule();
187   module->AddEntryComputation(builder.Build());
188 
189   InsertCopies(module.get());
190   EXPECT_EQ(CountCopies(*module), 2);
191 
192   EXPECT_THAT(
193       module->entry_computation()->root_instruction(),
194       op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
195 }
196 
TEST_F(CopyInsertionTest,BitcastParameter)197 TEST_F(CopyInsertionTest, BitcastParameter) {
198   // The output of a bitcast is its operand (same buffer), so a bitcast
199   // parameter feeding the result must have a copy added.
200   auto builder = HloComputation::Builder(TestName());
201   HloInstruction* x = builder.AddInstruction(
202       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
203   HloInstruction* bitcast = builder.AddInstruction(
204       HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
205 
206   auto module = CreateNewVerifiedModule();
207   module->AddEntryComputation(builder.Build());
208 
209   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
210 
211   HloInstruction* old_root = module->entry_computation()->root_instruction();
212   InsertCopies(module.get());
213   EXPECT_EQ(CountCopies(*module), 1);
214 
215   EXPECT_THAT(module->entry_computation()->root_instruction(),
216               op::Copy(old_root));
217 }
218 
TEST_F(CopyInsertionTest,BitcastConstant)219 TEST_F(CopyInsertionTest, BitcastConstant) {
220   // The output of a bitcast is its operand (same buffer), so a bitcast
221   // constant feeding the result must have a copy added.
222   auto builder = HloComputation::Builder(TestName());
223   HloInstruction* constant =
224       builder.AddInstruction(HloInstruction::CreateConstant(
225           LiteralUtil::CreateR1<float>({1.0, 42.0})));
226   HloInstruction* bitcast = builder.AddInstruction(
227       HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2}), constant));
228 
229   auto module = CreateNewVerifiedModule();
230   module->AddEntryComputation(builder.Build());
231 
232   EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
233 
234   HloInstruction* old_root = module->entry_computation()->root_instruction();
235   InsertCopies(module.get());
236   EXPECT_EQ(CountCopies(*module), 1);
237 
238   EXPECT_THAT(module->entry_computation()->root_instruction(),
239               op::Copy(old_root));
240 }
241 
TEST_F(CopyInsertionTest,BitcastTupleElementParameter)242 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
243   // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
244   auto builder = HloComputation::Builder(TestName());
245   HloInstruction* x = builder.AddInstruction(
246       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
247   HloInstruction* bitcast = builder.AddInstruction(
248       HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
249   builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
250 
251   auto module = CreateNewVerifiedModule();
252   module->AddEntryComputation(builder.Build());
253 
254   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
255 
256   InsertCopies(module.get());
257   EXPECT_EQ(CountCopies(*module), 1);
258 
259   EXPECT_THAT(module->entry_computation()->root_instruction(),
260               op::Tuple(op::Copy(bitcast)));
261 }
262 
TEST_F(CopyInsertionTest,NestedTupleParameter)263 TEST_F(CopyInsertionTest, NestedTupleParameter) {
264   // Construct a trivial computation where the root of the computation is a
265   // nested tuple-shaped parameter. The parameter should be deep copied and the
266   // copy should be the root of the computation.
267   auto builder = HloComputation::Builder(TestName());
268 
269   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
270   builder.AddInstruction(HloInstruction::CreateParameter(
271       0,
272       ShapeUtil::MakeTupleShape(
273           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
274                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
275            ShapeUtil::MakeShape(F32, {42})}),
276       "param0"));
277 
278   auto module = CreateNewVerifiedModule();
279   module->AddEntryComputation(builder.Build());
280 
281   EXPECT_EQ(HloOpcode::kParameter,
282             module->entry_computation()->root_instruction()->opcode());
283 
284   HloInstruction* old_root = module->entry_computation()->root_instruction();
285   InsertCopies(module.get());
286   EXPECT_EQ(CountCopies(*module), 3);
287 
288   HloInstruction* new_root = module->entry_computation()->root_instruction();
289   EXPECT_NE(old_root, new_root);
290 
291   EXPECT_THAT(
292       new_root,
293       op::Tuple(
294           op::Tuple(
295               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))),
296               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))),
297           op::Copy(op::GetTupleElement(old_root))));
298 }
299 
TEST_F(CopyInsertionTest,ElementOfNestedTupleParameter)300 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
301   // Construct a computation where the root of the computation is a tuple
302   // element of a nested tuple-shaped parameter.
303   auto builder = HloComputation::Builder(TestName());
304 
305   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
306   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
307       0,
308       ShapeUtil::MakeTupleShape(
309           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
310                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
311            ShapeUtil::MakeShape(F32, {42})}),
312       "param0"));
313 
314   // The return value of the computation is the zero-th element of the nested
315   // tuple. This element is itself a tuple.
316   auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
317       ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
318 
319   auto module = CreateNewVerifiedModule();
320   module->AddEntryComputation(builder.Build());
321 
322   EXPECT_EQ(gte, module->entry_computation()->root_instruction());
323 
324   InsertCopies(module.get());
325   EXPECT_EQ(CountCopies(*module), 2);
326 
327   EXPECT_THAT(
328       module->entry_computation()->root_instruction(),
329       op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
330                 op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
331 }
332 
333 class WhileCopyInsertionTest : public CopyInsertionTest {
334  protected:
WhileCopyInsertionTest()335   WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {}
336 
337   // Builds a While condition computation which reads the induction variable
338   // from the tuple parameter, and returns a predicate indicating whether this
339   // value is less than the constant '10'.
340   // The parameter 'nested' specifies the loop state shape from which to
341   // read the induction variable.
BuildConditionComputation(const Shape & loop_state_shape)342   std::unique_ptr<HloComputation> BuildConditionComputation(
343       const Shape& loop_state_shape) {
344     auto builder = HloComputation::Builder(TestName() + ".Condition");
345     auto limit_const = builder.AddInstruction(
346         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(10)));
347     auto loop_state = builder.AddInstruction(
348         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
349     auto induction_variable =
350         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
351             limit_const->shape(), loop_state, 0));
352     builder.AddInstruction(HloInstruction::CreateCompare(
353         condition_result_shape_, induction_variable, limit_const,
354         ComparisonDirection::kLt));
355     return builder.Build();
356   }
357 
358   // Builds a While body computation with one output tuple element dependent on
359   // both input tuple elements.
360   // EX:
361   // Body({in0, in1})
362   //   out0 = Add(in0, 1)
363   //   out1 = Add(BCast(in0), in1)
364   //   Tuple(out0, out1)
BuildDependentBodyComputation()365   std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
366     auto builder = HloComputation::Builder(TestName() + ".Body");
367     // Create param instruction to access loop state.
368     auto loop_state = builder.AddInstruction(
369         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
370     // Update the induction variable GTE(0).
371     auto induction_variable =
372         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
373             induction_variable_shape_, loop_state, 0));
374     auto inc = builder.AddInstruction(
375         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
376     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
377         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
378     // Update data GTE(1).
379     auto data = builder.AddInstruction(
380         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
381     // Use 'induction_variable' in computation with no path to output tuple.
382     Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
383     auto convert = builder.AddInstruction(
384         HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
385     auto update = builder.AddInstruction(
386         HloInstruction::CreateBroadcast(data_shape_, convert, {}));
387     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
388         data_shape_, HloOpcode::kAdd, data, update));
389     // Create output Tuple.
390     builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
391     return builder.Build();
392   }
393 
394   // Builds a While body computation with two output tuple elements dependent on
395   // both input tuple elements.
396   //
397   // EX: Body({in0, in1, in2})
398   //   out0 = Add(in0, 1)
399   //   out1 = in1
400   //   out2 = in2
401   //   Tuple(out0, out1, out2)
BuildDependentBodyComputation2()402   std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
403     auto builder = HloComputation::Builder(TestName() + ".Body");
404 
405     const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
406         {induction_variable_shape_, data_shape_, data_shape_});
407 
408     auto loop_state = builder.AddInstruction(
409         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
410 
411     // Update the induction variable GTE(0).
412     auto induction_variable =
413         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
414             induction_variable_shape_, loop_state, 0));
415     auto inc = builder.AddInstruction(
416         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
417 
418     // add0 = Add(in0, 1)
419     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
420         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
421     // data1 = GTE(1).
422     HloInstruction* data1 = builder.AddInstruction(
423         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
424 
425     // data2 = GTE(2).
426     HloInstruction* data2 = builder.AddInstruction(
427         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
428 
429     // Create output Tuple.
430     builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
431 
432     return builder.Build();
433   }
434 
435   // Builds a While body computation with read-only tuple element 0.
436   // EX:
437   // Body({in0, in1})
438   //   out0 = in0
439   //   out1 = Add(BCast(in0), in1)
440   //   Tuple(out0, out1)
BuildDependentBodyOneReadOnlyComputation()441   std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
442     auto builder = HloComputation::Builder(TestName() + ".Body");
443     // Create param instruction to access loop state.
444     auto loop_state = builder.AddInstruction(
445         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
446     // Update the induction variable GTE(0).
447     auto induction_variable =
448         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
449             induction_variable_shape_, loop_state, 0));
450     // Update data GTE(1).
451     auto data = builder.AddInstruction(
452         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
453 
454     // Use 'induction_variable' in computation with no path to output tuple.
455     Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
456     auto convert = builder.AddInstruction(
457         HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
458     auto update = builder.AddInstruction(
459         HloInstruction::CreateBroadcast(data_shape_, convert, {}));
460     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
461         data_shape_, HloOpcode::kAdd, data, update));
462     // Create output Tuple.
463     builder.AddInstruction(
464         HloInstruction::CreateTuple({induction_variable, add1}));
465     return builder.Build();
466   }
467 
468   // Builds a While body computation with independent outputs.
469   // EX:
470   // Body({in0, in1})
471   //   out0 = Add(in0, 1)
472   //   out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
473   //   Tuple(out0, out1)
BuildIndependentBodyComputation(bool nested=false)474   std::unique_ptr<HloComputation> BuildIndependentBodyComputation(
475       bool nested = false) {
476     auto builder = HloComputation::Builder(TestName() + ".Body");
477     // Create param instruction to access loop state.
478     const Shape& loop_state_shape =
479         nested ? nested_loop_state_shape_ : loop_state_shape_;
480 
481     auto loop_state = builder.AddInstruction(
482         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
483     // Update the induction variable GTE(0).
484     auto induction_variable =
485         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
486             induction_variable_shape_, loop_state, 0));
487     auto inc = builder.AddInstruction(
488         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
489     // add0 = Add(in0, 1)
490     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
491         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
492     // Update data GTE(1).
493     HloInstruction* data = nullptr;
494     if (nested) {
495       data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
496           nested_tuple_shape_, loop_state, 1));
497       data = builder.AddInstruction(
498           HloInstruction::CreateGetTupleElement(data_shape_, data, 0));
499     } else {
500       data = builder.AddInstruction(
501           HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
502     }
503     auto update = builder.AddInstruction(
504         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
505             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
506     // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
507     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
508         data_shape_, HloOpcode::kAdd, data, update));
509     // Create output Tuple.
510     if (nested) {
511       auto nested_tuple =
512           builder.AddInstruction(HloInstruction::CreateTuple({add1, add1}));
513       builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple}));
514     } else {
515       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
516     }
517     return builder.Build();
518   }
519 
520   // Builds a While body computation with the following nested tuple
521   // sub-computation:
522   //                            |
523   //                    GTE(loop_state, 1)
524   //                       /           \
525   // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
526   //           |                              |
527   //          Add                           Reverse
528   //           |                              |
BuildNestedBodyComputation()529   std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
530     auto builder = HloComputation::Builder(TestName() + ".Body");
531     // Create param instruction to access loop state.
532     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
533         0, nested_loop_state_shape_, "loop_state"));
534     // Update GTE(0).
535     auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
536         induction_variable_shape_, loop_state, 0));
537     auto inc = builder.AddInstruction(
538         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
539     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
540         gte0->shape(), HloOpcode::kAdd, gte0, inc));
541 
542     // GTE(loop_state, 1)
543     auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
544         nested_tuple_shape_, loop_state, 1));
545     // GTE(GTE(loop_state, 1), 0) -> Add
546     auto gte10 = builder.AddInstruction(
547         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
548     auto update10 = builder.AddInstruction(
549         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
550             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
551     auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
552         data_shape_, HloOpcode::kAdd, gte10, update10));
553 
554     // GTE(GTE(loop_state, 1), 1) -> Reverse
555     auto gte11 = builder.AddInstruction(
556         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
557     auto rev11 = builder.AddInstruction(
558         HloInstruction::CreateReverse(data_shape_, gte11, {0}));
559 
560     // Create output Tuple.
561     auto inner_tuple =
562         builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
563     builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
564     return builder.Build();
565   }
566 
567   // Builds a While instruction using 'condition' and 'body' sub-computations.
568   // Init operand is initialized to zeros of appropriate shape.
BuildWhileInstruction(HloComputation * condition,HloComputation * body,bool nested=false)569   HloInstruction* BuildWhileInstruction(HloComputation* condition,
570                                         HloComputation* body,
571                                         bool nested = false) {
572     auto builder = HloComputation::Builder(TestName() + ".While");
573     auto induction_var_init = builder.AddInstruction(
574         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
575 
576     auto data_init = builder.AddInstruction(
577         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
578             {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
579 
580     if (nested) {
581       auto inner_init = builder.AddInstruction(
582           HloInstruction::CreateTuple({data_init, data_init}));
583       auto loop_state_init = builder.AddInstruction(
584           HloInstruction::CreateTuple({induction_var_init, inner_init}));
585       auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
586           loop_state_init->shape(), condition, body, loop_state_init));
587       module_->AddEntryComputation(builder.Build());
588       return while_hlo;
589     }
590 
591     auto loop_state_init = builder.AddInstruction(
592         HloInstruction::CreateTuple({induction_var_init, data_init}));
593     auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
594         loop_state_shape_, condition, body, loop_state_init));
595     module_->AddEntryComputation(builder.Build());
596     return while_hlo;
597   }
598 
BuildWhileInstruction_InitPointsToConstant()599   HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
600     auto builder = HloComputation::Builder(TestName() + ".While");
601     auto data_init = builder.AddInstruction(
602         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
603             {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
604     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
605                                                &builder);
606   }
607 
BuildWhileInstruction_InitPointsToParameter()608   HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
609     auto builder = HloComputation::Builder(TestName() + ".While");
610     auto data_init = builder.AddInstruction(
611         HloInstruction::CreateParameter(0, data_shape_, "data_init"));
612     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
613                                                &builder);
614   }
615 
BuildWhileInstruction_InitPointsToNonDistinct()616   HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
617     auto builder = HloComputation::Builder(TestName() + ".While");
618 
619     auto one = builder.AddInstruction(
620         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
621     auto one_vec = builder.AddInstruction(
622         HloInstruction::CreateBroadcast(data_shape_, one, {}));
623     auto data_init =
624         builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
625 
626     return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
627                                                data_init, &builder);
628   }
629 
BuildWhileInstruction_InitPointsToInterfering()630   HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
631     auto builder = HloComputation::Builder(TestName() + ".While");
632     auto one = builder.AddInstruction(
633         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
634     auto data_init = builder.AddInstruction(
635         HloInstruction::CreateBroadcast(data_shape_, one, {}));
636     auto one_vec = builder.AddInstruction(
637         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
638             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
639     // Take a reference to 'data_init' to make it interfere with while result.
640     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
641         data_shape_, HloOpcode::kAdd, data_init, one_vec));
642 
643     auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
644                                                          data_init, &builder);
645 
646     // Add an additional binary operation operating on the while and the
647     // interfering add so that neither operation is dead.
648     auto gte = xla_while->parent()->AddInstruction(
649         HloInstruction::CreateGetTupleElement(
650             ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
651     auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
652         data_shape_, HloOpcode::kSubtract, add, gte));
653     auto gte0 = xla_while->parent()->AddInstruction(
654         HloInstruction::CreateGetTupleElement(
655             ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
656     auto tuple = xla_while->parent()->AddInstruction(
657         HloInstruction::CreateTuple({gte0, sub}));
658 
659     xla_while->parent()->set_root_instruction(tuple);
660 
661     return xla_while;
662   }
663 
BuildWhileInstructionWithCustomInit(const Shape & loop_state_shape,HloInstruction * data_init,HloComputation::Builder * builder)664   HloInstruction* BuildWhileInstructionWithCustomInit(
665       const Shape& loop_state_shape, HloInstruction* data_init,
666       HloComputation::Builder* builder) {
667     const bool nested =
668         ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
669     auto induction_var_init = builder->AddInstruction(
670         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
671     auto condition = module_->AddEmbeddedComputation(
672         BuildConditionComputation(loop_state_shape));
673     auto body = module_->AddEmbeddedComputation(
674         BuildIndependentBodyComputation(nested));
675     auto loop_state_init = builder->AddInstruction(
676         HloInstruction::CreateTuple({induction_var_init, data_init}));
677     auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
678         loop_state_shape, condition, body, loop_state_init));
679     module_->AddEntryComputation(builder->Build());
680     return while_hlo;
681   }
682 
683   std::unique_ptr<HloModule> module_;
684   Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
685   Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
686   Shape loop_state_shape_ =
687       ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
688   Shape nested_tuple_shape_ =
689       ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
690   Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
691       {induction_variable_shape_, nested_tuple_shape_});
692   Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
693 };
694 
695 // Tests while body computation with independent tuple elements:
696 //
697 //   While.Body({in0, in1})
698 //     out0 = Add(in0, 1)
699 //     out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
700 //     Tuple(out0, out1)
701 //
702 // CopyInsertion pass should not generate any copies.
703 //
TEST_F(WhileCopyInsertionTest,IndependentTupleElements)704 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
705   auto condition = module_->AddEmbeddedComputation(
706       BuildConditionComputation(loop_state_shape_));
707   auto body =
708       module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
709   auto while_hlo = BuildWhileInstruction(condition, body);
710 
711   InsertCopies(module_.get());
712 
713   // Body should have no copies as the adds can be done inplace.
714   EXPECT_EQ(CountCopies(*body), 0);
715   EXPECT_EQ(CountControlEdges(*module_), 0);
716 
717   // Both init indices need copies as they are constants.
718   EXPECT_THAT(while_hlo->operand(0),
719               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
720 }
721 
722 // Tests Copy Insertion when a while feeds another while
723 //                         PARAMETER
724 //                        |        |
725 //                        GTE(0)   GTE(1)
726 //                        |        |
727 //                        X = CreateTuple(GTE(0), GTE(1))
728 //                                 |
729 //                        WHILE(X) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterWithCopies)730 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterWithCopies) {
731   const std::string& hlo_string = R"(
732 HloModule DependentTupleElements
733 
734 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
735   %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
736   %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
737   %constant.1 = s32[] constant(1)
738   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
739   %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
740   %convert = f32[] convert(s32[] %get-tuple-element.1)
741   %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
742   %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
743   ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
744 }
745 
746 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
747   %loop_state = (s32[], f32[8]{0}) parameter(0)
748   %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
749   %constant = s32[] constant(10)
750   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
751 }
752 
753 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
754   %constant.2 = s32[] constant(0)
755   %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
756   %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
757   ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
758 }
759 )";
760   auto module_ = ParseAndReturnVerifiedModule(hlo_string).value();
761   auto while_hlo = module_->entry_computation()->root_instruction();
762   // module_ and while_hlo are the pre-existing module and hlo, the below
763   // code generates a clone of the existing while and replaces that while
764   // with itself. The body of the new while calls the previous while
765   HloComputation* outer_while_condition =
766       module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
767   HloComputation* outer_while_body =
768       module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
769   HloInstruction* outer_while =
770       while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
771           while_hlo->shape(), outer_while_condition, outer_while_body,
772           while_hlo->mutable_operand(0)));
773   HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
774   std::vector<HloInstruction*> materialized_gtes;
775   for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
776     materialized_gtes.push_back(
777         outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
778             outer_param->shape().tuple_shapes(i), outer_param, i)));
779   }
780   HloInstruction* dual_init = outer_while_body->AddInstruction(
781       HloInstruction::CreateTuple(materialized_gtes));
782   HloInstruction* dual_while =
783       outer_while_body->AddInstruction(HloInstruction::CreateWhile(
784           while_hlo->shape(), while_hlo->while_condition(),
785           while_hlo->while_body(), dual_init));
786   TF_CHECK_OK(outer_while_body->ReplaceInstruction(
787       outer_while_body->root_instruction(), dual_while));
788   TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
789   InsertCopies(module_.get());
790 }
791 
792 // Tests Copy Insertion when a while feeds another while
793 //                         PARAMETER
794 //                        |        |
795 //                         \      /
796 //                           WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterNoCopies)797 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterNoCopies) {
798   const std::string& hlo_string = R"(
799 HloModule DependentTupleElements
800 
801 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
802   %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
803   %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
804   %constant.1 = s32[] constant(1)
805   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
806   %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
807   %convert = f32[] convert(s32[] %get-tuple-element.1)
808   %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
809   %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
810   ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
811 }
812 
813 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
814   %loop_state = (s32[], f32[8]{0}) parameter(0)
815   %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
816   %constant = s32[] constant(10)
817   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
818 }
819 
820 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
821   %constant.2 = s32[] constant(0)
822   %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
823   %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
824   ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
825 }
826 )";
827   auto module_ = ParseAndReturnVerifiedModule(hlo_string).value();
828   auto while_hlo = module_->entry_computation()->root_instruction();
829   // module_ and while_hlo are the pre-existing module and hlo, the below
830   // code generates a clone of the existing while and replaces that while
831   // with itself. The body of the new while calls the previous while
832   HloComputation* outer_while_condition =
833       module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
834   HloComputation* outer_while_body =
835       module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
836   HloInstruction* outer_while =
837       while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
838           while_hlo->shape(), outer_while_condition, outer_while_body,
839           while_hlo->mutable_operand(0)));
840   HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
841   HloInstruction* dual_while =
842       outer_while_body->AddInstruction(HloInstruction::CreateWhile(
843           while_hlo->shape(), while_hlo->while_condition(),
844           while_hlo->while_body(), outer_param));
845   TF_CHECK_OK(outer_while_body->ReplaceInstruction(
846       outer_while_body->root_instruction(), dual_while));
847   TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
848   InsertCopies(module_.get());
849 }
850 
851 // Tests Copy Insertion when a while feeds another while
852 //                         PARAMETER
853 //                        |        |
854 //                         \      /
855 //                           WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterBig)856 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterBig) {
857   const std::string& hlo_string = R"(
858 HloModule DependentTupleElements
859 
860 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
861   %loop_state.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
862   %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=0
863   %constant.1 = s32[] constant(1)
864   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
865   %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=1
866   %convert = f32[] convert(s32[] %get-tuple-element.1)
867   %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
868   %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
869   ROOT %tuple = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1)
870 }
871 
872 %DependentTupleElements.Condition (loop_state: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> pred[] {
873   %loop_state = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
874   %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state), index=0
875   %constant = s32[] constant(10)
876   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
877 }
878 
879 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
880   %constant.2 = s32[] constant(0)
881   %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
882   %tuple.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3)
883   ROOT %while.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) while( (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
884 }
885 )";
886   auto module_ = ParseAndReturnVerifiedModule(hlo_string).value();
887   auto while_hlo = module_->entry_computation()->root_instruction();
888   // module_ and while_hlo are the pre-existing module and hlo, the below
889   // code generates a clone of the existing while and replaces that while
890   // with itself. The body of the new while calls the previous while
891   HloComputation* outer_while_condition =
892       module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
893   HloComputation* outer_while_body =
894       module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
895   HloInstruction* outer_while =
896       while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
897           while_hlo->shape(), outer_while_condition, outer_while_body,
898           while_hlo->mutable_operand(0)));
899   HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
900   std::vector<HloInstruction*> materialized_gtes;
901   for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
902     materialized_gtes.push_back(
903         outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
904             outer_param->shape().tuple_shapes(i), outer_param, i)));
905   }
906   HloInstruction* dual_init = outer_while_body->AddInstruction(
907       HloInstruction::CreateTuple(materialized_gtes));
908   HloInstruction* dual_while =
909       outer_while_body->AddInstruction(HloInstruction::CreateWhile(
910           while_hlo->shape(), while_hlo->while_condition(),
911           while_hlo->while_body(), dual_init));
912   TF_CHECK_OK(outer_while_body->ReplaceInstruction(
913       outer_while_body->root_instruction(), dual_while));
914   TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
915   InsertCopies(module_.get());
916 }
917 
918 // Tests while body computation with dependent tuple elements:
919 //
920 //   While.Body({in0, in1})
921 //     out0 = Add(in0, 1)
922 //     out1 = Add(BCast(in0), in1)
923 //     Tuple(out0, out1)
924 //
925 // CopyInsertion pass should convert the root instruction to:
926 //
927 //     Tuple(Copy(out0), out1)
928 //
TEST_F(WhileCopyInsertionTest,DependentTupleElements)929 TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
930   auto condition = module_->AddEmbeddedComputation(
931       BuildConditionComputation(loop_state_shape_));
932   auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
933   auto while_hlo = BuildWhileInstruction(condition, body);
934 
935   InsertCopies(module_.get());
936 
937   EXPECT_EQ(CountCopies(*body), 1);
938   EXPECT_EQ(CountControlEdges(*body), 0);
939 
940   EXPECT_THAT(
941       body->root_instruction(),
942       op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
943 
944   auto add = body->root_instruction()->operand(0);
945   auto bcast = body->root_instruction()->operand(1)->operand(1);
946   ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
947   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
948 
949   EXPECT_THAT(while_hlo->while_body()->root_instruction(),
950               op::Tuple(op::Add(op::Copy(), op::Constant()),
951                         op::Add(op::GetTupleElement(),
952                                 op::Broadcast(op::Convert(op::Copy())))));
953 
954   // Both init indices need copies as they are constants.
955   EXPECT_THAT(while_hlo->operand(0),
956               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
957 }
958 
959 // Tests while body computation with read-only tuple element 0:
960 //
961 //                         PARAMETER
962 //                         /       \
963 //                      GTE(0)     GTE(1)
964 //                        |  \      |
965 //                        |   BCAST |
966 //                        |      \  |
967 //                        |       ADD
968 //                        |        |
969 //                         \      /
970 //                           TUPLE (root)
971 //
972 // CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly)973 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
974   auto condition = module_->AddEmbeddedComputation(
975       BuildConditionComputation(loop_state_shape_));
976   auto body = module_->AddEmbeddedComputation(
977       BuildDependentBodyOneReadOnlyComputation());
978   BuildWhileInstruction(condition, body);
979 
980   InsertCopies(module_.get());
981 
982   // No copies or control edges should be inserted. The body is legal as is.
983   EXPECT_EQ(CountCopies(*body), 0);
984   EXPECT_EQ(CountControlEdges(*body), 0);
985 }
986 
987 // Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_EntryParams)988 TEST_F(WhileCopyInsertionTest,
989        DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
990   auto condition1 = module_->AddEmbeddedComputation(
991       BuildConditionComputation(loop_state_shape_));
992   auto condition2 = module_->AddEmbeddedComputation(
993       BuildConditionComputation(loop_state_shape_));
994   auto body1 = module_->AddEmbeddedComputation(
995       BuildDependentBodyOneReadOnlyComputation());
996   auto body2 = module_->AddEmbeddedComputation(
997       BuildDependentBodyOneReadOnlyComputation());
998 
999   auto builder = HloComputation::Builder(TestName() + ".While");
1000   auto iter_param = builder.AddInstruction(
1001       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1002   auto data_param = builder.AddInstruction(
1003       HloInstruction::CreateParameter(1, data_shape_, "data"));
1004   auto loop_init = builder.AddInstruction(
1005       HloInstruction::CreateTuple({iter_param, data_param}));
1006 
1007   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1008       loop_state_shape_, condition1, body1, loop_init));
1009   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1010       loop_state_shape_, condition2, body2, loop_init));
1011 
1012   // Add a couple elements from each of the while so both whiles are live.
1013   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1014       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1015   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1016       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1017   builder.AddInstruction(
1018       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1019 
1020   auto entry = module_->AddEntryComputation(builder.Build());
1021 
1022   InsertCopies(module_.get());
1023 
1024   // Neither body should have any copies or control edges in them.
1025   EXPECT_EQ(CountCopies(*body1), 0);
1026   EXPECT_EQ(CountCopies(*body2), 0);
1027   EXPECT_EQ(CountControlEdges(*body1), 0);
1028   EXPECT_EQ(CountControlEdges(*body2), 0);
1029 
1030   // Only two copies should be necessary. Each of the whiles should have
1031   // a copy of tuple element 1 (init value is a parameter, and the element is
1032   // not non-read-only) so each of the while bodies gets its own buffer to write
1033   // element 1 into.
1034   EXPECT_EQ(CountCopies(*entry), 2);
1035 
1036   EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1037   EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1038 
1039   // The two copies of element 1 should be different.
1040   EXPECT_NE(while_hlo1->operand(0)->operand(1),
1041             while_hlo2->operand(0)->operand(1));
1042 }
1043 
1044 // Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_NonParams)1045 TEST_F(WhileCopyInsertionTest,
1046        DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
1047   auto condition1 = module_->AddEmbeddedComputation(
1048       BuildConditionComputation(loop_state_shape_));
1049   auto condition2 = module_->AddEmbeddedComputation(
1050       BuildConditionComputation(loop_state_shape_));
1051   auto body1 = module_->AddEmbeddedComputation(
1052       BuildDependentBodyOneReadOnlyComputation());
1053   auto body2 = module_->AddEmbeddedComputation(
1054       BuildDependentBodyOneReadOnlyComputation());
1055 
1056   auto builder = HloComputation::Builder(TestName() + ".While");
1057   auto iter_param = builder.AddInstruction(
1058       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1059   auto data_param = builder.AddInstruction(
1060       HloInstruction::CreateParameter(1, data_shape_, "data"));
1061   // Add dummy ops to ensure loop_init elements aren't entry parameters.
1062   Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
1063   auto convert = builder.AddInstruction(
1064       HloInstruction::CreateConvert(f32_scalar_shape, iter_param));
1065   auto iter_value = builder.AddInstruction(
1066       HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert));
1067   auto convert2 = builder.AddInstruction(
1068       HloInstruction::CreateConvert(induction_variable_shape_, iter_value));
1069   auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
1070       data_param->shape(), HloOpcode::kExp, data_param));
1071   auto loop_init = builder.AddInstruction(
1072       HloInstruction::CreateTuple({convert2, data_value}));
1073 
1074   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1075       loop_state_shape_, condition1, body1, loop_init));
1076   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1077       loop_state_shape_, condition2, body2, loop_init));
1078 
1079   // Add a couple elements from each of the while so both whiles are not dead.
1080   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1081       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1082   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1083       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1084   builder.AddInstruction(
1085       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1086   auto entry = module_->AddEntryComputation(builder.Build());
1087 
1088   InsertCopies(module_.get());
1089 
1090   // Ideally only one copy should be necessary. One of the whiles should
1091   // have a copy of tuple element 1 (the non-read-only element) so each of the
1092   // while bodies gets its own buffer to write element 1 into. However, the
1093   // analysis isn't perfect and adds an additional copy of element 0.
1094   EXPECT_EQ(CountCopies(*entry), 2);
1095 
1096   EXPECT_THAT(while_hlo1->operand(0),
1097               op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1098   EXPECT_THAT(while_hlo2->operand(0),
1099               op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1100 }
1101 
1102 // Tests while body computation with nested tuple elements:
1103 //
1104 //                            |
1105 //                    GTE(loop_state, 1)
1106 //                       /          \
1107 // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
1108 //           |                              |
1109 //          Add                           Reverse
1110 //           |                              |
1111 //
1112 // CopyInsertion pass will conceptually generate the following, but with the
1113 // actual GTE and Tuple instructions optimized away:
1114 //
1115 //                    Tuple  // old root
1116 //                   /     \
1117 //                  /       \
1118 //                GTE(0)   GTE(1)
1119 //                  |       /  \
1120 //                  |      /    \
1121 //                  |    GTE(0) GTE(1)
1122 //                  |       |    |
1123 //                  |       |   Copy
1124 //                  |       |    |
1125 //                   \      |   /
1126 //                    \    Tuple  // "inner" tuple.
1127 //                     \    /
1128 //                      \  /
1129 //                     Tuple  // new root
1130 //
TEST_F(WhileCopyInsertionTest,NestedTupleElements)1131 TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
1132   auto condition = module_->AddEmbeddedComputation(
1133       BuildConditionComputation(nested_loop_state_shape_));
1134   auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
1135   BuildWhileInstruction(condition, body, true);
1136 
1137   //  HloInstruction* old_root = body->root_instruction();
1138   InsertCopies(module_.get());
1139 
1140   // The only copy necessary is for the kReverse as it cannot be done
1141   // in-place (instruction can share buffer with operand). The other elements of
1142   // the loop state are kAdd instructions which can be done in-place.
1143   EXPECT_EQ(CountCopies(*body), 1);
1144 
1145   // Each element of the init needs a copy as all are constants.
1146   EXPECT_EQ(CountCopies(*module_), 4);
1147 
1148   // Either the kReverse itself must be copied or the operand of the kReverse
1149   // must be copied.
1150   if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
1151       HloOpcode::kCopy) {
1152     EXPECT_THAT(
1153         body->root_instruction(),
1154         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
1155   } else {
1156     EXPECT_THAT(
1157         body->root_instruction(),
1158         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
1159   }
1160 }
1161 
1162 // Tests while init instruction which points-to a constant.
1163 //
1164 //     init = Tuple(Constant(S32, {}), Constant(F32, {8}))
1165 //
1166 // CopyInsertion pass should add copies for both constants.
1167 //
TEST_F(WhileCopyInsertionTest,InitPointsToConstant)1168 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
1169   auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
1170 
1171   InsertCopies(module_.get());
1172   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1173   EXPECT_EQ(CountCopies(*module_), 2);
1174 
1175   EXPECT_THAT(while_hlo->operand(0),
1176               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1177 }
1178 
1179 // Tests while init instruction which points-to a parameter.
1180 //
1181 //     init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1182 //
1183 // CopyInsertion pass should add copies for both the constant and parameter.
1184 //
TEST_F(WhileCopyInsertionTest,InitPointsToParameter)1185 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
1186   auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
1187 
1188   InsertCopies(module_.get());
1189   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1190   EXPECT_EQ(CountCopies(*module_), 2);
1191 
1192   EXPECT_THAT(while_hlo->operand(0),
1193               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
1194 }
1195 
1196 // Tests while init instruction which has a non-distinct points-to set.
1197 //
1198 //     init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
1199 //
1200 // CopyInsertion pass will conceptually generate the following, but with some of
1201 // the actual GTE and Tuple instructions optimized away:
1202 //
1203 //                    Tuple  // old init
1204 //                   /     \
1205 //                  /       \
1206 //                GTE(0)   GTE(1)
1207 //                  |       /  \
1208 //                  |      /    \
1209 //                  |    GTE(0) GTE(1)
1210 //                  |       |    |
1211 //                Copy   Copy   Copy
1212 //                  |       |    |
1213 //                   \      |   /
1214 //                    \    Tuple
1215 //                     \    /
1216 //                      \  /
1217 //                     Tuple  // new init
1218 //
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinct)1219 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
1220   auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
1221 
1222   InsertCopies(module_.get());
1223 
1224   // The entry computation requires two copies to resolve the non-distinctness
1225   // of two init elements and the constant passed in as one of the init
1226   // elements. Either element can be copied for the distinctness issue.
1227   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1228   if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
1229       HloOpcode::kCopy) {
1230     EXPECT_THAT(
1231         while_hlo->operand(0),
1232         op::Tuple(op::Copy(op::Constant()),
1233                   op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
1234   } else {
1235     EXPECT_THAT(
1236         while_hlo->operand(0),
1237         op::Tuple(op::Copy(op::Constant()),
1238                   op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
1239   }
1240 
1241   // The body requires one copy because the buffer set is not distinct: the
1242   // result of one of the adds is written into two elements of the output of the
1243   // loop body. Either element might be copied.
1244   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1245   if (while_hlo->while_body()
1246           ->root_instruction()
1247           ->operand(1)
1248           ->operand(0)
1249           ->opcode() == HloOpcode::kCopy) {
1250     EXPECT_THAT(
1251         while_hlo->while_body()->root_instruction(),
1252         op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1253   } else {
1254     EXPECT_THAT(
1255         while_hlo->while_body()->root_instruction(),
1256         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1257   }
1258 }
1259 
1260 // Tests while init instruction buffer which interferes with while result
1261 // buffer.
1262 //
1263 //     init_data = Broadcast(...)
1264 //     add_unrelated = Add(init_data) // takes a reference to cause interference
1265 //     init = Tuple(Constant(S32, {}), init_data))
1266 //
1267 // CopyInsertion pass should copy both operands.
1268 //
TEST_F(WhileCopyInsertionTest,InitPointsToInterfering)1269 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
1270   auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
1271 
1272   InsertCopies(module_.get());
1273   EXPECT_EQ(CountCopies(*module_), 2);
1274   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1275 
1276   EXPECT_THAT(while_hlo->operand(0),
1277               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
1278 }
1279 
1280 // Tests while init instruction buffer which has a non-distinct points-to set:
1281 //
1282 //     init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
1283 //                  Parameter(F32, {8})))
1284 //
1285 // where the second and third parameters are identical *and* the tuple shared
1286 // by another while instruction.
1287 //
1288 // Verifies that the resulting point-to set is distinct in the resulting Tuple
1289 // (non-identical Copys). In other words, verifies that copy sharing does not
1290 // insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinctUsedByTwoWhileLoops)1291 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
1292   // Loop body that outputs tuple comprises two elements dependent on the init
1293   // tuple.
1294   const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
1295       {induction_variable_shape_, data_shape_, data_shape_});
1296 
1297   auto condition1 = module_->AddEmbeddedComputation(
1298       BuildConditionComputation(loop_state_shape));
1299   auto condition2 = module_->AddEmbeddedComputation(
1300       BuildConditionComputation(loop_state_shape));
1301   auto body1 =
1302       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1303   auto body2 =
1304       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1305 
1306   auto builder = HloComputation::Builder(TestName() + ".While");
1307 
1308   auto iter_param = builder.AddInstruction(
1309       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1310   auto data_param = builder.AddInstruction(
1311       HloInstruction::CreateParameter(1, data_shape_, "data"));
1312 
1313   // Loop init tuple contains two identical parameter buffers.
1314   auto loop_init = builder.AddInstruction(
1315       HloInstruction::CreateTuple({iter_param, data_param, data_param}));
1316 
1317   // Two while loops share the same loop init tuple.
1318   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1319       loop_state_shape, condition1, body1, loop_init));
1320   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1321       loop_state_shape, condition2, body2, loop_init));
1322 
1323   // Add add instruction so neither while is dead.
1324   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1325       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1326   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1327       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
1328   builder.AddInstruction(
1329       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1330 
1331   module_->AddEntryComputation(builder.Build());
1332 
1333   InsertCopies(module_.get());
1334 
1335   // None of the bodies should have copies or control flow edges.
1336   EXPECT_EQ(CountCopies(*body1), 0);
1337   EXPECT_EQ(CountCopies(*body2), 0);
1338 
1339   // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
1340   // these should not need to be copied before either while. However, copy
1341   // insertion is not able to reason about the transparency of elements through
1342   // while bodies in all circumstances so extra copies are added (b/xxx).
1343   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1344 
1345   EXPECT_THAT(while_hlo1->operand(0),
1346               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1347   EXPECT_THAT(while_hlo2->operand(0),
1348               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1349 }
1350 
TEST_F(CopyInsertionTest,SwizzlingWhile)1351 TEST_F(CopyInsertionTest, SwizzlingWhile) {
1352   // Test a while instruction with a body which permutes its tuple parameter
1353   // elements.
1354   auto module = CreateNewVerifiedModule();
1355   const Shape loop_state_shape =
1356       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1357 
1358   // Body simply interchanges the two tuple elements in the loop state.
1359   auto body_builder = HloComputation::Builder("body");
1360   auto body_param = body_builder.AddInstruction(
1361       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1362   auto body_element_0 = body_builder.AddInstruction(
1363       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1364   auto body_element_1 = body_builder.AddInstruction(
1365       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1366   body_builder.AddInstruction(
1367       HloInstruction::CreateTuple({body_element_1, body_element_0}));
1368   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1369 
1370   auto cond_builder = HloComputation::Builder("condition");
1371   cond_builder.AddInstruction(
1372       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1373   auto cond_constant = cond_builder.AddInstruction(
1374       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1375   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1376       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1377   HloComputation* condition =
1378       module->AddEmbeddedComputation(cond_builder.Build());
1379 
1380   auto builder = HloComputation::Builder(TestName());
1381   auto constant1 = builder.AddInstruction(
1382       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1383   auto constant2 = builder.AddInstruction(
1384       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1385   auto tuple = builder.AddInstruction(
1386       HloInstruction::CreateTuple({constant1, constant2}));
1387   auto xla_while = builder.AddInstruction(
1388       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1389   module->AddEntryComputation(builder.Build());
1390 
1391   InsertCopies(module.get());
1392 
1393   EXPECT_EQ(CountCopies(*module), 6);
1394 
1395   // The loop state elements should be copied at the parameter and at the root
1396   // with a control edge in between (see DeepCopyAndAddControlEdges). This is
1397   // technically one more copy than is strictly necessary, but in order to have
1398   // only three copies the copies of different loop state elements must be
1399   // ordered with a control edge.
1400   EXPECT_EQ(CountCopies(*body), 4);
1401   EXPECT_EQ(CountControlEdges(*body), 2);
1402 
1403   EXPECT_THAT(body->root_instruction(),
1404               op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
1405 
1406   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1407   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1408 }
1409 
TEST_F(CopyInsertionTest,CrossingParameters)1410 TEST_F(CopyInsertionTest, CrossingParameters) {
1411   // Test a case where two parameters' dataflow cross with each other while
1412   // input and output are aliased with same index:
1413   //
1414   //  (p0 ,  p1)
1415   //   | \   /|
1416   //   |  \ / |
1417   // alias X  alias
1418   //   |  / \ |
1419   //   | /   \|
1420   //  (p1  ,  p0)
1421   auto module = CreateNewVerifiedModule();
1422   const Shape tuple_shape =
1423       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1424 
1425   auto builder = HloComputation::Builder(TestName());
1426   auto param = builder.AddInstruction(
1427       HloInstruction::CreateParameter(0, tuple_shape, "0"));
1428   auto gte0 = builder.AddInstruction(
1429       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1430   auto gte1 = builder.AddInstruction(
1431       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1432   builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
1433   module->AddEntryComputation(builder.Build());
1434   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1435       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1436   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1437       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1438   InsertCopies(module.get());
1439 
1440   EXPECT_EQ(CountCopies(*module), 4);
1441 }
1442 
TEST_F(CopyInsertionTest,ParametersAliasing)1443 TEST_F(CopyInsertionTest, ParametersAliasing) {
1444   // Test a case where two parameters' dataflow don't interfere with each other
1445   // while aliased.
1446   //
1447   //  (p0 ,  p1)
1448   //   |      |
1449   //   |      |
1450   // alias   alias
1451   //   |      |
1452   //   |      |
1453   //  (p0 ,  p1)
1454   auto module = CreateNewVerifiedModule();
1455   const Shape tuple_shape =
1456       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1457 
1458   auto builder = HloComputation::Builder(TestName());
1459   auto param = builder.AddInstruction(
1460       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1461   auto gte0 = builder.AddInstruction(
1462       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1463   auto gte1 = builder.AddInstruction(
1464       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1465   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1466   module->AddEntryComputation(builder.Build());
1467   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1468       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1469   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1470       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1471   InsertCopies(module.get());
1472 
1473   EXPECT_EQ(CountCopies(*module), 0);
1474 }
1475 
TEST_F(CopyInsertionTest,ParameterWithNoAliasing)1476 TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
1477   // Test a case where no parameter is aliased with result. In this case, copy
1478   // should be added
1479   //
1480   //  (p0 ,  p1)
1481   //   |      |
1482   //   |      |
1483   //   |      |
1484   //   |      |
1485   //   |      |
1486   //  (p0 ,  p1)
1487   auto module = CreateNewVerifiedModule();
1488   const Shape tuple_shape =
1489       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1490 
1491   auto builder = HloComputation::Builder(TestName());
1492   auto param = builder.AddInstruction(
1493       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1494   auto gte0 = builder.AddInstruction(
1495       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1496   auto gte1 = builder.AddInstruction(
1497       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1498   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1499   module->AddEntryComputation(builder.Build());
1500   InsertCopies(module.get());
1501 
1502   EXPECT_THAT(module->entry_computation()->root_instruction(),
1503               op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
1504                         op::Copy(op::GetTupleElement(param, 1))));
1505 
1506   EXPECT_EQ(CountCopies(*module), 2);
1507 }
1508 
TEST_F(CopyInsertionTest,ParameterWithPartialAliasing)1509 TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
1510   // Test a case where one parameter is aliased with result while another one
1511   // isn't.
1512   //
1513   //  (p0 ,  p1)
1514   //   |      |
1515   //   |      |
1516   // alias    |
1517   //   |      |
1518   //   |      |
1519   //  (p0 ,  p1)
1520   auto module = CreateNewVerifiedModule();
1521   const Shape tuple_shape =
1522       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1523 
1524   auto builder = HloComputation::Builder(TestName());
1525   auto param = builder.AddInstruction(
1526       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1527   auto gte0 = builder.AddInstruction(
1528       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1529   auto gte1 = builder.AddInstruction(
1530       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1531   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1532   module->AddEntryComputation(builder.Build());
1533   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1534       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1535   InsertCopies(module.get());
1536 
1537   EXPECT_THAT(module->entry_computation()->root_instruction(),
1538               op::Tuple(op::GetTupleElement(param, 0),
1539                         op::Copy(op::GetTupleElement(param, 1))));
1540 
1541   EXPECT_EQ(CountCopies(*module), 1);
1542 }
1543 
TEST_F(CopyInsertionTest,ParameterAndParallelOpsWithPartialAliasing)1544 TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
1545   // Test a case where one parameter is aliased with result while another one
1546   // isn't.
1547   //
1548   //   +-- (p0 ,  p1)
1549   //   |    |      |
1550   //   |    |      |
1551   // alias Negate  Negate
1552   //   |    |      |
1553   //   |    |      |
1554   //   +-- (p0 ,  p1)
1555   auto module = CreateNewVerifiedModule();
1556   const Shape tuple_shape =
1557       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1558 
1559   auto builder = HloComputation::Builder(TestName());
1560   auto param = builder.AddInstruction(
1561       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1562   auto gte0 = builder.AddInstruction(
1563       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1564   auto gte1 = builder.AddInstruction(
1565       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1566 
1567   auto negate0 = builder.AddInstruction(
1568       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1569 
1570   auto negate1 = builder.AddInstruction(
1571       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1572   builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
1573   module->AddEntryComputation(builder.Build());
1574   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1575       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1576   InsertCopies(module.get());
1577 
1578   EXPECT_EQ(CountCopies(*module), 0);
1579 }
1580 
TEST_F(CopyInsertionTest,ParameterAndOpsWithPartialAliasing)1581 TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
1582   // Test a case where one parameter is aliased with result while another one
1583   // isn't.
1584   //
1585   //   +-- (p0 ,  p1)
1586   //   |    |      |
1587   //   |    |      |
1588   // alias Negate  Negate
1589   //   |    |      |
1590   //   |    Add----+
1591   //   |    |      |
1592   //   +-- (p0 ,  p1)
1593   auto module = CreateNewVerifiedModule();
1594   const Shape tuple_shape =
1595       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1596 
1597   auto builder = HloComputation::Builder(TestName());
1598   auto param = builder.AddInstruction(
1599       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1600   auto gte0 = builder.AddInstruction(
1601       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1602   auto gte1 = builder.AddInstruction(
1603       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1604 
1605   auto negate0 = builder.AddInstruction(
1606       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1607 
1608   auto negate1 = builder.AddInstruction(
1609       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1610 
1611   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1612       scalar_shape_, HloOpcode::kAdd, negate0, negate1));
1613   builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
1614   module->AddEntryComputation(builder.Build());
1615   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1616       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1617   InsertCopies(module.get());
1618 
1619   EXPECT_EQ(CountCopies(*module), 0);
1620 }
1621 
TEST_F(CopyInsertionTest,SwizzlingWhileWithOneOp)1622 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
1623   // Test a while instruction with a body which permutes its tuple parameter
1624   // elements and applies one operation to one of the elements. The addition of
1625   // the operation (instruction) on the element makes the live range of the
1626   // respective input and output elements different than if the instruction were
1627   // not there (as in the SwizzlingWhile test above).
1628   auto module = CreateNewVerifiedModule();
1629   const Shape loop_state_shape =
1630       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1631 
1632   // Body interchanges the two tuple elements in the loop state and negates one
1633   // of them.
1634   auto body_builder = HloComputation::Builder("body");
1635   auto body_param = body_builder.AddInstruction(
1636       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1637   auto body_element_0 = body_builder.AddInstruction(
1638       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1639   auto body_element_1 = body_builder.AddInstruction(
1640       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1641   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1642       scalar_shape_, HloOpcode::kNegate, body_element_1));
1643   body_builder.AddInstruction(
1644       HloInstruction::CreateTuple({negate, body_element_0}));
1645   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1646 
1647   auto cond_builder = HloComputation::Builder("condition");
1648   cond_builder.AddInstruction(
1649       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1650   auto cond_constant = cond_builder.AddInstruction(
1651       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1652   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1653       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1654   HloComputation* condition =
1655       module->AddEmbeddedComputation(cond_builder.Build());
1656 
1657   auto builder = HloComputation::Builder(TestName());
1658   auto constant1 = builder.AddInstruction(
1659       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1660   auto constant2 = builder.AddInstruction(
1661       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1662   auto tuple = builder.AddInstruction(
1663       HloInstruction::CreateTuple({constant1, constant2}));
1664   auto xla_while = builder.AddInstruction(
1665       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1666   module->AddEntryComputation(builder.Build());
1667 
1668   InsertCopies(module.get());
1669 
1670   EXPECT_EQ(CountCopies(*module), 6);
1671 
1672   // The loop state elements should be copied at the parameter and at the root
1673   // with a control edge in between (see DeepCopyAndAddControlEdges).
1674   EXPECT_EQ(CountCopies(*body), 4);
1675   EXPECT_EQ(CountControlEdges(*body), 2);
1676 
1677   EXPECT_THAT(
1678       body->root_instruction(),
1679       op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
1680 
1681   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1682   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1683 }
1684 
TEST_F(CopyInsertionTest,SwizzlingWhileSharedInput)1685 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
1686   // Test a while instruction with a body which permutes it's tuple parameter
1687   // elements similar to SwizzlinWhile above. However, in this test the input to
1688   // the while body is a single constant (both loop state elements are the same
1689   // constant). This means no copies are necessary because both loop state
1690   // elements are the same so interchanging them is a no-op.
1691   auto module = CreateNewVerifiedModule();
1692   const Shape loop_state_shape =
1693       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1694 
1695   // Body simply interchanges the two tuple elements in the loop state.
1696   auto body_builder = HloComputation::Builder("body");
1697   auto body_param = body_builder.AddInstruction(
1698       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1699   auto body_element_0 = body_builder.AddInstruction(
1700       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1701   auto body_element_1 = body_builder.AddInstruction(
1702       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1703   body_builder.AddInstruction(
1704       HloInstruction::CreateTuple({body_element_1, body_element_0}));
1705   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1706 
1707   auto cond_builder = HloComputation::Builder("condition");
1708   cond_builder.AddInstruction(
1709       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1710   auto cond_constant = cond_builder.AddInstruction(
1711       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1712   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1713       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1714   HloComputation* condition =
1715       module->AddEmbeddedComputation(cond_builder.Build());
1716 
1717   auto builder = HloComputation::Builder(TestName());
1718   auto constant = builder.AddInstruction(
1719       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1720   auto tuple =
1721       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
1722   builder.AddInstruction(
1723       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1724   module->AddEntryComputation(builder.Build());
1725 
1726   InsertCopies(module.get());
1727 
1728   EXPECT_EQ(CountCopies(*module), 2);
1729   EXPECT_EQ(CountCopies(*body), 0);
1730 
1731   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1732   EXPECT_THAT(module->entry_computation()->root_instruction(),
1733               op::Tuple(op::Copy(), op::Copy()));
1734 }
1735 
TEST_F(CopyInsertionTest,SequentialWhiles)1736 TEST_F(CopyInsertionTest, SequentialWhiles) {
1737   // Construct a computation with a series of sequential while instructions
1738   // containing four loop state elements:
1739   //
1740   //   element 0 is passed to each while directly from an entry parameter.
1741   //
1742   //   element 1 is passed transparently in series through all the while bodies.
1743   //
1744   //   element 2 is negated in each while body. (in-place possible)
1745   //
1746   //   element 3 is reversed in each while body. (in-place not possible)
1747   //
1748   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1749   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1750       {element_shape, element_shape, element_shape, element_shape});
1751 
1752   auto module = CreateNewVerifiedModule();
1753   auto builder = HloComputation::Builder(TestName());
1754   auto param_0 = builder.AddInstruction(
1755       HloInstruction::CreateParameter(0, element_shape, "param_0"));
1756   auto param_1 = builder.AddInstruction(
1757       HloInstruction::CreateParameter(1, element_shape, "param_1"));
1758   auto param_2 = builder.AddInstruction(
1759       HloInstruction::CreateParameter(2, element_shape, "param_2"));
1760   auto param_3 = builder.AddInstruction(
1761       HloInstruction::CreateParameter(3, element_shape, "param_3"));
1762 
1763   // The number of sequential kWhile instructions.
1764   const int kNumWhiles = 3;
1765 
1766   HloInstruction* prev_element_1 = param_1;
1767   HloInstruction* prev_element_2 = param_2;
1768   HloInstruction* prev_element_3 = param_3;
1769 
1770   // Vector containing all of the while instructions.
1771   std::vector<const HloInstruction*> whiles;
1772   for (int i = 0; i < kNumWhiles; ++i) {
1773     auto body_builder = HloComputation::Builder("body");
1774     auto body_param = body_builder.AddInstruction(
1775         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1776     auto body_element_0 = body_builder.AddInstruction(
1777         HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
1778     auto body_element_1 = body_builder.AddInstruction(
1779         HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
1780     auto body_element_2 = body_builder.AddInstruction(
1781         HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
1782     auto body_element_3 = body_builder.AddInstruction(
1783         HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
1784     auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1785         element_shape, HloOpcode::kNegate, body_element_2));
1786     auto reverse = body_builder.AddInstruction(
1787         HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
1788     body_builder.AddInstruction(HloInstruction::CreateTuple(
1789         {body_element_0, body_element_1, negate, reverse}));
1790     HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1791 
1792     auto cond_builder = HloComputation::Builder("condition");
1793     cond_builder.AddInstruction(
1794         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1795     auto cond_constant = cond_builder.AddInstruction(
1796         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1797     cond_builder.AddInstruction(HloInstruction::CreateUnary(
1798         cond_constant->shape(), HloOpcode::kNot, cond_constant));
1799     HloComputation* condition =
1800         module->AddEmbeddedComputation(cond_builder.Build());
1801 
1802     auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
1803         {param_0, prev_element_1, prev_element_2, prev_element_3}));
1804 
1805     auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
1806         loop_state_shape, condition, body, while_init));
1807     whiles.push_back(xla_while);
1808     if (i != kNumWhiles - 1) {
1809       prev_element_1 = builder.AddInstruction(
1810           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
1811       prev_element_2 = builder.AddInstruction(
1812           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
1813       prev_element_3 = builder.AddInstruction(
1814           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
1815     }
1816   }
1817 
1818   module->AddEntryComputation(builder.Build());
1819 
1820   InsertCopies(module.get());
1821 
1822   // Each while body has one copy. And each loop state element is copied once in
1823   // the entry computation.
1824   EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
1825 
1826   // Each while body should have exactly one copy for element three which is an
1827   // op (kReverse) which cannot be done in place.
1828   for (const HloInstruction* xla_while : whiles) {
1829     EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
1830   }
1831 
1832   EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
1833                                                op::Copy(), op::Copy()));
1834   EXPECT_THAT(module->entry_computation()->root_instruction(),
1835               op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
1836                         op::GetTupleElement()));
1837 }
1838 
TEST_F(CopyInsertionTest,WhileBodyWithConstantRoot)1839 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
1840   // Test a while body and condition which are each simply a constant (root of
1841   // computation is a constant). The body constant should be copied.
1842   auto module = CreateNewVerifiedModule();
1843   auto builder = HloComputation::Builder(TestName());
1844   auto param_0 = builder.AddInstruction(
1845       HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
1846 
1847   auto body_builder = HloComputation::Builder("body");
1848   body_builder.AddInstruction(
1849       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1850   body_builder.AddInstruction(
1851       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
1852   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1853 
1854   auto cond_builder = HloComputation::Builder("condition");
1855   cond_builder.AddInstruction(
1856       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1857   cond_builder.AddInstruction(
1858       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1859   HloComputation* condition =
1860       module->AddEmbeddedComputation(cond_builder.Build());
1861 
1862   auto xla_while = builder.AddInstruction(
1863       HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
1864 
1865   module->AddEntryComputation(builder.Build());
1866 
1867   InsertCopies(module.get());
1868 
1869   EXPECT_EQ(CountCopies(*module), 2);
1870 
1871   EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
1872   EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
1873   EXPECT_THAT(condition->root_instruction(), op::Constant());
1874 }
1875 
TEST_F(CopyInsertionTest,TokensShouldNotBeCopied)1876 TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
1877   std::string module_string = R"(
1878 HloModule TokensShouldNotBeCopied
1879 
1880 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
1881   %param.1 = (s32[], token[]) parameter(0)
1882   %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
1883   %constant.1 = s32[] constant(1)
1884   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
1885   %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
1886   %after-all = token[] after-all(token[] %get-tuple-element.2)
1887   ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
1888 }
1889 
1890 %Cond (param: (s32[], token[])) -> pred[] {
1891   %param = (s32[], token[]) parameter(0)
1892   %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
1893   %constant = s32[] constant(42)
1894   ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
1895 }
1896 
1897 ENTRY %TokensShouldNotBeCopied () -> s32[] {
1898   %one = s32[] constant(1)
1899   %negative_one = s32[] negate(%one)
1900   %init_token = token[] after-all()
1901   %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
1902   %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
1903   ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
1904 }
1905 )";
1906   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1907                           ParseAndReturnVerifiedModule(module_string));
1908   InsertCopies(module.get());
1909 
1910   // There should be no copies added because tokens should not be copied.
1911   EXPECT_EQ(CountCopies(*module), 0);
1912 }
1913 
MakeTrivialCondition(const Shape & shape)1914 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
1915   auto builder = HloComputation::Builder("trivial_condition");
1916   builder.AddInstruction(
1917       HloInstruction::CreateParameter(0, shape, "loop_state"));
1918   auto constant = builder.AddInstruction(
1919       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1920   builder.AddInstruction(HloInstruction::CreateUnary(
1921       constant->shape(), HloOpcode::kNot, constant));
1922   return builder.Build();
1923 }
1924 
MakeBenchmarkWhileBody()1925 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
1926   auto builder = HloComputation::Builder("benchmark_loop_body");
1927   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1928   const Shape loop_state_shape =
1929       ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
1930   HloInstruction* param = builder.AddInstruction(
1931       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
1932   HloInstruction* element_0 = builder.AddInstruction(
1933       HloInstruction::CreateGetTupleElement(element_shape, param, 0));
1934   HloInstruction* element_1 = builder.AddInstruction(
1935       HloInstruction::CreateGetTupleElement(element_shape, param, 1));
1936   HloInstruction* element_2 = builder.AddInstruction(
1937       HloInstruction::CreateGetTupleElement(element_shape, param, 2));
1938 
1939   HloInstruction* rev_1 = builder.AddInstruction(
1940       HloInstruction::CreateReverse(element_shape, element_1, {0}));
1941   HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
1942       element_shape, HloOpcode::kAdd, element_1, element_2));
1943 
1944   builder.AddInstruction(
1945       HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
1946   return builder.Build();
1947 }
1948 
BM_SequentialWhiles(::testing::benchmark::State & state)1949 void BM_SequentialWhiles(::testing::benchmark::State& state) {
1950   const int num_whiles = state.range(0);
1951 
1952   // This benchmark constructs a chain of sequential while instructions.
1953   // Timer starts automatically at the first iteration of this loop
1954   // and ends after the last one.
1955   for (auto s : state) {
1956     state.PauseTiming();
1957     HloModuleConfig config;
1958     config.set_debug_options(GetDebugOptionsFromFlags());
1959     HloModule module("BM_SequentialWhiles", config);
1960 
1961     auto builder = HloComputation::Builder("BM_SequentialWhiles");
1962     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
1963         0, ShapeUtil::MakeShape(F32, {42}), "x"));
1964     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
1965         1, ShapeUtil::MakeShape(F32, {42}), "y"));
1966     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
1967         2, ShapeUtil::MakeShape(F32, {42}), "z"));
1968     HloInstruction* init =
1969         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
1970 
1971     HloInstruction* prev_loop_state = init;
1972     for (int w = 0; w < num_whiles; ++w) {
1973       HloComputation* condition =
1974           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
1975       HloComputation* body =
1976           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
1977       prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
1978           init->shape(), condition, body, prev_loop_state));
1979     }
1980     module.AddEntryComputation(builder.Build());
1981 
1982     CopyInsertion copy_insertion;
1983 
1984     state.ResumeTiming();
1985     ASSERT_IS_OK(copy_insertion.Run(&module).status());
1986     state.PauseTiming();
1987 
1988     // The entry computation should have three copies, and each body has one.
1989     ASSERT_EQ(CountCopies(module), 3 + num_whiles);
1990     state.ResumeTiming();
1991   }
1992 }
1993 
BM_ParallelWhiles(::testing::benchmark::State & state)1994 void BM_ParallelWhiles(::testing::benchmark::State& state) {
1995   const int num_whiles = state.range(0);
1996 
1997   // This benchmark constructs a fan-out of parallel while instructions.
1998   for (auto s : state) {
1999     state.PauseTiming();
2000     HloModuleConfig config;
2001     config.set_debug_options(GetDebugOptionsFromFlags());
2002     HloModule module("BM_SequentialWhiles", config);
2003 
2004     auto builder = HloComputation::Builder("BM_ParallelWhiles");
2005     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
2006         0, ShapeUtil::MakeShape(F32, {42}), "x"));
2007     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
2008         1, ShapeUtil::MakeShape(F32, {42}), "y"));
2009     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
2010         2, ShapeUtil::MakeShape(F32, {42}), "z"));
2011     HloInstruction* init =
2012         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
2013 
2014     HloInstruction* sum = nullptr;
2015     for (int w = 0; w < num_whiles; ++w) {
2016       HloComputation* condition =
2017           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2018       HloComputation* body =
2019           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
2020 
2021       HloInstruction* xla_while = builder.AddInstruction(
2022           HloInstruction::CreateWhile(init->shape(), condition, body, init));
2023 
2024       if (sum == nullptr) {
2025         sum = builder.AddInstruction(
2026             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2027       } else {
2028         HloInstruction* element_0 = builder.AddInstruction(
2029             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2030         sum = builder.AddInstruction(HloInstruction::CreateBinary(
2031             x->shape(), HloOpcode::kAdd, sum, element_0));
2032       }
2033     }
2034     module.AddEntryComputation(builder.Build());
2035 
2036     CopyInsertion copy_insertion;
2037 
2038     state.ResumeTiming();
2039     ASSERT_IS_OK(copy_insertion.Run(&module).status());
2040     state.PauseTiming();
2041 
2042     // Each body receives of copy of two of the parameters (the corresponding
2043     // elements in the body are modified), and there is one copy in each body.
2044     ASSERT_EQ(CountCopies(module), 3 * num_whiles);
2045   }
2046 }
2047 
MakeBenchmarkWhileBody(const int num_tuple_inputs)2048 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
2049     const int num_tuple_inputs) {
2050   auto builder = HloComputation::Builder("benchmark_loop_body");
2051   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2052   std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
2053   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
2054   HloInstruction* param = builder.AddInstruction(
2055       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2056   std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
2057   for (int i = 0; i < num_tuple_inputs; ++i) {
2058     gte_nodes[i] = builder.AddInstruction(
2059         HloInstruction::CreateGetTupleElement(element_shape, param, i));
2060   }
2061   builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
2062   return builder.Build();
2063 }
2064 
BM_ManyElementTuple(::testing::benchmark::State & state)2065 void BM_ManyElementTuple(::testing::benchmark::State& state) {
2066   const int num_tuple_inputs = state.range(0);
2067   HloModuleConfig config;
2068   config.set_debug_options(GetDebugOptionsFromFlags());
2069   CopyInsertion copy_insertion;
2070   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2071   std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
2072   for (auto s : state) {
2073     state.PauseTiming();
2074     auto builder = HloComputation::Builder("BM_ParallelWhiles");
2075     HloModule module("BM_ManyElementTuple", config);
2076     for (int j = 0; j < num_tuple_inputs; ++j) {
2077       tuple_params[j] = builder.AddInstruction(
2078           HloInstruction::CreateParameter(j, element_shape, ""));
2079     }
2080     HloInstruction* init =
2081         builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
2082     HloComputation* condition =
2083         module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2084     HloComputation* body =
2085         module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
2086     HloInstruction* xla_while = builder.AddInstruction(
2087         HloInstruction::CreateWhile(init->shape(), condition, body, init));
2088     builder.AddInstruction(HloInstruction::CreateGetTupleElement(
2089         ShapeUtil::MakeShape(F32, {}), xla_while, 0));
2090     module.AddEntryComputation(builder.Build());
2091     state.ResumeTiming();
2092     ASSERT_IS_OK(copy_insertion.Run(&module).status());
2093   }
2094 }
2095 
2096 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2097 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2098 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
2099 
TEST_F(CopyInsertionTest,SimpleControlFlowTest)2100 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
2101   const std::string& hlo_string = R"(
2102 HloModule TestModule
2103 
2104 if-body.v5 {
2105   constant.3 = s32[] constant(-1)
2106   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2107   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2108   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2109   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2110   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2111   tuple.33 = (s32[]) tuple(add.3)
2112   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2113 }
2114 
2115 if-condition.v4 {
2116   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2117   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2118   constant.4 = s32[] constant(0)
2119   ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2120 }
2121 
2122 _functionalize_body_1__.v28 {
2123   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2124   get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
2125   constant.7 = s32[] constant(1)
2126   add.4 = s32[] add(get-tuple-element.68, constant.7)
2127   get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
2128   get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
2129   less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
2130   constant.8 = s32[] constant(0)
2131   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2132   get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
2133   tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
2134   tuple.36 = (s32[]) tuple(constant.8)
2135   tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
2136   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
2137   get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
2138   get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
2139   ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
2140 }
2141 
2142 cond_wrapper.v3.1 {
2143   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2144   get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
2145   constant.11 = s32[] constant(7)
2146   ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
2147 }
2148 
2149 _functionalize_body_2__.v25 {
2150   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2151   get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
2152   get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
2153   get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
2154   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
2155   tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
2156   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2157   get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
2158   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
2159   constant.12 = s32[] constant(1)
2160   add.5 = s32[] add(get-tuple-element.81, constant.12)
2161   get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
2162   ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
2163 }
2164 
2165 cond_wrapper.v3.2 {
2166   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2167   get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
2168   constant.13 = s32[] constant(5)
2169   ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
2170 }
2171 
2172 ENTRY TestComputation {
2173   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2174   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2175 }
2176 )";
2177   auto module = ParseAndReturnVerifiedModule(hlo_string).value();
2178   InsertCopies(module.get());
2179 }
2180 
TEST_F(CopyInsertionTest,ControlFlowTest)2181 TEST_F(CopyInsertionTest, ControlFlowTest) {
2182   const std::string& hlo_string = R"(
2183 HloModule TestModule
2184 
2185 if-body.v5 {
2186   constant.3 = s32[] constant(-1)
2187   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2188   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2189   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2190   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2191   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2192   tuple.33 = (s32[]) tuple(add.3)
2193   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2194 }
2195 
2196 if-condition.v4 {
2197   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2198   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2199   constant.4 = s32[] constant(0)
2200   ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2201 }
2202 
2203 if-body.v5.1 {
2204   constant.5 = s32[] constant(-1)
2205   p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2206   get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
2207   get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
2208   multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
2209   tuple.35 = (s32[]) tuple(multiply.1)
2210   ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
2211 }
2212 
2213 if-condition.v4.1 {
2214   p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2215   get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
2216   constant.6 = s32[] constant(1)
2217   ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
2218 }
2219 
2220 _functionalize_body_1__.v28 {
2221   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2222   get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
2223   constant.7 = s32[] constant(1)
2224   add.4 = s32[] add(get-tuple-element.72, constant.7)
2225   get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
2226   get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
2227   less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
2228   constant.8 = s32[] constant(0)
2229   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2230   get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
2231   tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
2232   tuple.38 = (s32[]) tuple(constant.8)
2233   tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
2234   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
2235   while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
2236   get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
2237   get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
2238   ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
2239 }
2240 
2241 cond_wrapper.v3.1 {
2242   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2243   get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
2244   constant.11 = s32[] constant(7)
2245   ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
2246 }
2247 
2248 _functionalize_body_2__.v25 {
2249   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2250   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
2251   get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
2252   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
2253   get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
2254   tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
2255   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2256   get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
2257   get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
2258   constant.12 = s32[] constant(1)
2259   add.5 = s32[] add(get-tuple-element.84, constant.12)
2260   get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
2261   ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
2262 }
2263 
2264 cond_wrapper.v3.2 {
2265   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2266   get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
2267   constant.13 = s32[] constant(5)
2268   ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
2269 }
2270 
2271 ENTRY TestComputation {
2272   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2273   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2274 }
2275 )";
2276   auto module = ParseAndReturnVerifiedModule(hlo_string).value();
2277   InsertCopies(module.get());
2278 }
2279 
TEST_F(CopyInsertionTest,NestedWhiles)2280 TEST_F(CopyInsertionTest, NestedWhiles) {
2281   // Verify that only no unnecessary copies remain after copy insertion for
2282   // trivial nested whiles (b/112472605).
2283   const std::string& hlo_string = R"(
2284 HloModule TestModule
2285 
2286 cond.inner {
2287   ROOT param.cond.inner = pred[] parameter(0)
2288 }
2289 
2290 body.inner {
2291   param.body.inner = pred[] parameter(0)
2292   ROOT not = pred[] not(param.body.inner)
2293 }
2294 
2295 cond.outer {
2296   ROOT param.cond.outer = pred[] parameter(0)
2297 }
2298 
2299 body.outer {
2300   param.cond.outer = pred[] parameter(0)
2301   ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2302 }
2303 
2304 ENTRY TestComputation {
2305   entry_param = pred[] parameter(0)
2306   ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2307 }
2308 )";
2309   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2310                           ParseAndReturnVerifiedModule(hlo_string));
2311   InsertCopies(module.get());
2312 
2313   // There should only be a single copy inserted, and it's in the entry
2314   // computation.
2315   EXPECT_EQ(CountCopies(*module), 1);
2316   EXPECT_THAT(module->entry_computation()->root_instruction(),
2317               op::While(op::Copy(op::Parameter())));
2318 }
2319 
TEST_F(CopyInsertionTest,NestedWhilesWithParamRoot)2320 TEST_F(CopyInsertionTest, NestedWhilesWithParamRoot) {
2321   // Test that when the root of a computation is before other side-effecting
2322   // operation (e.g. when the while body computation parameter is the root), we
2323   // introduce an interference edge and copy at the level of this outer loop
2324   // body and not one level out.
2325   const std::string& hlo_string = R"(
2326 HloModule TestModule
2327 
2328 cond.inner {
2329   ROOT param.cond.inner = pred[] parameter(0)
2330 }
2331 
2332 body.inner {
2333   param.body.inner = pred[] parameter(0)
2334   ROOT not = pred[] not(param.body.inner)
2335 }
2336 
2337 cond.outer {
2338   ROOT param.cond.outer = pred[] parameter(0)
2339 }
2340 
2341 body.outer {
2342   ROOT param.cond.outer = pred[] parameter(0)
2343   while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2344   after-all = token[] after-all()
2345   outfeed = token[] outfeed(while, after-all)
2346 }
2347 
2348 ENTRY TestComputation {
2349   entry_param = pred[] parameter(0)
2350   while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2351   ROOT not = pred[] not(while)
2352 }
2353 )";
2354   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2355                           ParseAndReturnVerifiedModule(hlo_string));
2356   InsertCopies(module.get());
2357 
2358   // There should only be a single copy inserted, and it's in the outer while
2359   // loop body.
2360   EXPECT_EQ(CountCopies(*module), 1);
2361   EXPECT_THAT(module->entry_computation()->root_instruction(),
2362               op::Not(op::While(op::Parameter())));
2363   HloInstruction* outfeed = FindInstruction(module.get(), "outfeed");
2364   EXPECT_THAT(outfeed, op::Outfeed(op::While(op::Copy(op::Parameter(0))),
2365                                    op::AfterAll()));
2366 }
2367 
TEST_F(CopyInsertionTest,NestedWhilesWithParamRoot2)2368 TEST_F(CopyInsertionTest, NestedWhilesWithParamRoot2) {
2369   // Test that when the root of a computation is before other side-effecting
2370   // operation (e.g. when the while body computation parameter is the root), we
2371   // introduce an interference edge and copy at the level of this outer loop
2372   // body and not one level out.
2373   const std::string& hlo_string = R"(
2374 HloModule TestModule
2375 
2376 cond.inner {
2377   param.cond.inner = (pred[], pred[]) parameter(0)
2378   ROOT gte = pred[] get-tuple-element(param.cond.inner), index=0
2379 }
2380 
2381 body.inner {
2382   param.body.inner = (pred[], pred[]) parameter(0)
2383   gte.0 = pred[] get-tuple-element(param.body.inner), index=0
2384   gte.1 = pred[] get-tuple-element(param.body.inner), index=1
2385   and = pred[] and(gte.0, gte.1)
2386   not = pred[] not(gte.1)
2387   ROOT root = (pred[], pred[]) tuple(and, not)
2388 }
2389 
2390 cond.outer {
2391   param.cond.outer = (pred[], pred[]) parameter(0)
2392   ROOT gte = pred[] get-tuple-element(param.cond.outer), index=0
2393 }
2394 
2395 body.outer {
2396   param.body.outer = (pred[], pred[]) parameter(0)
2397   gte.0 = pred[] get-tuple-element(param.body.outer), index=0
2398   gte.1 = pred[] get-tuple-element(param.body.outer), index=1
2399   while.inner = (pred[], pred[]) while(param.body.outer), condition=cond.inner, body=body.inner
2400   gte.2 = pred[] get-tuple-element(while.inner), index=0
2401   after-all = token[] after-all()
2402   outfeed = token[] outfeed(gte.2, after-all)
2403   ROOT root = (pred[], pred[]) tuple(gte.0, gte.1)
2404 }
2405 
2406 ENTRY TestComputation {
2407   entry_param.1 = pred[] parameter(0)
2408   entry_param.2 = pred[] parameter(1)
2409   tuple = (pred[], pred[]) tuple(entry_param.1, entry_param.2)
2410   while.outer = (pred[], pred[]) while(tuple), condition=cond.outer, body=body.outer
2411   gte = pred[] get-tuple-element(while.outer), index=0
2412   ROOT not = pred[] not(gte)
2413 }
2414 )";
2415   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2416                           ParseAndReturnVerifiedModule(hlo_string));
2417   InsertCopies(module.get());
2418 
2419   HloInstruction* while_inner = FindInstruction(module.get(), "while.inner");
2420   EXPECT_THAT(
2421       while_inner,
2422       op::While(op::Tuple(op::Copy(op::GetTupleElement(op::Parameter(0))),
2423                           op::Copy(op::GetTupleElement(op::Parameter(0))))));
2424 }
2425 
TEST_F(CopyInsertionTest,NestedWhileAndConditional2)2426 TEST_F(CopyInsertionTest, NestedWhileAndConditional2) {
2427   const std::string& hlo_string = R"(
2428 HloModule TestModule
2429 
2430 on_true
2431  {
2432   v1 = f32[2] parameter(0)
2433   v2 = f32[2] add(v1,v1)
2434   ROOT t1 = (f32[2], f32[2]) tuple(v1,v2)
2435 }
2436 
2437 on_false
2438  {
2439   v1 = f32[2] parameter(0)
2440   v2 = f32[2] multiply(v1,v1)
2441   ROOT t2 = (f32[2], f32[2]) tuple(v1,v2)
2442 }
2443 
2444 cond.outer {
2445   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2446   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2447 }
2448 
2449 body.outer {
2450   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2451   pred.1 = pred[] get-tuple-element(param.1), index=0
2452   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2453   if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2454   e1 = f32[2] get-tuple-element(if), index=0
2455   e2 = f32[2] get-tuple-element(if), index=1
2456   ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2457 }
2458 
2459 ENTRY TestComputation {
2460   entry_param.1 = pred[] parameter(0)
2461   float_param = f32[2] parameter(1)
2462   entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2463   ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2464 }
2465 )";
2466   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2467                           ParseAndReturnVerifiedModule(hlo_string));
2468   InsertCopies(module.get());
2469 
2470   // An extra copy must be kept inside the loop due to uses in the conditional.
2471   EXPECT_EQ(CountCopies(*module), 3);
2472 }
2473 
TEST_F(CopyInsertionTest,NestedWhileAndConditional)2474 TEST_F(CopyInsertionTest, NestedWhileAndConditional) {
2475   const std::string& hlo_string = R"(
2476 HloModule TestModule
2477 
2478 on_true
2479  {
2480   v1 = f32[2] parameter(0)
2481   ROOT v2 = f32[2] add(v1,v1)
2482 }
2483 
2484 on_false
2485  {
2486   v1 = f32[2] parameter(0)
2487   ROOT v2 = f32[2] multiply(v1,v1)
2488 }
2489 
2490 cond.outer {
2491   param.1 = (pred[], f32[2]) parameter(0)
2492   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2493 }
2494 
2495 body.outer {
2496   param.1 = (pred[], f32[2]) parameter(0)
2497   pred.1 = pred[] get-tuple-element(param.1), index=0
2498   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2499   if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2500   ROOT res = (pred[], f32[2]) tuple(pred.1,if)
2501 }
2502 
2503 ENTRY TestComputation {
2504   entry_param.1 = pred[] parameter(0)
2505   float_param = f32[2] parameter(1)
2506   entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param)
2507   ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2508 }
2509 )";
2510   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2511                           ParseAndReturnVerifiedModule(hlo_string));
2512   InsertCopies(module.get());
2513   VLOG(2) << module->ToString() << "\n";
2514 
2515   // There should only be two copies inserted, and in the entry and exit of the
2516   // computation.
2517   EXPECT_EQ(CountCopies(*module), 2);
2518 }
2519 
TEST_F(CopyInsertionTest,FixpointComputationRequired)2520 TEST_F(CopyInsertionTest, FixpointComputationRequired) {
2521   const std::string& hlo_string = R"(
2522 HloModule Module
2523 
2524 fused_computation {
2525   param0 = f32[3,3,96,1] parameter(0)
2526   param1 = f32[] parameter(1)
2527   broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={}
2528   ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast)
2529 }
2530 
2531 ENTRY entry_computation {
2532   arg0 = f32[3,3,96,1] parameter(0)
2533   arg1 = f32[] parameter(1)
2534   fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1),
2535     kind=kLoop, calls=fused_computation
2536   negate = f32[] negate(f32[] arg1)
2537   ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple(
2538     f32[3,3,96,1] fusion,
2539     f32[3,3,96,1] arg0,
2540     f32[] negate,
2541     f32[] arg1)
2542 }
2543   )";
2544 
2545   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2546                           ParseAndReturnVerifiedModule(hlo_string));
2547   // Set up the aliasing manually which normally would be set by
2548   // alias_passthrough_params pass.
2549   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2550       /*output_index=*/{1},
2551       /*param_number=*/0,
2552       /*param_index=*/{}));
2553   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2554       /*output_index=*/{3},
2555       /*param_number=*/1,
2556       /*param_index=*/{}));
2557 
2558   InsertCopies(module.get());
2559 
2560   // There should be no copies inserted.
2561   EXPECT_EQ(CountCopies(*module), 0);
2562 }
2563 
TEST_F(CopyInsertionTest,NoAliasCheckViolation)2564 TEST_F(CopyInsertionTest, NoAliasCheckViolation) {
2565   const std::string& hlo_string = R"(
2566 HloModule cluster
2567 
2568 ENTRY Entry {
2569   %arg = f32[8,28,28,1] parameter(0)
2570   %bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg)
2571   ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg)
2572 }
2573 )";
2574   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2575                           ParseAndReturnVerifiedModule(hlo_string));
2576   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2577       /*output_index=*/{1},
2578       /*param_number=*/0,
2579       /*param_index=*/{}));
2580   InsertCopies(module.get());
2581   EXPECT_EQ(CountCopies(*module), 1);
2582 }
2583 
TEST_F(CopyInsertionTest,DynamicUpdateSliceNoCopy)2584 TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
2585   absl::string_view hlo_string = R"(
2586 HloModule Module
2587 
2588 ENTRY main {
2589   param = f32[1280,1,128] parameter(0)
2590   negate = f32[1280,1,128] negate(param)
2591   constant.1 = f32[] constant(0)
2592   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2593   constant.3 = s32[] constant(0)
2594   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2595 }
2596 )";
2597   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2598                           ParseAndReturnVerifiedModule(hlo_string));
2599   InsertCopies(module.get());
2600   EXPECT_EQ(CountCopies(*module), 0);
2601 }
2602 
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceNoCopy)2603 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
2604   absl::string_view hlo_string = R"(
2605 HloModule Module
2606 
2607 fused_computation {
2608   param0 = f32[1280,1,128] parameter(0)
2609   constant.1 = f32[] constant(0)
2610   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2611   constant.3 = s32[] constant(0)
2612   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2613 }
2614 
2615 ENTRY main {
2616   param = f32[1280,1,128] parameter(0)
2617   negate = f32[1280,1,128] negate(param)
2618   ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2619 }
2620 )";
2621   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2622                           ParseAndReturnVerifiedModule(hlo_string));
2623   InsertCopies(module.get());
2624   EXPECT_EQ(CountCopies(*module), 0);
2625 }
2626 
TEST_F(CopyInsertionTest,DynamicUpdateSliceCopy)2627 TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
2628   absl::string_view hlo_string = R"(
2629 HloModule Module
2630 
2631 ENTRY main {
2632   param = f32[1280,1,128] parameter(0)
2633   negate = f32[1280,1,128] negate(param)
2634   constant.1 = f32[] constant(0)
2635   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2636   constant.3 = s32[] constant(0)
2637   add = f32[1280,1,128] add(negate, negate)
2638   dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2639   ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
2640 }
2641 )";
2642   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2643                           ParseAndReturnVerifiedModule(hlo_string));
2644   InsertCopies(module.get());
2645   EXPECT_EQ(CountCopies(*module), 1);
2646 }
2647 
TEST_F(CopyInsertionTest,DynamicUpdateSliceParameterShareCopy)2648 TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
2649   absl::string_view hlo_string = R"(
2650 HloModule Module
2651 
2652 ENTRY main {
2653   param = f32[1280,1,128] parameter(0)
2654   constant.1 = f32[] constant(0)
2655   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2656   constant.3 = s32[] constant(0)
2657   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
2658 }
2659 )";
2660   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2661                           ParseAndReturnVerifiedModule(hlo_string));
2662   InsertCopies(module.get());
2663   EXPECT_EQ(CountCopies(*module), 1);
2664 }
2665 
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy)2666 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
2667   absl::string_view hlo_string = R"(
2668 HloModule Module
2669 
2670 fused_computation {
2671   param0 = f32[1280,1,128] parameter(0)
2672   constant.1 = f32[] constant(0)
2673   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2674   constant.3 = s32[] constant(0)
2675   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2676 }
2677 
2678 ENTRY main {
2679   param = f32[1280,1,128] parameter(0)
2680   negate = f32[1280,1,128] negate(param)
2681   add = f32[1280,1,128] add(negate, negate)
2682   fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2683   ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
2684 }
2685 )";
2686   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2687                           ParseAndReturnVerifiedModule(hlo_string));
2688   InsertCopies(module.get());
2689   EXPECT_EQ(CountCopies(*module), 1);
2690 }
2691 
TEST_F(CopyInsertionTest,ChainDynamicUpdateSliceCopy)2692 TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
2693   absl::string_view hlo_string = R"(
2694 HloModule Module
2695 
2696 ENTRY main {
2697   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2698   constant.1 = f32[] constant(0)
2699   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2700   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2701   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2702   constant.2 = s32[] constant(128)
2703   add.5 = s32[] add(get-tuple-element.3, constant.2)
2704   constant.3 = s32[] constant(0)
2705   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2706   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2707   ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2708 }
2709 )";
2710   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2711                           ParseAndReturnVerifiedModule(hlo_string));
2712   InsertCopies(module.get());
2713   EXPECT_EQ(CountCopies(*module), 1);
2714 }
2715 
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy2)2716 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
2717   absl::string_view hlo_string = R"(
2718 HloModule Module
2719 
2720 fused_computation.1 {
2721   param0 = f32[1280,1,128] parameter(0)
2722   constant.1 = f32[] constant(0)
2723   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2724   constant.3 = s32[] constant(0)
2725   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2726 }
2727 
2728 fused_computation.2 {
2729   param0 = f32[1280,1,128] parameter(0)
2730   param1 = f32[1280,1,128] parameter(1)
2731   slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
2732   constant.3 = s32[] constant(0)
2733   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
2734 }
2735 
2736 ENTRY main {
2737   param = f32[1280,1,128] parameter(0)
2738   negate = f32[1280,1,128] negate(param)
2739   add = f32[1280,1,128] add(negate, negate)
2740   fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
2741   ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
2742 }
2743 )";
2744   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2745                           ParseAndReturnVerifiedModule(hlo_string));
2746   InsertCopies(module.get());
2747   EXPECT_EQ(CountCopies(*module), 1);
2748 }
2749 
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceCopy)2750 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
2751   // Tests multi-output fusion with two DUS outputs, requiring two copies.
2752   absl::string_view hlo_string = R"(
2753 HloModule Module
2754 
2755 fused_computation {
2756   param0 = f32[1280,1,128] parameter(0)
2757   param1 = f32[1280,1,128] parameter(1)
2758   param2 = f32[1280,1,128] parameter(2)
2759   constant.1 = f32[] constant(0)
2760   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2761   constant.3 = s32[] constant(0)
2762   add.1 = f32[1280,1,128] add(param0, param0)
2763   dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2764   dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2765   ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2766 }
2767 
2768 ENTRY main {
2769   param = f32[1280,1,128] parameter(0)
2770   negate0 = f32[1280,1,128] negate(param)
2771   negate1 = f32[1280,1,128] negate(param)
2772   negate2 = f32[1280,1,128] negate(param)
2773   fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2774   gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2775   gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2776   gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2777   add0 = f32[1280,1,128] add(negate0, gte0)
2778   add1 = f32[1280,1,128] add(negate1, gte1)
2779   add2 = f32[1280,1,128] add(negate2, gte2)
2780   ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2781 }
2782 )";
2783   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2784                           ParseAndReturnVerifiedModule(hlo_string));
2785   InsertCopies(module.get());
2786   EXPECT_EQ(CountCopies(*module), 2);
2787 }
2788 
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceNoCopy)2789 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
2790   // Same as above, but negate1 is not used beyond fusion, so it only needs one
2791   // copy for negate0.
2792   absl::string_view hlo_string = R"(
2793 HloModule Module
2794 
2795 fused_computation {
2796   param0 = f32[1280,1,128] parameter(0)
2797   param1 = f32[1280,1,128] parameter(1)
2798   param2 = f32[1280,1,128] parameter(2)
2799   constant.1 = f32[] constant(0)
2800   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2801   constant.3 = s32[] constant(0)
2802   add.1 = f32[1280,1,128] add(param0, param0)
2803   dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2804   dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2805   ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2806 }
2807 
2808 ENTRY main {
2809   param = f32[1280,1,128] parameter(0)
2810   negate0 = f32[1280,1,128] negate(param)
2811   negate1 = f32[1280,1,128] negate(param)
2812   negate2 = f32[1280,1,128] negate(param)
2813   fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2814   gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2815   gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2816   gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2817   add0 = f32[1280,1,128] add(negate0, gte0)
2818   add1 = f32[1280,1,128] add(gte1, gte1)
2819   add2 = f32[1280,1,128] add(negate2, gte2)
2820   ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2821 }
2822 )";
2823   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2824                           ParseAndReturnVerifiedModule(hlo_string));
2825   InsertCopies(module.get());
2826   EXPECT_EQ(CountCopies(*module), 1);
2827 }
2828 
TEST_F(CopyInsertionTest,ScatterSharedOperand)2829 TEST_F(CopyInsertionTest, ScatterSharedOperand) {
2830   // If an in-place op has an additional operand that has the same value as the
2831   // in-place buffer, a copy needs to be inserted on one of these values only.
2832   absl::string_view hlo_string = R"(
2833 HloModule Module
2834 
2835 update_s32 {
2836   lhs = s32[] parameter(0)
2837   ROOT rhs = s32[] parameter(1)
2838 }
2839 
2840 fused_computation {
2841   iota.1 = s32[73729]{0} iota(), iota_dimension=0
2842   ROOT indices.1 = s32[73729]{0} reverse(iota.1), dimensions={0}
2843 }
2844 
2845 ENTRY main {
2846   iota.2 = s32[73729]{0} iota(), iota_dimension=0
2847   fusion = s32[73729]{0} fusion(), kind=kLoop, calls=fused_computation
2848   ROOT scatter = s32[73729]{0} scatter(iota.2, fusion, iota.2), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=update_s32
2849 }
2850 )";
2851   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2852                           ParseAndReturnVerifiedModule(hlo_string));
2853   InsertCopies(module.get());
2854   EXPECT_EQ(CountCopies(*module), 1);
2855   EXPECT_THAT(module->entry_computation()->root_instruction(),
2856               op::Scatter(op::Copy(op::Iota()), op::Fusion(), op::Iota()));
2857 }
2858 
TEST_F(CopyInsertionTest,HorizontalLoopFusionNoCopy)2859 TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
2860   const std::string& hlo_string = R"(
2861     HloModule test
2862 
2863     fused_computation {
2864       p0 = f32[10,20] parameter(0)
2865       p1 = f32[10,20] parameter(1)
2866       p2 = f32[10,10] parameter(2)
2867       p3 = f32[10,10] parameter(3)
2868       add0 = f32[10, 20] add(p0, p1)
2869       sub0 = f32[10, 10] subtract(p2, p3)
2870       reshape0 = f32[200] reshape(add0)
2871       reshape1 = f32[100] reshape(sub0)
2872       concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
2873       slice0 = f32[200] slice(concat0), slice={[0:200]}
2874       slice1 = f32[100] slice(concat0), slice={[200:300]}
2875       ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
2876     }
2877 
2878     ENTRY test {
2879       p0 = f32[10,20] parameter(0)
2880       p1 = f32[10,20] parameter(1)
2881       p2 = f32[10,10] parameter(2)
2882       p3 = f32[10,10] parameter(3)
2883       fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
2884       gte0 = f32[200] get-tuple-element(fusion), index=0
2885       gte1 = f32[100] get-tuple-element(fusion), index=1
2886       bitcast0 = f32[10,20] bitcast(gte0)
2887       bitcast1 = f32[10,10] bitcast(gte1)
2888       ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
2889     }
2890   )";
2891 
2892   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2893                           ParseAndReturnVerifiedModule(hlo_string));
2894   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2895       /*output_index=*/{0},
2896       /*param_number=*/0,
2897       /*param_index=*/{}));
2898   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2899       /*output_index=*/{1},
2900       /*param_number=*/3,
2901       /*param_index=*/{}));
2902 
2903   InsertCopies(module.get());
2904 
2905   // There should be no copies inserted.
2906   EXPECT_EQ(CountCopies(*module), 0);
2907 }
2908 
TEST_F(CopyInsertionTest,NestedWhileAndConditional3)2909 TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
2910   const std::string& hlo_string = R"(
2911 HloModule TestModule
2912 
2913 on_true.1
2914  {
2915   ROOT v1 = f32[2] parameter(0)
2916 }
2917 
2918 on_false.1
2919  {
2920   v1 = f32[2] parameter(0)
2921   ROOT v2 = f32[2] multiply(v1,v1)
2922 }
2923 
2924 on_true
2925  {
2926   v1 = f32[2] parameter(0)
2927   v2 = f32[2] add(v1,v1)
2928   v3 = (f32[2],f32[2]) tuple(v1,v2)
2929   v4 = f32[2] get-tuple-element(v3), index=1
2930   v5 = f32[2] multiply(v4,v2)
2931    ROOT t1 = (f32[2], f32[2]) tuple(v5,v2)
2932 }
2933 
2934 on_false
2935  {
2936   v1 = f32[2] parameter(0)
2937   v2 = f32[2] multiply(v1,v1)
2938   pred.1 = pred[] constant(true)
2939   v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1
2940   v5 = f32[2] multiply(v4,v2)
2941   ROOT t2 = (f32[2], f32[2]) tuple(v2,v5)
2942 
2943 }
2944 
2945 cond.outer {
2946   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2947   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2948 }
2949 
2950 body.outer {
2951   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2952   pred.1 = pred[] get-tuple-element(param.1), index=0
2953   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2954   if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2955   e1 = f32[2] get-tuple-element(if), index=0
2956   e2 = f32[2] get-tuple-element(if), index=1
2957   ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2958 }
2959 
2960 ENTRY TestComputation {
2961   entry_param.1 = pred[] parameter(0)
2962   float_param = f32[2] parameter(1)
2963   entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2964   ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2965 }
2966 )";
2967   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2968                           ParseAndReturnVerifiedModule(hlo_string));
2969   InsertCopies(module.get());
2970   // An extra copy must be kept inside the loop due to uses in the conditional
2971   EXPECT_EQ(CountCopies(*module), 4);
2972 }
2973 
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy1)2974 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy1) {
2975   const std::string& hlo_string = R"(
2976 HloModule TestModule
2977 
2978  branch_0_comp.5.clone {
2979  %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
2980  %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
2981  %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
2982  %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
2983  ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
2984  }
2985 
2986  branch_1_comp.12.clone {
2987   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
2988   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
2989   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
2990   ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
2991  }
2992 
2993 ENTRY TestComputation {
2994   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
2995   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
2996   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
2997   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
2998   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
2999   %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3000   %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3001   ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3002 }
3003 )";
3004   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3005                           ParseAndReturnVerifiedModule(hlo_string));
3006   InsertCopies(module.get());
3007   CopyInsertion copy_insertion(nullptr,
3008                                /*use_region_based_live_range_analysis=*/-1);
3009   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3010   VLOG(3) << module->ToString();
3011   // The copy.1 must be kept due to modification in the other branch.
3012   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3013   CHECK_NE(conditional18, nullptr);
3014   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3015   CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3016   auto copy1 = tuple6->operand(0);
3017   CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3018 }
3019 
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy2)3020 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy2) {
3021   const std::string& hlo_string = R"(
3022 HloModule TestModule
3023 
3024  branch_0_comp.5.clone {
3025  %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3026  %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3027  %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3028  %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3029  ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3030  }
3031 
3032  branch_1_comp.12.clone {
3033   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3034   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3035   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3036   %constant.1 = s32[] constant(0)
3037   %broadcast.6 = s32[2] broadcast(constant.1), dimensions={}
3038   dynamic-update-slice.5 = s32[2]{0:T(128)} dynamic-update-slice(%copy.1, %broadcast.6, %constant.1)
3039   %add.1 = s32[2]{0:T(128)} add(dynamic-update-slice.5, %copy.1)
3040   ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%add.1)
3041  }
3042 
3043 ENTRY TestComputation {
3044   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3045   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3046   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3047   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3048   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3049   %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}
3050   %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3051   ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3052 }
3053 )";
3054   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3055                           ParseAndReturnVerifiedModule(hlo_string));
3056   CopyInsertion copy_insertion(nullptr,
3057                                /*use_region_based_live_range_analysis=*/-1);
3058   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3059   // The copy.1 must be kept due to modification in the other branch.
3060   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3061   CHECK_NE(conditional18, nullptr);
3062   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3063   CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3064   auto add1 = tuple6->operand(0);
3065   CHECK_EQ(add1->opcode(), HloOpcode::kAdd);
3066   auto dus = add1->operand(0);
3067   auto copy1 = dus->operand(0);
3068   CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3069 }
3070 
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy3)3071 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy3) {
3072   const std::string& hlo_string = R"(
3073 HloModule primitive_computation_cond.19
3074 %branch_0_comp.5.clone (parameter.0: (s32[2])) -> (s32[2]) {
3075   %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3076   %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3077   %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3078   %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3079   ROOT %tuple.5 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy)
3080 }
3081 
3082 %branch_1_comp.12.clone (parameter.4: (s32[2])) -> (s32[2]) {
3083   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3084   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3085   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3086   ROOT %tuple.6 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy.1)
3087 }
3088 
3089 ENTRY %primitive_computation_cond.19 (parameter.1: s32[], parameter.2: s32[2], parameter.3: s32[2]) -> (s32[2]) {
3090   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3091   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3092   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3093   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3094   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3095   ROOT %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3096 }
3097 )";
3098   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3099                           ParseAndReturnVerifiedModule(hlo_string));
3100   InsertCopies(module.get());
3101   CopyInsertion copy_insertion(nullptr,
3102                                /*use_region_based_live_range_analysis=*/-1);
3103   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3104   VLOG(3) << module->ToString();
3105   // The copy.1 must be kept b/c aliasing of parameter and root is not allowed.
3106   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3107   CHECK_NE(conditional18, nullptr);
3108   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3109   CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3110   auto copy1 = tuple6->operand(0);
3111   CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3112 }
3113 
TEST_F(CopyInsertionTest,ConditionalBranchDoNotCopy1)3114 TEST_F(CopyInsertionTest, ConditionalBranchDoNotCopy1) {
3115   const std::string& hlo_string = R"(
3116 HloModule TestModule
3117 
3118  branch_0_comp.5.clone {
3119  %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3120  %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3121  %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3122  %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3123  ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3124  }
3125 
3126  branch_1_comp.12.clone {
3127   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3128   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3129   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3130   ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
3131  }
3132 
3133 ENTRY TestComputation {
3134   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3135   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3136   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3137   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3138   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3139   %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3140   %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3141   ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(gte.1, gte.1)
3142 }
3143 )";
3144   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3145                           ParseAndReturnVerifiedModule(hlo_string));
3146   CopyInsertion copy_insertion(nullptr,
3147                                /*use_region_based_live_range_analysis=*/-1);
3148   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3149   VLOG(3) << module->ToString() << "\n";
3150 
3151   // The copy.1 must be kept due to modification in the other branch.
3152   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3153   CHECK_NE(conditional18, nullptr);
3154   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3155   CHECK_EQ(tuple6->opcode(), HloOpcode::kParameter);
3156 }
3157 
TEST_F(CopyInsertionTest,RootInstructionNotLast)3158 TEST_F(CopyInsertionTest, RootInstructionNotLast) {
3159   // This is a test for b/189219227. When the root instruction is scheduled not
3160   // as the last instruction, it still lives out. So, we make sure that the copy
3161   // after the root cannot be removed.
3162   const std::string& hlo_string = R"(
3163 HloModule module, is_scheduled=true
3164 
3165 body2 {
3166   p_body2 = (f32[2]{0}) parameter(0)
3167   p_body2.1 = f32[2]{0} get-tuple-element(p_body2), index=0
3168   add.3 = f32[2]{0} add(p_body2.1, p_body2.1)
3169   ROOT root2 = (f32[2]{0}) tuple(add.3)
3170 }
3171 
3172 condition2 {
3173   p_cond2 = (f32[2]{0}) parameter(0)
3174   ROOT result = pred[] constant(true)
3175 }
3176 
3177 body {
3178   p_body = (f32[2]{0}) parameter(0)
3179   p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
3180   ROOT root = (f32[2]{0}) tuple(p_body.1)
3181   copy = f32[2]{0} copy(p_body.1)
3182   tuple = (f32[2]{0}) tuple(copy)
3183   while.1 = (f32[2]{0}) while(tuple), condition=condition2, body=body2
3184 }
3185 
3186 condition {
3187   p_cond = (f32[2]{0}) parameter(0)
3188   ROOT result = pred[] constant(true)
3189 }
3190 
3191 ENTRY entry {
3192   const0 = f32[2]{0} constant({1, 2})
3193   while_init = (f32[2]{0}) tuple(const0)
3194   ROOT while.0 = (f32[2]{0}) while(while_init), condition=condition, body=body
3195 }
3196 )";
3197   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3198                           ParseAndReturnVerifiedModule(hlo_string));
3199   CopyInsertion copy_insertion(nullptr,
3200                                /*use_region_based_live_range_analysis=*/-1);
3201   SequentialHloOrdering ordering(module->schedule());
3202   ASSERT_IS_OK(copy_insertion.RemoveUnnecessaryCopies(&ordering, module.get()));
3203   auto while_1 = FindInstruction(module.get(), "while.1");
3204   EXPECT_THAT(while_1, op::While(op::Tuple(op::Copy())));
3205 }
3206 
TEST_F(CopyInsertionTest,InPlaceCollectivePermuteCopy)3207 TEST_F(CopyInsertionTest, InPlaceCollectivePermuteCopy) {
3208   absl::string_view hlo_string = R"(
3209 HloModule hlo_runner_test_0.1
3210 ENTRY hlo_runner_test_0.1 {
3211     replica_id = u32[] replica-id()
3212     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
3213     constant.1 = u32[] constant(1000)
3214     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3215     broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3216     constant.2 = s32[] constant(0)
3217     constant.3 = s32[] constant(1)
3218     tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
3219     tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
3220     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
3221     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
3222     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
3223     constant.4 = s32[] constant(2)
3224     tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
3225     tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
3226     tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
3227     tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3228     tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
3229     tuple.10 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3230     collective-permute.0 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3231     collective-permute.1 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.10), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3232     ROOT tuple = ((u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}), (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})) tuple(collective-permute.0, collective-permute.1)
3233   }
3234 )";
3235   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3236                           ParseAndReturnVerifiedModule(hlo_string));
3237   InsertCopies(module.get());
3238   EXPECT_EQ(CountCopies(*module), 4);
3239 }
3240 
TEST_F(CopyInsertionTest,KeepCopyOfBroadcast)3241 TEST_F(CopyInsertionTest, KeepCopyOfBroadcast) {
3242   absl::string_view hlo_string = R"(
3243 HloModule Module
3244 
3245 ENTRY main {
3246   param = f32[128,1,128] parameter(0)
3247   negate = f32[128,1,128] negate(param)
3248   constant.1 = f32[] constant(0)
3249   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
3250   broadcast.7 = f32[128,1,128] broadcast(constant.1), dimensions={}
3251   constant.3 = s32[] constant(0)
3252   dynamic-update-slice.5 = f32[128,1,128] dynamic-update-slice(broadcast.6, broadcast.7, constant.3, constant.3, constant.3)
3253   add1 = f32[128,1,128] add(dynamic-update-slice.5, dynamic-update-slice.5)
3254   dynamic-update-slice.4 = f32[128,1,128] dynamic-update-slice(broadcast.6, broadcast.7, constant.3, constant.3, constant.3)
3255   add2 = f32[128,1,128] add(dynamic-update-slice.4, dynamic-update-slice.4)
3256   tuple = (f32[128,1,128], f32[128,1,128]) tuple(add1, add2)
3257 }
3258 )";
3259   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3260                           ParseAndReturnVerifiedModule(hlo_string));
3261   CopyInsertion copy_insertion(nullptr,
3262                                /*use_region_based_live_range_analysis=*/-1);
3263   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3264   EXPECT_EQ(CountCopies(*module), 2);
3265 }
3266 
TEST_F(CopyInsertionTest,CustomCallAliasingCopyInsertedAliasedParam)3267 TEST_F(CopyInsertionTest, CustomCallAliasingCopyInsertedAliasedParam) {
3268   // The custom call specifies aliasing for an operand that is an input to the
3269   // computation, but it does not own that buffer so a precautionary copy
3270   // must be inserted.
3271   const char* const kModuleString = R"(
3272     HloModule xla_computation_f
3273 
3274     ENTRY xla_computation_f {
3275       parameter.1 = f32[2,3,4,5] parameter(0)
3276       parameter.2 = f32[2,3,4,5] parameter(1)
3277       ROOT custom-call = f32[2,3,4,5] custom-call(parameter.1, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
3278     }
3279   )";
3280 
3281   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3282                           ParseAndReturnUnverifiedModule(kModuleString));
3283   InsertCopies(module.get());
3284   HloInstruction* custom_call = module->entry_computation()->root_instruction();
3285   EXPECT_THAT(custom_call->operand(0), op::Copy(op::Parameter(0)));
3286 }
3287 
TEST_F(CopyInsertionTest,CustomCallAliasingCopyInsertedAliasedReuse)3288 TEST_F(CopyInsertionTest, CustomCallAliasingCopyInsertedAliasedReuse) {
3289   // The custom call specifies aliasing for an operand that is later re-used
3290   // by a different instruction (add.2) A copy must be inserted so the correct
3291   // HloValue is passed to the add, and not the result of the aliased call.
3292   const char* const kModuleString = R"(
3293     HloModule xla_computation_f
3294 
3295     ENTRY xla_computation_f {
3296       parameter.1 = f32[2,3,4,5] parameter(0)
3297       parameter.2 = f32[2,3,4,5] parameter(1)
3298       add.1 = f32[2,3,4,5] add(parameter.1, parameter.2)
3299       custom-call = f32[2,3,4,5] custom-call(add.1, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
3300       ROOT add.2 = f32[2,3,4,5] add(custom-call, add.1)
3301     }
3302   )";
3303 
3304   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3305                           ParseAndReturnUnverifiedModule(kModuleString));
3306 
3307   InsertCopies(module.get());
3308   HloInstruction* custom_call = FindInstruction(module.get(), "custom-call");
3309   CHECK_NE(custom_call, nullptr);
3310   EXPECT_THAT(custom_call->operand(0), op::Copy(op::Add()));
3311 }
3312 
TEST_F(CopyInsertionTest,CustomCallAliasingCopyRemoved)3313 TEST_F(CopyInsertionTest, CustomCallAliasingCopyRemoved) {
3314   // This custom call aliases an intermediate result, and the value is never
3315   // reused. There is no need for a copy.
3316   const char* const kModuleString = R"(
3317     HloModule xla_computation_f__1
3318     ENTRY xla_computation_f {
3319       parameter.1 = f32[2,3,4,5] parameter(0)
3320       parameter.2 = f32[2,3,4,5] parameter(1)
3321       add = f32[2,3,4,5] add(parameter.1, parameter.2)
3322       ROOT custom-call = f32[2,3,4,5] custom-call(add, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
3323     }
3324   )";
3325 
3326   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3327                           ParseAndReturnUnverifiedModule(kModuleString));
3328 
3329   InsertCopies(module.get());
3330   HloInstruction* custom_call = module->entry_computation()->root_instruction();
3331   EXPECT_THAT(custom_call->operand(0), op::Add());
3332 }
3333 
TEST_F(CopyInsertionTest,ReverseInConditional)3334 TEST_F(CopyInsertionTest, ReverseInConditional) {
3335   const char* const kModuleString = R"(
3336 HloModule jit_f.0
3337 
3338 %region_0.4 (Arg_.5: u8[300,451,3]) -> (u8[300,451,3]) {
3339   %Arg_.5 = u8[300,451,3]{1,0,2:T(8,128)(4,1)} parameter(0)
3340   ROOT %tuple = (u8[300,451,3]{1,0,2:T(8,128)(4,1)}) tuple(u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_.5)
3341 }
3342 
3343 %region_1.9 (Arg_.10: u8[300,451,3]) -> (u8[300,451,3]) {
3344   %Arg_.10 = u8[300,451,3]{1,0,2:T(8,128)(4,1)} parameter(0)
3345   %reverse = u8[300,451,3]{1,0,2:T(8,128)(4,1)} reverse(u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_.10), dimensions={0}
3346   ROOT %tuple.1 = (u8[300,451,3]{1,0,2:T(8,128)(4,1)}) tuple(u8[300,451,3]{1,0,2:T(8,128)(4,1)} %reverse)
3347 }
3348 
3349 ENTRY %main.13 (Arg_0.1: pred[], Arg_1.2: u8[300,451,3]) -> u8[300,451,3] {
3350   %Arg_0.1 = pred[]{:T(1024)} parameter(0)
3351   %convert.3 = s32[]{:T(256)} convert(pred[]{:T(1024)} %Arg_0.1)
3352   %Arg_1.2 = u8[300,451,3]{1,0,2:T(8,128)(4,1)} parameter(1)
3353   %conditional.12.clone = (u8[300,451,3]{1,0,2:T(8,128)(4,1)}) conditional(s32[]{:T(256)} %convert.3, u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_1.2, u8[300,451,3]{1,0,2:T(8,128)(4,1)} %Arg_1.2), branch_computations={%region_0.4, %region_1.9}
3354   ROOT %get-tuple-element = u8[300,451,3]{1,0,2:T(8,128)(4,1)} get-tuple-element((u8[300,451,3]{1,0,2:T(8,128)(4,1)}) %conditional.12.clone), index=0
3355 }
3356 )";
3357 
3358   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3359                           ParseAndReturnUnverifiedModule(kModuleString));
3360 
3361   CopyInsertion copy_insertion(nullptr,
3362                                /*use_region_based_live_range_analysis=*/-1);
3363   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3364   VLOG(2) << module->ToString();
3365   HloInstruction* reverse = FindInstruction(module.get(), "reverse");
3366   EXPECT_THAT(reverse->operand(0), op::Copy());
3367 }
3368 
TEST_F(CopyInsertionTest,InputOutputAliasCopy)3369 TEST_F(CopyInsertionTest, InputOutputAliasCopy) {
3370   const char* const kModuleString = R"(
3371 HloModule main_tf2xla.11, input_output_alias={ {0}: (0, {1}, may-alias) }
3372 
3373 ENTRY %main_tf2xla.11 (arg_tuple.1: (f32[], f32[])) -> (f32[], f32[]) {
3374 ROOT %arg_tuple.1 = (f32[]{:T(256)}, f32[]{:T(256)}) parameter(0), parameter_replication={false,false}, sharding={{replicated}, {replicated}}
3375 }
3376 )";
3377 
3378   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
3379                           ParseAndReturnUnverifiedModule(kModuleString));
3380 
3381   CopyInsertion copy_insertion(nullptr,
3382                                /*use_region_based_live_range_analysis=*/-1);
3383   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3384   VLOG(2) << module->ToString();
3385 }
3386 
3387 }  // namespace
3388 }  // namespace xla
3389